1 # Access WeakSet through the weakref module.
2 # This code is separated-out because it is needed
3 # by abc.py to load everything else at startup.
4
5 from _weakref import ref
6 from types import GenericAlias
7
8 __all__ = ['WeakSet']
9
10
11 class ESC[4;38;5;81m_IterationGuard:
12 # This context manager registers itself in the current iterators of the
13 # weak container, such as to delay all removals until the context manager
14 # exits.
15 # This technique should be relatively thread-safe (since sets are).
16
17 def __init__(self, weakcontainer):
18 # Don't create cycles
19 self.weakcontainer = ref(weakcontainer)
20
21 def __enter__(self):
22 w = self.weakcontainer()
23 if w is not None:
24 w._iterating.add(self)
25 return self
26
27 def __exit__(self, e, t, b):
28 w = self.weakcontainer()
29 if w is not None:
30 s = w._iterating
31 s.remove(self)
32 if not s:
33 w._commit_removals()
34
35
36 class ESC[4;38;5;81mWeakSet:
37 def __init__(self, data=None):
38 self.data = set()
39 def _remove(item, selfref=ref(self)):
40 self = selfref()
41 if self is not None:
42 if self._iterating:
43 self._pending_removals.append(item)
44 else:
45 self.data.discard(item)
46 self._remove = _remove
47 # A list of keys to be removed
48 self._pending_removals = []
49 self._iterating = set()
50 if data is not None:
51 self.update(data)
52
53 def _commit_removals(self):
54 pop = self._pending_removals.pop
55 discard = self.data.discard
56 while True:
57 try:
58 item = pop()
59 except IndexError:
60 return
61 discard(item)
62
63 def __iter__(self):
64 with _IterationGuard(self):
65 for itemref in self.data:
66 item = itemref()
67 if item is not None:
68 # Caveat: the iterator will keep a strong reference to
69 # `item` until it is resumed or closed.
70 yield item
71
72 def __len__(self):
73 return len(self.data) - len(self._pending_removals)
74
75 def __contains__(self, item):
76 try:
77 wr = ref(item)
78 except TypeError:
79 return False
80 return wr in self.data
81
82 def __reduce__(self):
83 return self.__class__, (list(self),), self.__getstate__()
84
85 def add(self, item):
86 if self._pending_removals:
87 self._commit_removals()
88 self.data.add(ref(item, self._remove))
89
90 def clear(self):
91 if self._pending_removals:
92 self._commit_removals()
93 self.data.clear()
94
95 def copy(self):
96 return self.__class__(self)
97
98 def pop(self):
99 if self._pending_removals:
100 self._commit_removals()
101 while True:
102 try:
103 itemref = self.data.pop()
104 except KeyError:
105 raise KeyError('pop from empty WeakSet') from None
106 item = itemref()
107 if item is not None:
108 return item
109
110 def remove(self, item):
111 if self._pending_removals:
112 self._commit_removals()
113 self.data.remove(ref(item))
114
115 def discard(self, item):
116 if self._pending_removals:
117 self._commit_removals()
118 self.data.discard(ref(item))
119
120 def update(self, other):
121 if self._pending_removals:
122 self._commit_removals()
123 for element in other:
124 self.add(element)
125
126 def __ior__(self, other):
127 self.update(other)
128 return self
129
130 def difference(self, other):
131 newset = self.copy()
132 newset.difference_update(other)
133 return newset
134 __sub__ = difference
135
136 def difference_update(self, other):
137 self.__isub__(other)
138 def __isub__(self, other):
139 if self._pending_removals:
140 self._commit_removals()
141 if self is other:
142 self.data.clear()
143 else:
144 self.data.difference_update(ref(item) for item in other)
145 return self
146
147 def intersection(self, other):
148 return self.__class__(item for item in other if item in self)
149 __and__ = intersection
150
151 def intersection_update(self, other):
152 self.__iand__(other)
153 def __iand__(self, other):
154 if self._pending_removals:
155 self._commit_removals()
156 self.data.intersection_update(ref(item) for item in other)
157 return self
158
159 def issubset(self, other):
160 return self.data.issubset(ref(item) for item in other)
161 __le__ = issubset
162
163 def __lt__(self, other):
164 return self.data < set(map(ref, other))
165
166 def issuperset(self, other):
167 return self.data.issuperset(ref(item) for item in other)
168 __ge__ = issuperset
169
170 def __gt__(self, other):
171 return self.data > set(map(ref, other))
172
173 def __eq__(self, other):
174 if not isinstance(other, self.__class__):
175 return NotImplemented
176 return self.data == set(map(ref, other))
177
178 def symmetric_difference(self, other):
179 newset = self.copy()
180 newset.symmetric_difference_update(other)
181 return newset
182 __xor__ = symmetric_difference
183
184 def symmetric_difference_update(self, other):
185 self.__ixor__(other)
186 def __ixor__(self, other):
187 if self._pending_removals:
188 self._commit_removals()
189 if self is other:
190 self.data.clear()
191 else:
192 self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
193 return self
194
195 def union(self, other):
196 return self.__class__(e for s in (self, other) for e in s)
197 __or__ = union
198
199 def isdisjoint(self, other):
200 return len(self.intersection(other)) == 0
201
202 def __repr__(self):
203 return repr(self.data)
204
205 __class_getitem__ = classmethod(GenericAlias)