1  """Synchronization primitives."""
       2  
       3  __all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
       4             'BoundedSemaphore', 'Barrier')
       5  
       6  import collections
       7  import enum
       8  
       9  from . import exceptions
      10  from . import mixins
      11  
      12  class ESC[4;38;5;81m_ContextManagerMixin:
      13      async def __aenter__(self):
      14          await self.acquire()
      15          # We have no use for the "as ..."  clause in the with
      16          # statement for locks.
      17          return None
      18  
      19      async def __aexit__(self, exc_type, exc, tb):
      20          self.release()
      21  
      22  
      23  class ESC[4;38;5;81mLock(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
      24      """Primitive lock objects.
      25  
      26      A primitive lock is a synchronization primitive that is not owned
      27      by a particular coroutine when locked.  A primitive lock is in one
      28      of two states, 'locked' or 'unlocked'.
      29  
      30      It is created in the unlocked state.  It has two basic methods,
      31      acquire() and release().  When the state is unlocked, acquire()
      32      changes the state to locked and returns immediately.  When the
      33      state is locked, acquire() blocks until a call to release() in
      34      another coroutine changes it to unlocked, then the acquire() call
      35      resets it to locked and returns.  The release() method should only
      36      be called in the locked state; it changes the state to unlocked
      37      and returns immediately.  If an attempt is made to release an
      38      unlocked lock, a RuntimeError will be raised.
      39  
      40      When more than one coroutine is blocked in acquire() waiting for
      41      the state to turn to unlocked, only one coroutine proceeds when a
      42      release() call resets the state to unlocked; first coroutine which
      43      is blocked in acquire() is being processed.
      44  
      45      acquire() is a coroutine and should be called with 'await'.
      46  
      47      Locks also support the asynchronous context management protocol.
      48      'async with lock' statement should be used.
      49  
      50      Usage:
      51  
      52          lock = Lock()
      53          ...
      54          await lock.acquire()
      55          try:
      56              ...
      57          finally:
      58              lock.release()
      59  
      60      Context manager usage:
      61  
      62          lock = Lock()
      63          ...
      64          async with lock:
      65               ...
      66  
      67      Lock objects can be tested for locking state:
      68  
      69          if not lock.locked():
      70             await lock.acquire()
      71          else:
      72             # lock is acquired
      73             ...
      74  
      75      """
      76  
      77      def __init__(self):
      78          self._waiters = None
      79          self._locked = False
      80  
      81      def __repr__(self):
      82          res = super().__repr__()
      83          extra = 'locked' if self._locked else 'unlocked'
      84          if self._waiters:
      85              extra = f'{extra}, waiters:{len(self._waiters)}'
      86          return f'<{res[1:-1]} [{extra}]>'
      87  
      88      def locked(self):
      89          """Return True if lock is acquired."""
      90          return self._locked
      91  
      92      async def acquire(self):
      93          """Acquire a lock.
      94  
      95          This method blocks until the lock is unlocked, then sets it to
      96          locked and returns True.
      97          """
      98          if (not self._locked and (self._waiters is None or
      99                  all(w.cancelled() for w in self._waiters))):
     100              self._locked = True
     101              return True
     102  
     103          if self._waiters is None:
     104              self._waiters = collections.deque()
     105          fut = self._get_loop().create_future()
     106          self._waiters.append(fut)
     107  
     108          # Finally block should be called before the CancelledError
     109          # handling as we don't want CancelledError to call
     110          # _wake_up_first() and attempt to wake up itself.
     111          try:
     112              try:
     113                  await fut
     114              finally:
     115                  self._waiters.remove(fut)
     116          except exceptions.CancelledError:
     117              if not self._locked:
     118                  self._wake_up_first()
     119              raise
     120  
     121          self._locked = True
     122          return True
     123  
     124      def release(self):
     125          """Release a lock.
     126  
     127          When the lock is locked, reset it to unlocked, and return.
     128          If any other coroutines are blocked waiting for the lock to become
     129          unlocked, allow exactly one of them to proceed.
     130  
     131          When invoked on an unlocked lock, a RuntimeError is raised.
     132  
     133          There is no return value.
     134          """
     135          if self._locked:
     136              self._locked = False
     137              self._wake_up_first()
     138          else:
     139              raise RuntimeError('Lock is not acquired.')
     140  
     141      def _wake_up_first(self):
     142          """Wake up the first waiter if it isn't done."""
     143          if not self._waiters:
     144              return
     145          try:
     146              fut = next(iter(self._waiters))
     147          except StopIteration:
     148              return
     149  
     150          # .done() necessarily means that a waiter will wake up later on and
     151          # either take the lock, or, if it was cancelled and lock wasn't
     152          # taken already, will hit this again and wake up a new waiter.
     153          if not fut.done():
     154              fut.set_result(True)
     155  
     156  
     157  class ESC[4;38;5;81mEvent(ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
     158      """Asynchronous equivalent to threading.Event.
     159  
     160      Class implementing event objects. An event manages a flag that can be set
     161      to true with the set() method and reset to false with the clear() method.
     162      The wait() method blocks until the flag is true. The flag is initially
     163      false.
     164      """
     165  
     166      def __init__(self):
     167          self._waiters = collections.deque()
     168          self._value = False
     169  
     170      def __repr__(self):
     171          res = super().__repr__()
     172          extra = 'set' if self._value else 'unset'
     173          if self._waiters:
     174              extra = f'{extra}, waiters:{len(self._waiters)}'
     175          return f'<{res[1:-1]} [{extra}]>'
     176  
     177      def is_set(self):
     178          """Return True if and only if the internal flag is true."""
     179          return self._value
     180  
     181      def set(self):
     182          """Set the internal flag to true. All coroutines waiting for it to
     183          become true are awakened. Coroutine that call wait() once the flag is
     184          true will not block at all.
     185          """
     186          if not self._value:
     187              self._value = True
     188  
     189              for fut in self._waiters:
     190                  if not fut.done():
     191                      fut.set_result(True)
     192  
     193      def clear(self):
     194          """Reset the internal flag to false. Subsequently, coroutines calling
     195          wait() will block until set() is called to set the internal flag
     196          to true again."""
     197          self._value = False
     198  
     199      async def wait(self):
     200          """Block until the internal flag is true.
     201  
     202          If the internal flag is true on entry, return True
     203          immediately.  Otherwise, block until another coroutine calls
     204          set() to set the flag to true, then return True.
     205          """
     206          if self._value:
     207              return True
     208  
     209          fut = self._get_loop().create_future()
     210          self._waiters.append(fut)
     211          try:
     212              await fut
     213              return True
     214          finally:
     215              self._waiters.remove(fut)
     216  
     217  
     218  class ESC[4;38;5;81mCondition(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
     219      """Asynchronous equivalent to threading.Condition.
     220  
     221      This class implements condition variable objects. A condition variable
     222      allows one or more coroutines to wait until they are notified by another
     223      coroutine.
     224  
     225      A new Lock object is created and used as the underlying lock.
     226      """
     227  
     228      def __init__(self, lock=None):
     229          if lock is None:
     230              lock = Lock()
     231  
     232          self._lock = lock
     233          # Export the lock's locked(), acquire() and release() methods.
     234          self.locked = lock.locked
     235          self.acquire = lock.acquire
     236          self.release = lock.release
     237  
     238          self._waiters = collections.deque()
     239  
     240      def __repr__(self):
     241          res = super().__repr__()
     242          extra = 'locked' if self.locked() else 'unlocked'
     243          if self._waiters:
     244              extra = f'{extra}, waiters:{len(self._waiters)}'
     245          return f'<{res[1:-1]} [{extra}]>'
     246  
     247      async def wait(self):
     248          """Wait until notified.
     249  
     250          If the calling coroutine has not acquired the lock when this
     251          method is called, a RuntimeError is raised.
     252  
     253          This method releases the underlying lock, and then blocks
     254          until it is awakened by a notify() or notify_all() call for
     255          the same condition variable in another coroutine.  Once
     256          awakened, it re-acquires the lock and returns True.
     257          """
     258          if not self.locked():
     259              raise RuntimeError('cannot wait on un-acquired lock')
     260  
     261          self.release()
     262          try:
     263              fut = self._get_loop().create_future()
     264              self._waiters.append(fut)
     265              try:
     266                  await fut
     267                  return True
     268              finally:
     269                  self._waiters.remove(fut)
     270  
     271          finally:
     272              # Must reacquire lock even if wait is cancelled
     273              cancelled = False
     274              while True:
     275                  try:
     276                      await self.acquire()
     277                      break
     278                  except exceptions.CancelledError:
     279                      cancelled = True
     280  
     281              if cancelled:
     282                  raise exceptions.CancelledError
     283  
     284      async def wait_for(self, predicate):
     285          """Wait until a predicate becomes true.
     286  
     287          The predicate should be a callable which result will be
     288          interpreted as a boolean value.  The final predicate value is
     289          the return value.
     290          """
     291          result = predicate()
     292          while not result:
     293              await self.wait()
     294              result = predicate()
     295          return result
     296  
     297      def notify(self, n=1):
     298          """By default, wake up one coroutine waiting on this condition, if any.
     299          If the calling coroutine has not acquired the lock when this method
     300          is called, a RuntimeError is raised.
     301  
     302          This method wakes up at most n of the coroutines waiting for the
     303          condition variable; it is a no-op if no coroutines are waiting.
     304  
     305          Note: an awakened coroutine does not actually return from its
     306          wait() call until it can reacquire the lock. Since notify() does
     307          not release the lock, its caller should.
     308          """
     309          if not self.locked():
     310              raise RuntimeError('cannot notify on un-acquired lock')
     311  
     312          idx = 0
     313          for fut in self._waiters:
     314              if idx >= n:
     315                  break
     316  
     317              if not fut.done():
     318                  idx += 1
     319                  fut.set_result(False)
     320  
     321      def notify_all(self):
     322          """Wake up all threads waiting on this condition. This method acts
     323          like notify(), but wakes up all waiting threads instead of one. If the
     324          calling thread has not acquired the lock when this method is called,
     325          a RuntimeError is raised.
     326          """
     327          self.notify(len(self._waiters))
     328  
     329  
     330  class ESC[4;38;5;81mSemaphore(ESC[4;38;5;149m_ContextManagerMixin, ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
     331      """A Semaphore implementation.
     332  
     333      A semaphore manages an internal counter which is decremented by each
     334      acquire() call and incremented by each release() call. The counter
     335      can never go below zero; when acquire() finds that it is zero, it blocks,
     336      waiting until some other thread calls release().
     337  
     338      Semaphores also support the context management protocol.
     339  
     340      The optional argument gives the initial value for the internal
     341      counter; it defaults to 1. If the value given is less than 0,
     342      ValueError is raised.
     343      """
     344  
     345      def __init__(self, value=1):
     346          if value < 0:
     347              raise ValueError("Semaphore initial value must be >= 0")
     348          self._waiters = None
     349          self._value = value
     350  
     351      def __repr__(self):
     352          res = super().__repr__()
     353          extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
     354          if self._waiters:
     355              extra = f'{extra}, waiters:{len(self._waiters)}'
     356          return f'<{res[1:-1]} [{extra}]>'
     357  
     358      def locked(self):
     359          """Returns True if semaphore cannot be acquired immediately."""
     360          return self._value == 0 or (
     361              any(not w.cancelled() for w in (self._waiters or ())))
     362  
     363      async def acquire(self):
     364          """Acquire a semaphore.
     365  
     366          If the internal counter is larger than zero on entry,
     367          decrement it by one and return True immediately.  If it is
     368          zero on entry, block, waiting until some other coroutine has
     369          called release() to make it larger than 0, and then return
     370          True.
     371          """
     372          if not self.locked():
     373              self._value -= 1
     374              return True
     375  
     376          if self._waiters is None:
     377              self._waiters = collections.deque()
     378          fut = self._get_loop().create_future()
     379          self._waiters.append(fut)
     380  
     381          # Finally block should be called before the CancelledError
     382          # handling as we don't want CancelledError to call
     383          # _wake_up_first() and attempt to wake up itself.
     384          try:
     385              try:
     386                  await fut
     387              finally:
     388                  self._waiters.remove(fut)
     389          except exceptions.CancelledError:
     390              if not fut.cancelled():
     391                  self._value += 1
     392                  self._wake_up_next()
     393              raise
     394  
     395          if self._value > 0:
     396              self._wake_up_next()
     397          return True
     398  
     399      def release(self):
     400          """Release a semaphore, incrementing the internal counter by one.
     401  
     402          When it was zero on entry and another coroutine is waiting for it to
     403          become larger than zero again, wake up that coroutine.
     404          """
     405          self._value += 1
     406          self._wake_up_next()
     407  
     408      def _wake_up_next(self):
     409          """Wake up the first waiter that isn't done."""
     410          if not self._waiters:
     411              return
     412  
     413          for fut in self._waiters:
     414              if not fut.done():
     415                  self._value -= 1
     416                  fut.set_result(True)
     417                  return
     418  
     419  
     420  class ESC[4;38;5;81mBoundedSemaphore(ESC[4;38;5;149mSemaphore):
     421      """A bounded semaphore implementation.
     422  
     423      This raises ValueError in release() if it would increase the value
     424      above the initial value.
     425      """
     426  
     427      def __init__(self, value=1):
     428          self._bound_value = value
     429          super().__init__(value)
     430  
     431      def release(self):
     432          if self._value >= self._bound_value:
     433              raise ValueError('BoundedSemaphore released too many times')
     434          super().release()
     435  
     436  
     437  
     438  class ESC[4;38;5;81m_BarrierState(ESC[4;38;5;149menumESC[4;38;5;149m.ESC[4;38;5;149mEnum):
     439      FILLING = 'filling'
     440      DRAINING = 'draining'
     441      RESETTING = 'resetting'
     442      BROKEN = 'broken'
     443  
     444  
     445  class ESC[4;38;5;81mBarrier(ESC[4;38;5;149mmixinsESC[4;38;5;149m.ESC[4;38;5;149m_LoopBoundMixin):
     446      """Asyncio equivalent to threading.Barrier
     447  
     448      Implements a Barrier primitive.
     449      Useful for synchronizing a fixed number of tasks at known synchronization
     450      points. Tasks block on 'wait()' and are simultaneously awoken once they
     451      have all made their call.
     452      """
     453  
     454      def __init__(self, parties):
     455          """Create a barrier, initialised to 'parties' tasks."""
     456          if parties < 1:
     457              raise ValueError('parties must be > 0')
     458  
     459          self._cond = Condition() # notify all tasks when state changes
     460  
     461          self._parties = parties
     462          self._state = _BarrierState.FILLING
     463          self._count = 0       # count tasks in Barrier
     464  
     465      def __repr__(self):
     466          res = super().__repr__()
     467          extra = f'{self._state.value}'
     468          if not self.broken:
     469              extra += f', waiters:{self.n_waiting}/{self.parties}'
     470          return f'<{res[1:-1]} [{extra}]>'
     471  
     472      async def __aenter__(self):
     473          # wait for the barrier reaches the parties number
     474          # when start draining release and return index of waited task
     475          return await self.wait()
     476  
     477      async def __aexit__(self, *args):
     478          pass
     479  
     480      async def wait(self):
     481          """Wait for the barrier.
     482  
     483          When the specified number of tasks have started waiting, they are all
     484          simultaneously awoken.
     485          Returns an unique and individual index number from 0 to 'parties-1'.
     486          """
     487          async with self._cond:
     488              await self._block() # Block while the barrier drains or resets.
     489              try:
     490                  index = self._count
     491                  self._count += 1
     492                  if index + 1 == self._parties:
     493                      # We release the barrier
     494                      await self._release()
     495                  else:
     496                      await self._wait()
     497                  return index
     498              finally:
     499                  self._count -= 1
     500                  # Wake up any tasks waiting for barrier to drain.
     501                  self._exit()
     502  
     503      async def _block(self):
     504          # Block until the barrier is ready for us,
     505          # or raise an exception if it is broken.
     506          #
     507          # It is draining or resetting, wait until done
     508          # unless a CancelledError occurs
     509          await self._cond.wait_for(
     510              lambda: self._state not in (
     511                  _BarrierState.DRAINING, _BarrierState.RESETTING
     512              )
     513          )
     514  
     515          # see if the barrier is in a broken state
     516          if self._state is _BarrierState.BROKEN:
     517              raise exceptions.BrokenBarrierError("Barrier aborted")
     518  
     519      async def _release(self):
     520          # Release the tasks waiting in the barrier.
     521  
     522          # Enter draining state.
     523          # Next waiting tasks will be blocked until the end of draining.
     524          self._state = _BarrierState.DRAINING
     525          self._cond.notify_all()
     526  
     527      async def _wait(self):
     528          # Wait in the barrier until we are released. Raise an exception
     529          # if the barrier is reset or broken.
     530  
     531          # wait for end of filling
     532          # unless a CancelledError occurs
     533          await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
     534  
     535          if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
     536              raise exceptions.BrokenBarrierError("Abort or reset of barrier")
     537  
     538      def _exit(self):
     539          # If we are the last tasks to exit the barrier, signal any tasks
     540          # waiting for the barrier to drain.
     541          if self._count == 0:
     542              if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
     543                  self._state = _BarrierState.FILLING
     544              self._cond.notify_all()
     545  
     546      async def reset(self):
     547          """Reset the barrier to the initial state.
     548  
     549          Any tasks currently waiting will get the BrokenBarrier exception
     550          raised.
     551          """
     552          async with self._cond:
     553              if self._count > 0:
     554                  if self._state is not _BarrierState.RESETTING:
     555                      #reset the barrier, waking up tasks
     556                      self._state = _BarrierState.RESETTING
     557              else:
     558                  self._state = _BarrierState.FILLING
     559              self._cond.notify_all()
     560  
     561      async def abort(self):
     562          """Place the barrier into a 'broken' state.
     563  
     564          Useful in case of error.  Any currently waiting tasks and tasks
     565          attempting to 'wait()' will have BrokenBarrierError raised.
     566          """
     567          async with self._cond:
     568              self._state = _BarrierState.BROKEN
     569              self._cond.notify_all()
     570  
     571      @property
     572      def parties(self):
     573          """Return the number of tasks required to trip the barrier."""
     574          return self._parties
     575  
     576      @property
     577      def n_waiting(self):
     578          """Return the number of tasks currently waiting at the barrier."""
     579          if self._state is _BarrierState.FILLING:
     580              return self._count
     581          return 0
     582  
     583      @property
     584      def broken(self):
     585          """Return True if the barrier is in a broken state."""
     586          return self._state is _BarrierState.BROKEN