python (3.12.0)
1 """Synchronization primitives."""
2
3 __all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
4 'BoundedSemaphore', 'Barrier')
5
6 import collections
7 import enum
8
9 from . import exceptions
10 from . import mixins
11
12 class ESC[4;38;5;81m_ContextManagerMixin:
13 async def __aenter__(self):
14 await self.acquire()
15 # We have no use for the "as ..." clause in the with
16 # statement for locks.
17 return None
18
19 async def __aexit__(self, exc_type, exc, tb):
20 self.release()
21
22
23 class ESC[4;38;5;81mLock(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
24 """Primitive lock objects.
25
26 A primitive lock is a synchronization primitive that is not owned
27 by a particular coroutine when locked. A primitive lock is in one
28 of two states, 'locked' or 'unlocked'.
29
30 It is created in the unlocked state. It has two basic methods,
31 acquire() and release(). When the state is unlocked, acquire()
32 changes the state to locked and returns immediately. When the
33 state is locked, acquire() blocks until a call to release() in
34 another coroutine changes it to unlocked, then the acquire() call
35 resets it to locked and returns. The release() method should only
36 be called in the locked state; it changes the state to unlocked
37 and returns immediately. If an attempt is made to release an
38 unlocked lock, a RuntimeError will be raised.
39
40 When more than one coroutine is blocked in acquire() waiting for
41 the state to turn to unlocked, only one coroutine proceeds when a
42 release() call resets the state to unlocked; first coroutine which
43 is blocked in acquire() is being processed.
44
45 acquire() is a coroutine and should be called with 'await'.
46
47 Locks also support the asynchronous context management protocol.
48 'async with lock' statement should be used.
49
50 Usage:
51
52 lock = Lock()
53 ...
54 await lock.acquire()
55 try:
56 ...
57 finally:
58 lock.release()
59
60 Context manager usage:
61
62 lock = Lock()
63 ...
64 async with lock:
65 ...
66
67 Lock objects can be tested for locking state:
68
69 if not lock.locked():
70 await lock.acquire()
71 else:
72 # lock is acquired
73 ...
74
75 """
76
77 def __init__(self):
78 self._waiters = None
79 self._locked = False
80
81 def __repr__(self):
82 res = super().__repr__()
83 extra = 'locked' if self._locked else 'unlocked'
84 if self._waiters:
85 extra = f'{extra}, waiters:{len(self._waiters)}'
86 return f'<{res[1:-1]} [{extra}]>'
87
88 def locked(self):
89 """Return True if lock is acquired."""
90 return self._locked
91
92 async def acquire(self):
93 """Acquire a lock.
94
95 This method blocks until the lock is unlocked, then sets it to
96 locked and returns True.
97 """
98 if (not self._locked and (self._waiters is None or
99 all(w.cancelled() for w in self._waiters))):
100 self._locked = True
101 return True
102
103 if self._waiters is None:
104 self._waiters = collections.deque()
105 fut = self._get_loop().create_future()
106 self._waiters.append(fut)
107
108 # Finally block should be called before the CancelledError
109 # handling as we don't want CancelledError to call
110 # _wake_up_first() and attempt to wake up itself.
111 try:
112 try:
113 await fut
114 finally:
115 self._waiters.remove(fut)
116 except exceptions.CancelledError:
117 if not self._locked:
118 self._wake_up_first()
119 raise
120
121 self._locked = True
122 return True
123
124 def release(self):
125 """Release a lock.
126
127 When the lock is locked, reset it to unlocked, and return.
128 If any other coroutines are blocked waiting for the lock to become
129 unlocked, allow exactly one of them to proceed.
130
131 When invoked on an unlocked lock, a RuntimeError is raised.
132
133 There is no return value.
134 """
135 if self._locked:
136 self._locked = False
137 self._wake_up_first()
138 else:
139 raise RuntimeError('Lock is not acquired.')
140
141 def _wake_up_first(self):
142 """Wake up the first waiter if it isn't done."""
143 if not self._waiters:
144 return
145 try:
146 fut = next(iter(self._waiters))
147 except StopIteration:
148 return
149
150 # .done() necessarily means that a waiter will wake up later on and
151 # either take the lock, or, if it was cancelled and lock wasn't
152 # taken already, will hit this again and wake up a new waiter.
153 if not fut.done():
154 fut.set_result(True)
155
156
157 class ESC[4;38;5;81mEvent(ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
158 """Asynchronous equivalent to threading.Event.
159
160 Class implementing event objects. An event manages a flag that can be set
161 to true with the set() method and reset to false with the clear() method.
162 The wait() method blocks until the flag is true. The flag is initially
163 false.
164 """
165
166 def __init__(self):
167 self._waiters = collections.deque()
168 self._value = False
169
170 def __repr__(self):
171 res = super().__repr__()
172 extra = 'set' if self._value else 'unset'
173 if self._waiters:
174 extra = f'{extra}, waiters:{len(self._waiters)}'
175 return f'<{res[1:-1]} [{extra}]>'
176
177 def is_set(self):
178 """Return True if and only if the internal flag is true."""
179 return self._value
180
181 def set(self):
182 """Set the internal flag to true. All coroutines waiting for it to
183 become true are awakened. Coroutine that call wait() once the flag is
184 true will not block at all.
185 """
186 if not self._value:
187 self._value = True
188
189 for fut in self._waiters:
190 if not fut.done():
191 fut.set_result(True)
192
193 def clear(self):
194 """Reset the internal flag to false. Subsequently, coroutines calling
195 wait() will block until set() is called to set the internal flag
196 to true again."""
197 self._value = False
198
199 async def wait(self):
200 """Block until the internal flag is true.
201
202 If the internal flag is true on entry, return True
203 immediately. Otherwise, block until another coroutine calls
204 set() to set the flag to true, then return True.
205 """
206 if self._value:
207 return True
208
209 fut = self._get_loop().create_future()
210 self._waiters.append(fut)
211 try:
212 await fut
213 return True
214 finally:
215 self._waiters.remove(fut)
216
217
218 class ESC[4;38;5;81mCondition(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
219 """Asynchronous equivalent to threading.Condition.
220
221 This class implements condition variable objects. A condition variable
222 allows one or more coroutines to wait until they are notified by another
223 coroutine.
224
225 A new Lock object is created and used as the underlying lock.
226 """
227
228 def __init__(self, lock=None):
229 if lock is None:
230 lock = Lock()
231
232 self._lock = lock
233 # Export the lock's locked(), acquire() and release() methods.
234 self.locked = lock.locked
235 self.acquire = lock.acquire
236 self.release = lock.release
237
238 self._waiters = collections.deque()
239
240 def __repr__(self):
241 res = super().__repr__()
242 extra = 'locked' if self.locked() else 'unlocked'
243 if self._waiters:
244 extra = f'{extra}, waiters:{len(self._waiters)}'
245 return f'<{res[1:-1]} [{extra}]>'
246
247 async def wait(self):
248 """Wait until notified.
249
250 If the calling coroutine has not acquired the lock when this
251 method is called, a RuntimeError is raised.
252
253 This method releases the underlying lock, and then blocks
254 until it is awakened by a notify() or notify_all() call for
255 the same condition variable in another coroutine. Once
256 awakened, it re-acquires the lock and returns True.
257 """
258 if not self.locked():
259 raise RuntimeError('cannot wait on un-acquired lock')
260
261 self.release()
262 try:
263 fut = self._get_loop().create_future()
264 self._waiters.append(fut)
265 try:
266 await fut
267 return True
268 finally:
269 self._waiters.remove(fut)
270
271 finally:
272 # Must reacquire lock even if wait is cancelled
273 cancelled = False
274 while True:
275 try:
276 await self.acquire()
277 break
278 except exceptions.CancelledError:
279 cancelled = True
280
281 if cancelled:
282 raise exceptions.CancelledError
283
284 async def wait_for(self, predicate):
285 """Wait until a predicate becomes true.
286
287 The predicate should be a callable which result will be
288 interpreted as a boolean value. The final predicate value is
289 the return value.
290 """
291 result = predicate()
292 while not result:
293 await self.wait()
294 result = predicate()
295 return result
296
297 def notify(self, n=1):
298 """By default, wake up one coroutine waiting on this condition, if any.
299 If the calling coroutine has not acquired the lock when this method
300 is called, a RuntimeError is raised.
301
302 This method wakes up at most n of the coroutines waiting for the
303 condition variable; it is a no-op if no coroutines are waiting.
304
305 Note: an awakened coroutine does not actually return from its
306 wait() call until it can reacquire the lock. Since notify() does
307 not release the lock, its caller should.
308 """
309 if not self.locked():
310 raise RuntimeError('cannot notify on un-acquired lock')
311
312 idx = 0
313 for fut in self._waiters:
314 if idx >= n:
315 break
316
317 if not fut.done():
318 idx += 1
319 fut.set_result(False)
320
321 def notify_all(self):
322 """Wake up all threads waiting on this condition. This method acts
323 like notify(), but wakes up all waiting threads instead of one. If the
324 calling thread has not acquired the lock when this method is called,
325 a RuntimeError is raised.
326 """
327 self.notify(len(self._waiters))
328
329
330 class ESC[4;38;5;81mSemaphore(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
331 """A Semaphore implementation.
332
333 A semaphore manages an internal counter which is decremented by each
334 acquire() call and incremented by each release() call. The counter
335 can never go below zero; when acquire() finds that it is zero, it blocks,
336 waiting until some other thread calls release().
337
338 Semaphores also support the context management protocol.
339
340 The optional argument gives the initial value for the internal
341 counter; it defaults to 1. If the value given is less than 0,
342 ValueError is raised.
343 """
344
345 def __init__(self, value=1):
346 if value < 0:
347 raise ValueError("Semaphore initial value must be >= 0")
348 self._waiters = None
349 self._value = value
350
351 def __repr__(self):
352 res = super().__repr__()
353 extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
354 if self._waiters:
355 extra = f'{extra}, waiters:{len(self._waiters)}'
356 return f'<{res[1:-1]} [{extra}]>'
357
358 def locked(self):
359 """Returns True if semaphore cannot be acquired immediately."""
360 return self._value == 0 or (
361 any(not w.cancelled() for w in (self._waiters or ())))
362
363 async def acquire(self):
364 """Acquire a semaphore.
365
366 If the internal counter is larger than zero on entry,
367 decrement it by one and return True immediately. If it is
368 zero on entry, block, waiting until some other coroutine has
369 called release() to make it larger than 0, and then return
370 True.
371 """
372 if not self.locked():
373 self._value -= 1
374 return True
375
376 if self._waiters is None:
377 self._waiters = collections.deque()
378 fut = self._get_loop().create_future()
379 self._waiters.append(fut)
380
381 # Finally block should be called before the CancelledError
382 # handling as we don't want CancelledError to call
383 # _wake_up_first() and attempt to wake up itself.
384 try:
385 try:
386 await fut
387 finally:
388 self._waiters.remove(fut)
389 except exceptions.CancelledError:
390 if not fut.cancelled():
391 self._value += 1
392 self._wake_up_next()
393 raise
394
395 if self._value > 0:
396 self._wake_up_next()
397 return True
398
399 def release(self):
400 """Release a semaphore, incrementing the internal counter by one.
401
402 When it was zero on entry and another coroutine is waiting for it to
403 become larger than zero again, wake up that coroutine.
404 """
405 self._value += 1
406 self._wake_up_next()
407
408 def _wake_up_next(self):
409 """Wake up the first waiter that isn't done."""
410 if not self._waiters:
411 return
412
413 for fut in self._waiters:
414 if not fut.done():
415 self._value -= 1
416 fut.set_result(True)
417 return
418
419
420 class ESC[4;38;5;81mBoundedSemaphore(ESC[4;38;5;149mSemaphore):
421 """A bounded semaphore implementation.
422
423 This raises ValueError in release() if it would increase the value
424 above the initial value.
425 """
426
427 def __init__(self, value=1):
428 self._bound_value = value
429 super().__init__(value)
430
431 def release(self):
432 if self._value >= self._bound_value:
433 raise ValueError('BoundedSemaphore released too many times')
434 super().release()
435
436
437
438 class ESC[4;38;5;81m_BarrierState(ESC[4;38;5;149menumESC[4;38;5;149m.ESC[4;38;5;149mEnum):
439 FILLING = 'filling'
440 DRAINING = 'draining'
441 RESETTING = 'resetting'
442 BROKEN = 'broken'
443
444
445 class ESC[4;38;5;81mBarrier(ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
446 """Asyncio equivalent to threading.Barrier
447
448 Implements a Barrier primitive.
449 Useful for synchronizing a fixed number of tasks at known synchronization
450 points. Tasks block on 'wait()' and are simultaneously awoken once they
451 have all made their call.
452 """
453
454 def __init__(self, parties):
455 """Create a barrier, initialised to 'parties' tasks."""
456 if parties < 1:
457 raise ValueError('parties must be > 0')
458
459 self._cond = Condition() # notify all tasks when state changes
460
461 self._parties = parties
462 self._state = _BarrierState.FILLING
463 self._count = 0 # count tasks in Barrier
464
465 def __repr__(self):
466 res = super().__repr__()
467 extra = f'{self._state.value}'
468 if not self.broken:
469 extra += f', waiters:{self.n_waiting}/{self.parties}'
470 return f'<{res[1:-1]} [{extra}]>'
471
472 async def __aenter__(self):
473 # wait for the barrier reaches the parties number
474 # when start draining release and return index of waited task
475 return await self.wait()
476
477 async def __aexit__(self, *args):
478 pass
479
480 async def wait(self):
481 """Wait for the barrier.
482
483 When the specified number of tasks have started waiting, they are all
484 simultaneously awoken.
485 Returns an unique and individual index number from 0 to 'parties-1'.
486 """
487 async with self._cond:
488 await self._block() # Block while the barrier drains or resets.
489 try:
490 index = self._count
491 self._count += 1
492 if index + 1 == self._parties:
493 # We release the barrier
494 await self._release()
495 else:
496 await self._wait()
497 return index
498 finally:
499 self._count -= 1
500 # Wake up any tasks waiting for barrier to drain.
501 self._exit()
502
503 async def _block(self):
504 # Block until the barrier is ready for us,
505 # or raise an exception if it is broken.
506 #
507 # It is draining or resetting, wait until done
508 # unless a CancelledError occurs
509 await self._cond.wait_for(
510 lambda: self._state not in (
511 _BarrierState.DRAINING, _BarrierState.RESETTING
512 )
513 )
514
515 # see if the barrier is in a broken state
516 if self._state is _BarrierState.BROKEN:
517 raise exceptions.BrokenBarrierError("Barrier aborted")
518
519 async def _release(self):
520 # Release the tasks waiting in the barrier.
521
522 # Enter draining state.
523 # Next waiting tasks will be blocked until the end of draining.
524 self._state = _BarrierState.DRAINING
525 self._cond.notify_all()
526
527 async def _wait(self):
528 # Wait in the barrier until we are released. Raise an exception
529 # if the barrier is reset or broken.
530
531 # wait for end of filling
532 # unless a CancelledError occurs
533 await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
534
535 if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
536 raise exceptions.BrokenBarrierError("Abort or reset of barrier")
537
538 def _exit(self):
539 # If we are the last tasks to exit the barrier, signal any tasks
540 # waiting for the barrier to drain.
541 if self._count == 0:
542 if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
543 self._state = _BarrierState.FILLING
544 self._cond.notify_all()
545
546 async def reset(self):
547 """Reset the barrier to the initial state.
548
549 Any tasks currently waiting will get the BrokenBarrier exception
550 raised.
551 """
552 async with self._cond:
553 if self._count > 0:
554 if self._state is not _BarrierState.RESETTING:
555 #reset the barrier, waking up tasks
556 self._state = _BarrierState.RESETTING
557 else:
558 self._state = _BarrierState.FILLING
559 self._cond.notify_all()
560
561 async def abort(self):
562 """Place the barrier into a 'broken' state.
563
564 Useful in case of error. Any currently waiting tasks and tasks
565 attempting to 'wait()' will have BrokenBarrierError raised.
566 """
567 async with self._cond:
568 self._state = _BarrierState.BROKEN
569 self._cond.notify_all()
570
571 @property
572 def parties(self):
573 """Return the number of tasks required to trip the barrier."""
574 return self._parties
575
576 @property
577 def n_waiting(self):
578 """Return the number of tasks currently waiting at the barrier."""
579 if self._state is _BarrierState.FILLING:
580 return self._count
581 return 0
582
583 @property
584 def broken(self):
585 """Return True if the barrier is in a broken state."""
586 return self._state is _BarrierState.BROKEN