1 # Adapted with permission from the EdgeDB project;
2 # license: PSFL.
3
4
5 __all__ = ["TaskGroup"]
6
7 from . import events
8 from . import exceptions
9 from . import tasks
10
11
12 class ESC[4;38;5;81mTaskGroup:
13 """Asynchronous context manager for managing groups of tasks.
14
15 Example use:
16
17 async with asyncio.TaskGroup() as group:
18 task1 = group.create_task(some_coroutine(...))
19 task2 = group.create_task(other_coroutine(...))
20 print("Both tasks have completed now.")
21
22 All tasks are awaited when the context manager exits.
23
24 Any exceptions other than `asyncio.CancelledError` raised within
25 a task will cancel all remaining tasks and wait for them to exit.
26 The exceptions are then combined and raised as an `ExceptionGroup`.
27 """
28 def __init__(self):
29 self._entered = False
30 self._exiting = False
31 self._aborting = False
32 self._loop = None
33 self._parent_task = None
34 self._parent_cancel_requested = False
35 self._tasks = set()
36 self._errors = []
37 self._base_error = None
38 self._on_completed_fut = None
39
40 def __repr__(self):
41 info = ['']
42 if self._tasks:
43 info.append(f'tasks={len(self._tasks)}')
44 if self._errors:
45 info.append(f'errors={len(self._errors)}')
46 if self._aborting:
47 info.append('cancelling')
48 elif self._entered:
49 info.append('entered')
50
51 info_str = ' '.join(info)
52 return f'<TaskGroup{info_str}>'
53
54 async def __aenter__(self):
55 if self._entered:
56 raise RuntimeError(
57 f"TaskGroup {self!r} has already been entered")
58 if self._loop is None:
59 self._loop = events.get_running_loop()
60 self._parent_task = tasks.current_task(self._loop)
61 if self._parent_task is None:
62 raise RuntimeError(
63 f'TaskGroup {self!r} cannot determine the parent task')
64 self._entered = True
65
66 return self
67
68 async def __aexit__(self, et, exc, tb):
69 self._exiting = True
70
71 if (exc is not None and
72 self._is_base_error(exc) and
73 self._base_error is None):
74 self._base_error = exc
75
76 propagate_cancellation_error = \
77 exc if et is exceptions.CancelledError else None
78 if self._parent_cancel_requested:
79 # If this flag is set we *must* call uncancel().
80 if self._parent_task.uncancel() == 0:
81 # If there are no pending cancellations left,
82 # don't propagate CancelledError.
83 propagate_cancellation_error = None
84
85 if et is not None:
86 if not self._aborting:
87 # Our parent task is being cancelled:
88 #
89 # async with TaskGroup() as g:
90 # g.create_task(...)
91 # await ... # <- CancelledError
92 #
93 # or there's an exception in "async with":
94 #
95 # async with TaskGroup() as g:
96 # g.create_task(...)
97 # 1 / 0
98 #
99 self._abort()
100
101 # We use while-loop here because "self._on_completed_fut"
102 # can be cancelled multiple times if our parent task
103 # is being cancelled repeatedly (or even once, when
104 # our own cancellation is already in progress)
105 while self._tasks:
106 if self._on_completed_fut is None:
107 self._on_completed_fut = self._loop.create_future()
108
109 try:
110 await self._on_completed_fut
111 except exceptions.CancelledError as ex:
112 if not self._aborting:
113 # Our parent task is being cancelled:
114 #
115 # async def wrapper():
116 # async with TaskGroup() as g:
117 # g.create_task(foo)
118 #
119 # "wrapper" is being cancelled while "foo" is
120 # still running.
121 propagate_cancellation_error = ex
122 self._abort()
123
124 self._on_completed_fut = None
125
126 assert not self._tasks
127
128 if self._base_error is not None:
129 raise self._base_error
130
131 # Propagate CancelledError if there is one, except if there
132 # are other errors -- those have priority.
133 if propagate_cancellation_error and not self._errors:
134 raise propagate_cancellation_error
135
136 if et is not None and et is not exceptions.CancelledError:
137 self._errors.append(exc)
138
139 if self._errors:
140 # Exceptions are heavy objects that can have object
141 # cycles (bad for GC); let's not keep a reference to
142 # a bunch of them.
143 try:
144 me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
145 raise me from None
146 finally:
147 self._errors = None
148
149 def create_task(self, coro, *, name=None, context=None):
150 """Create a new task in this group and return it.
151
152 Similar to `asyncio.create_task`.
153 """
154 if not self._entered:
155 raise RuntimeError(f"TaskGroup {self!r} has not been entered")
156 if self._exiting and not self._tasks:
157 raise RuntimeError(f"TaskGroup {self!r} is finished")
158 if self._aborting:
159 raise RuntimeError(f"TaskGroup {self!r} is shutting down")
160 if context is None:
161 task = self._loop.create_task(coro)
162 else:
163 task = self._loop.create_task(coro, context=context)
164 tasks._set_task_name(task, name)
165 task.add_done_callback(self._on_task_done)
166 self._tasks.add(task)
167 return task
168
169 # Since Python 3.8 Tasks propagate all exceptions correctly,
170 # except for KeyboardInterrupt and SystemExit which are
171 # still considered special.
172
173 def _is_base_error(self, exc: BaseException) -> bool:
174 assert isinstance(exc, BaseException)
175 return isinstance(exc, (SystemExit, KeyboardInterrupt))
176
177 def _abort(self):
178 self._aborting = True
179
180 for t in self._tasks:
181 if not t.done():
182 t.cancel()
183
184 def _on_task_done(self, task):
185 self._tasks.discard(task)
186
187 if self._on_completed_fut is not None and not self._tasks:
188 if not self._on_completed_fut.done():
189 self._on_completed_fut.set_result(True)
190
191 if task.cancelled():
192 return
193
194 exc = task.exception()
195 if exc is None:
196 return
197
198 self._errors.append(exc)
199 if self._is_base_error(exc) and self._base_error is None:
200 self._base_error = exc
201
202 if self._parent_task.done():
203 # Not sure if this case is possible, but we want to handle
204 # it anyways.
205 self._loop.call_exception_handler({
206 'message': f'Task {task!r} has errored out but its parent '
207 f'task {self._parent_task} is already completed',
208 'exception': exc,
209 'task': task,
210 })
211 return
212
213 if not self._aborting and not self._parent_cancel_requested:
214 # If parent task *is not* being cancelled, it means that we want
215 # to manually cancel it to abort whatever is being run right now
216 # in the TaskGroup. But we want to mark parent task as
217 # "not cancelled" later in __aexit__. Example situation that
218 # we need to handle:
219 #
220 # async def foo():
221 # try:
222 # async with TaskGroup() as g:
223 # g.create_task(crash_soon())
224 # await something # <- this needs to be canceled
225 # # by the TaskGroup, e.g.
226 # # foo() needs to be cancelled
227 # except Exception:
228 # # Ignore any exceptions raised in the TaskGroup
229 # pass
230 # await something_else # this line has to be called
231 # # after TaskGroup is finished.
232 self._abort()
233 self._parent_cancel_requested = True
234 self._parent_task.cancel()