1 """Support for running coroutines in parallel with staggered start times."""
2
3 __all__ = 'staggered_race',
4
5 import contextlib
6 import typing
7
8 from . import events
9 from . import exceptions as exceptions_mod
10 from . import locks
11 from . import tasks
12
13
14 async def staggered_race(
15 coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
16 delay: typing.Optional[float],
17 *,
18 loop: events.AbstractEventLoop = None,
19 ) -> typing.Tuple[
20 typing.Any,
21 typing.Optional[int],
22 typing.List[typing.Optional[Exception]]
23 ]:
24 """Run coroutines with staggered start times and take the first to finish.
25
26 This method takes an iterable of coroutine functions. The first one is
27 started immediately. From then on, whenever the immediately preceding one
28 fails (raises an exception), or when *delay* seconds has passed, the next
29 coroutine is started. This continues until one of the coroutines complete
30 successfully, in which case all others are cancelled, or until all
31 coroutines fail.
32
33 The coroutines provided should be well-behaved in the following way:
34
35 * They should only ``return`` if completed successfully.
36
37 * They should always raise an exception if they did not complete
38 successfully. In particular, if they handle cancellation, they should
39 probably reraise, like this::
40
41 try:
42 # do work
43 except asyncio.CancelledError:
44 # undo partially completed work
45 raise
46
47 Args:
48 coro_fns: an iterable of coroutine functions, i.e. callables that
49 return a coroutine object when called. Use ``functools.partial`` or
50 lambdas to pass arguments.
51
52 delay: amount of time, in seconds, between starting coroutines. If
53 ``None``, the coroutines will run sequentially.
54
55 loop: the event loop to use.
56
57 Returns:
58 tuple *(winner_result, winner_index, exceptions)* where
59
60 - *winner_result*: the result of the winning coroutine, or ``None``
61 if no coroutines won.
62
63 - *winner_index*: the index of the winning coroutine in
64 ``coro_fns``, or ``None`` if no coroutines won. If the winning
65 coroutine may return None on success, *winner_index* can be used
66 to definitively determine whether any coroutine won.
67
68 - *exceptions*: list of exceptions returned by the coroutines.
69 ``len(exceptions)`` is equal to the number of coroutines actually
70 started, and the order is the same as in ``coro_fns``. The winning
71 coroutine's entry is ``None``.
72
73 """
74 # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
75 loop = loop or events.get_running_loop()
76 enum_coro_fns = enumerate(coro_fns)
77 winner_result = None
78 winner_index = None
79 exceptions = []
80 running_tasks = []
81
82 async def run_one_coro(
83 previous_failed: typing.Optional[locks.Event]) -> None:
84 # Wait for the previous task to finish, or for delay seconds
85 if previous_failed is not None:
86 with contextlib.suppress(exceptions_mod.TimeoutError):
87 # Use asyncio.wait_for() instead of asyncio.wait() here, so
88 # that if we get cancelled at this point, Event.wait() is also
89 # cancelled, otherwise there will be a "Task destroyed but it is
90 # pending" later.
91 await tasks.wait_for(previous_failed.wait(), delay)
92 # Get the next coroutine to run
93 try:
94 this_index, coro_fn = next(enum_coro_fns)
95 except StopIteration:
96 return
97 # Start task that will run the next coroutine
98 this_failed = locks.Event()
99 next_task = loop.create_task(run_one_coro(this_failed))
100 running_tasks.append(next_task)
101 assert len(running_tasks) == this_index + 2
102 # Prepare place to put this coroutine's exceptions if not won
103 exceptions.append(None)
104 assert len(exceptions) == this_index + 1
105
106 try:
107 result = await coro_fn()
108 except (SystemExit, KeyboardInterrupt):
109 raise
110 except BaseException as e:
111 exceptions[this_index] = e
112 this_failed.set() # Kickstart the next coroutine
113 else:
114 # Store winner's results
115 nonlocal winner_index, winner_result
116 assert winner_index is None
117 winner_index = this_index
118 winner_result = result
119 # Cancel all other tasks. We take care to not cancel the current
120 # task as well. If we do so, then since there is no `await` after
121 # here and CancelledError are usually thrown at one, we will
122 # encounter a curious corner case where the current task will end
123 # up as done() == True, cancelled() == False, exception() ==
124 # asyncio.CancelledError. This behavior is specified in
125 # https://bugs.python.org/issue30048
126 for i, t in enumerate(running_tasks):
127 if i != this_index:
128 t.cancel()
129
130 first_task = loop.create_task(run_one_coro(None))
131 running_tasks.append(first_task)
132 try:
133 # Wait for a growing list of tasks to all finish: poor man's version of
134 # curio's TaskGroup or trio's nursery
135 done_count = 0
136 while done_count != len(running_tasks):
137 done, _ = await tasks.wait(running_tasks)
138 done_count = len(done)
139 # If run_one_coro raises an unhandled exception, it's probably a
140 # programming error, and I want to see it.
141 if __debug__:
142 for d in done:
143 if d.done() and not d.cancelled() and d.exception():
144 raise d.exception()
145 return winner_result, winner_index, exceptions
146 finally:
147 # Make sure no tasks are left running if we leave this function
148 for t in running_tasks:
149 t.cancel()