1 """Synchronization primitives."""
2
3 __all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
4 'BoundedSemaphore', 'Barrier')
5
6 import collections
7 import enum
8
9 from . import exceptions
10 from . import mixins
11 from . import tasks
12
13 class ESC[4;38;5;81m_ContextManagerMixin:
14 async def __aenter__(self):
15 await self.acquire()
16 # We have no use for the "as ..." clause in the with
17 # statement for locks.
18 return None
19
20 async def __aexit__(self, exc_type, exc, tb):
21 self.release()
22
23
24 class ESC[4;38;5;81mLock(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
25 """Primitive lock objects.
26
27 A primitive lock is a synchronization primitive that is not owned
28 by a particular coroutine when locked. A primitive lock is in one
29 of two states, 'locked' or 'unlocked'.
30
31 It is created in the unlocked state. It has two basic methods,
32 acquire() and release(). When the state is unlocked, acquire()
33 changes the state to locked and returns immediately. When the
34 state is locked, acquire() blocks until a call to release() in
35 another coroutine changes it to unlocked, then the acquire() call
36 resets it to locked and returns. The release() method should only
37 be called in the locked state; it changes the state to unlocked
38 and returns immediately. If an attempt is made to release an
39 unlocked lock, a RuntimeError will be raised.
40
41 When more than one coroutine is blocked in acquire() waiting for
42 the state to turn to unlocked, only one coroutine proceeds when a
43 release() call resets the state to unlocked; first coroutine which
44 is blocked in acquire() is being processed.
45
46 acquire() is a coroutine and should be called with 'await'.
47
48 Locks also support the asynchronous context management protocol.
49 'async with lock' statement should be used.
50
51 Usage:
52
53 lock = Lock()
54 ...
55 await lock.acquire()
56 try:
57 ...
58 finally:
59 lock.release()
60
61 Context manager usage:
62
63 lock = Lock()
64 ...
65 async with lock:
66 ...
67
68 Lock objects can be tested for locking state:
69
70 if not lock.locked():
71 await lock.acquire()
72 else:
73 # lock is acquired
74 ...
75
76 """
77
78 def __init__(self):
79 self._waiters = None
80 self._locked = False
81
82 def __repr__(self):
83 res = super().__repr__()
84 extra = 'locked' if self._locked else 'unlocked'
85 if self._waiters:
86 extra = f'{extra}, waiters:{len(self._waiters)}'
87 return f'<{res[1:-1]} [{extra}]>'
88
89 def locked(self):
90 """Return True if lock is acquired."""
91 return self._locked
92
93 async def acquire(self):
94 """Acquire a lock.
95
96 This method blocks until the lock is unlocked, then sets it to
97 locked and returns True.
98 """
99 if (not self._locked and (self._waiters is None or
100 all(w.cancelled() for w in self._waiters))):
101 self._locked = True
102 return True
103
104 if self._waiters is None:
105 self._waiters = collections.deque()
106 fut = self._get_loop().create_future()
107 self._waiters.append(fut)
108
109 # Finally block should be called before the CancelledError
110 # handling as we don't want CancelledError to call
111 # _wake_up_first() and attempt to wake up itself.
112 try:
113 try:
114 await fut
115 finally:
116 self._waiters.remove(fut)
117 except exceptions.CancelledError:
118 if not self._locked:
119 self._wake_up_first()
120 raise
121
122 self._locked = True
123 return True
124
125 def release(self):
126 """Release a lock.
127
128 When the lock is locked, reset it to unlocked, and return.
129 If any other coroutines are blocked waiting for the lock to become
130 unlocked, allow exactly one of them to proceed.
131
132 When invoked on an unlocked lock, a RuntimeError is raised.
133
134 There is no return value.
135 """
136 if self._locked:
137 self._locked = False
138 self._wake_up_first()
139 else:
140 raise RuntimeError('Lock is not acquired.')
141
142 def _wake_up_first(self):
143 """Wake up the first waiter if it isn't done."""
144 if not self._waiters:
145 return
146 try:
147 fut = next(iter(self._waiters))
148 except StopIteration:
149 return
150
151 # .done() necessarily means that a waiter will wake up later on and
152 # either take the lock, or, if it was cancelled and lock wasn't
153 # taken already, will hit this again and wake up a new waiter.
154 if not fut.done():
155 fut.set_result(True)
156
157
158 class ESC[4;38;5;81mEvent(ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
159 """Asynchronous equivalent to threading.Event.
160
161 Class implementing event objects. An event manages a flag that can be set
162 to true with the set() method and reset to false with the clear() method.
163 The wait() method blocks until the flag is true. The flag is initially
164 false.
165 """
166
167 def __init__(self):
168 self._waiters = collections.deque()
169 self._value = False
170
171 def __repr__(self):
172 res = super().__repr__()
173 extra = 'set' if self._value else 'unset'
174 if self._waiters:
175 extra = f'{extra}, waiters:{len(self._waiters)}'
176 return f'<{res[1:-1]} [{extra}]>'
177
178 def is_set(self):
179 """Return True if and only if the internal flag is true."""
180 return self._value
181
182 def set(self):
183 """Set the internal flag to true. All coroutines waiting for it to
184 become true are awakened. Coroutine that call wait() once the flag is
185 true will not block at all.
186 """
187 if not self._value:
188 self._value = True
189
190 for fut in self._waiters:
191 if not fut.done():
192 fut.set_result(True)
193
194 def clear(self):
195 """Reset the internal flag to false. Subsequently, coroutines calling
196 wait() will block until set() is called to set the internal flag
197 to true again."""
198 self._value = False
199
200 async def wait(self):
201 """Block until the internal flag is true.
202
203 If the internal flag is true on entry, return True
204 immediately. Otherwise, block until another coroutine calls
205 set() to set the flag to true, then return True.
206 """
207 if self._value:
208 return True
209
210 fut = self._get_loop().create_future()
211 self._waiters.append(fut)
212 try:
213 await fut
214 return True
215 finally:
216 self._waiters.remove(fut)
217
218
219 class ESC[4;38;5;81mCondition(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
220 """Asynchronous equivalent to threading.Condition.
221
222 This class implements condition variable objects. A condition variable
223 allows one or more coroutines to wait until they are notified by another
224 coroutine.
225
226 A new Lock object is created and used as the underlying lock.
227 """
228
229 def __init__(self, lock=None):
230 if lock is None:
231 lock = Lock()
232
233 self._lock = lock
234 # Export the lock's locked(), acquire() and release() methods.
235 self.locked = lock.locked
236 self.acquire = lock.acquire
237 self.release = lock.release
238
239 self._waiters = collections.deque()
240
241 def __repr__(self):
242 res = super().__repr__()
243 extra = 'locked' if self.locked() else 'unlocked'
244 if self._waiters:
245 extra = f'{extra}, waiters:{len(self._waiters)}'
246 return f'<{res[1:-1]} [{extra}]>'
247
248 async def wait(self):
249 """Wait until notified.
250
251 If the calling coroutine has not acquired the lock when this
252 method is called, a RuntimeError is raised.
253
254 This method releases the underlying lock, and then blocks
255 until it is awakened by a notify() or notify_all() call for
256 the same condition variable in another coroutine. Once
257 awakened, it re-acquires the lock and returns True.
258 """
259 if not self.locked():
260 raise RuntimeError('cannot wait on un-acquired lock')
261
262 self.release()
263 try:
264 fut = self._get_loop().create_future()
265 self._waiters.append(fut)
266 try:
267 await fut
268 return True
269 finally:
270 self._waiters.remove(fut)
271
272 finally:
273 # Must reacquire lock even if wait is cancelled
274 cancelled = False
275 while True:
276 try:
277 await self.acquire()
278 break
279 except exceptions.CancelledError:
280 cancelled = True
281
282 if cancelled:
283 raise exceptions.CancelledError
284
285 async def wait_for(self, predicate):
286 """Wait until a predicate becomes true.
287
288 The predicate should be a callable which result will be
289 interpreted as a boolean value. The final predicate value is
290 the return value.
291 """
292 result = predicate()
293 while not result:
294 await self.wait()
295 result = predicate()
296 return result
297
298 def notify(self, n=1):
299 """By default, wake up one coroutine waiting on this condition, if any.
300 If the calling coroutine has not acquired the lock when this method
301 is called, a RuntimeError is raised.
302
303 This method wakes up at most n of the coroutines waiting for the
304 condition variable; it is a no-op if no coroutines are waiting.
305
306 Note: an awakened coroutine does not actually return from its
307 wait() call until it can reacquire the lock. Since notify() does
308 not release the lock, its caller should.
309 """
310 if not self.locked():
311 raise RuntimeError('cannot notify on un-acquired lock')
312
313 idx = 0
314 for fut in self._waiters:
315 if idx >= n:
316 break
317
318 if not fut.done():
319 idx += 1
320 fut.set_result(False)
321
322 def notify_all(self):
323 """Wake up all threads waiting on this condition. This method acts
324 like notify(), but wakes up all waiting threads instead of one. If the
325 calling thread has not acquired the lock when this method is called,
326 a RuntimeError is raised.
327 """
328 self.notify(len(self._waiters))
329
330
331 class ESC[4;38;5;81mSemaphore(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
332 """A Semaphore implementation.
333
334 A semaphore manages an internal counter which is decremented by each
335 acquire() call and incremented by each release() call. The counter
336 can never go below zero; when acquire() finds that it is zero, it blocks,
337 waiting until some other thread calls release().
338
339 Semaphores also support the context management protocol.
340
341 The optional argument gives the initial value for the internal
342 counter; it defaults to 1. If the value given is less than 0,
343 ValueError is raised.
344 """
345
346 def __init__(self, value=1):
347 if value < 0:
348 raise ValueError("Semaphore initial value must be >= 0")
349 self._waiters = None
350 self._value = value
351
352 def __repr__(self):
353 res = super().__repr__()
354 extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
355 if self._waiters:
356 extra = f'{extra}, waiters:{len(self._waiters)}'
357 return f'<{res[1:-1]} [{extra}]>'
358
359 def locked(self):
360 """Returns True if semaphore cannot be acquired immediately."""
361 return self._value == 0 or (
362 any(not w.cancelled() for w in (self._waiters or ())))
363
364 async def acquire(self):
365 """Acquire a semaphore.
366
367 If the internal counter is larger than zero on entry,
368 decrement it by one and return True immediately. If it is
369 zero on entry, block, waiting until some other coroutine has
370 called release() to make it larger than 0, and then return
371 True.
372 """
373 if not self.locked():
374 self._value -= 1
375 return True
376
377 if self._waiters is None:
378 self._waiters = collections.deque()
379 fut = self._get_loop().create_future()
380 self._waiters.append(fut)
381
382 # Finally block should be called before the CancelledError
383 # handling as we don't want CancelledError to call
384 # _wake_up_first() and attempt to wake up itself.
385 try:
386 try:
387 await fut
388 finally:
389 self._waiters.remove(fut)
390 except exceptions.CancelledError:
391 if not fut.cancelled():
392 self._value += 1
393 self._wake_up_next()
394 raise
395
396 if self._value > 0:
397 self._wake_up_next()
398 return True
399
400 def release(self):
401 """Release a semaphore, incrementing the internal counter by one.
402
403 When it was zero on entry and another coroutine is waiting for it to
404 become larger than zero again, wake up that coroutine.
405 """
406 self._value += 1
407 self._wake_up_next()
408
409 def _wake_up_next(self):
410 """Wake up the first waiter that isn't done."""
411 if not self._waiters:
412 return
413
414 for fut in self._waiters:
415 if not fut.done():
416 self._value -= 1
417 fut.set_result(True)
418 return
419
420
421 class ESC[4;38;5;81mBoundedSemaphore(ESC[4;38;5;149mSemaphore):
422 """A bounded semaphore implementation.
423
424 This raises ValueError in release() if it would increase the value
425 above the initial value.
426 """
427
428 def __init__(self, value=1):
429 self._bound_value = value
430 super().__init__(value)
431
432 def release(self):
433 if self._value >= self._bound_value:
434 raise ValueError('BoundedSemaphore released too many times')
435 super().release()
436
437
438
439 class ESC[4;38;5;81m_BarrierState(ESC[4;38;5;149menumESC[4;38;5;149m.ESC[4;38;5;149mEnum):
440 FILLING = 'filling'
441 DRAINING = 'draining'
442 RESETTING = 'resetting'
443 BROKEN = 'broken'
444
445
446 class ESC[4;38;5;81mBarrier(ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
447 """Asyncio equivalent to threading.Barrier
448
449 Implements a Barrier primitive.
450 Useful for synchronizing a fixed number of tasks at known synchronization
451 points. Tasks block on 'wait()' and are simultaneously awoken once they
452 have all made their call.
453 """
454
455 def __init__(self, parties):
456 """Create a barrier, initialised to 'parties' tasks."""
457 if parties < 1:
458 raise ValueError('parties must be > 0')
459
460 self._cond = Condition() # notify all tasks when state changes
461
462 self._parties = parties
463 self._state = _BarrierState.FILLING
464 self._count = 0 # count tasks in Barrier
465
466 def __repr__(self):
467 res = super().__repr__()
468 extra = f'{self._state.value}'
469 if not self.broken:
470 extra += f', waiters:{self.n_waiting}/{self.parties}'
471 return f'<{res[1:-1]} [{extra}]>'
472
473 async def __aenter__(self):
474 # wait for the barrier reaches the parties number
475 # when start draining release and return index of waited task
476 return await self.wait()
477
478 async def __aexit__(self, *args):
479 pass
480
481 async def wait(self):
482 """Wait for the barrier.
483
484 When the specified number of tasks have started waiting, they are all
485 simultaneously awoken.
486 Returns an unique and individual index number from 0 to 'parties-1'.
487 """
488 async with self._cond:
489 await self._block() # Block while the barrier drains or resets.
490 try:
491 index = self._count
492 self._count += 1
493 if index + 1 == self._parties:
494 # We release the barrier
495 await self._release()
496 else:
497 await self._wait()
498 return index
499 finally:
500 self._count -= 1
501 # Wake up any tasks waiting for barrier to drain.
502 self._exit()
503
504 async def _block(self):
505 # Block until the barrier is ready for us,
506 # or raise an exception if it is broken.
507 #
508 # It is draining or resetting, wait until done
509 # unless a CancelledError occurs
510 await self._cond.wait_for(
511 lambda: self._state not in (
512 _BarrierState.DRAINING, _BarrierState.RESETTING
513 )
514 )
515
516 # see if the barrier is in a broken state
517 if self._state is _BarrierState.BROKEN:
518 raise exceptions.BrokenBarrierError("Barrier aborted")
519
520 async def _release(self):
521 # Release the tasks waiting in the barrier.
522
523 # Enter draining state.
524 # Next waiting tasks will be blocked until the end of draining.
525 self._state = _BarrierState.DRAINING
526 self._cond.notify_all()
527
528 async def _wait(self):
529 # Wait in the barrier until we are released. Raise an exception
530 # if the barrier is reset or broken.
531
532 # wait for end of filling
533 # unless a CancelledError occurs
534 await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
535
536 if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
537 raise exceptions.BrokenBarrierError("Abort or reset of barrier")
538
539 def _exit(self):
540 # If we are the last tasks to exit the barrier, signal any tasks
541 # waiting for the barrier to drain.
542 if self._count == 0:
543 if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
544 self._state = _BarrierState.FILLING
545 self._cond.notify_all()
546
547 async def reset(self):
548 """Reset the barrier to the initial state.
549
550 Any tasks currently waiting will get the BrokenBarrier exception
551 raised.
552 """
553 async with self._cond:
554 if self._count > 0:
555 if self._state is not _BarrierState.RESETTING:
556 #reset the barrier, waking up tasks
557 self._state = _BarrierState.RESETTING
558 else:
559 self._state = _BarrierState.FILLING
560 self._cond.notify_all()
561
562 async def abort(self):
563 """Place the barrier into a 'broken' state.
564
565 Useful in case of error. Any currently waiting tasks and tasks
566 attempting to 'wait()' will have BrokenBarrierError raised.
567 """
568 async with self._cond:
569 self._state = _BarrierState.BROKEN
570 self._cond.notify_all()
571
572 @property
573 def parties(self):
574 """Return the number of tasks required to trip the barrier."""
575 return self._parties
576
577 @property
578 def n_waiting(self):
579 """Return the number of tasks currently waiting at the barrier."""
580 if self._state is _BarrierState.FILLING:
581 return self._count
582 return 0
583
584 @property
585 def broken(self):
586 """Return True if the barrier is in a broken state."""
587 return self._state is _BarrierState.BROKEN