python (3.12.0)
1 # Adapted with permission from the EdgeDB project;
2 # license: PSFL.
3
4
5 __all__ = ("TaskGroup",)
6
7 from . import events
8 from . import exceptions
9 from . import tasks
10
11
12 class ESC[4;38;5;81mTaskGroup:
13 """Asynchronous context manager for managing groups of tasks.
14
15 Example use:
16
17 async with asyncio.TaskGroup() as group:
18 task1 = group.create_task(some_coroutine(...))
19 task2 = group.create_task(other_coroutine(...))
20 print("Both tasks have completed now.")
21
22 All tasks are awaited when the context manager exits.
23
24 Any exceptions other than `asyncio.CancelledError` raised within
25 a task will cancel all remaining tasks and wait for them to exit.
26 The exceptions are then combined and raised as an `ExceptionGroup`.
27 """
28 def __init__(self):
29 self._entered = False
30 self._exiting = False
31 self._aborting = False
32 self._loop = None
33 self._parent_task = None
34 self._parent_cancel_requested = False
35 self._tasks = set()
36 self._errors = []
37 self._base_error = None
38 self._on_completed_fut = None
39
40 def __repr__(self):
41 info = ['']
42 if self._tasks:
43 info.append(f'tasks={len(self._tasks)}')
44 if self._errors:
45 info.append(f'errors={len(self._errors)}')
46 if self._aborting:
47 info.append('cancelling')
48 elif self._entered:
49 info.append('entered')
50
51 info_str = ' '.join(info)
52 return f'<TaskGroup{info_str}>'
53
54 async def __aenter__(self):
55 if self._entered:
56 raise RuntimeError(
57 f"TaskGroup {self!r} has been already entered")
58 self._entered = True
59
60 if self._loop is None:
61 self._loop = events.get_running_loop()
62
63 self._parent_task = tasks.current_task(self._loop)
64 if self._parent_task is None:
65 raise RuntimeError(
66 f'TaskGroup {self!r} cannot determine the parent task')
67
68 return self
69
70 async def __aexit__(self, et, exc, tb):
71 self._exiting = True
72
73 if (exc is not None and
74 self._is_base_error(exc) and
75 self._base_error is None):
76 self._base_error = exc
77
78 propagate_cancellation_error = \
79 exc if et is exceptions.CancelledError else None
80 if self._parent_cancel_requested:
81 # If this flag is set we *must* call uncancel().
82 if self._parent_task.uncancel() == 0:
83 # If there are no pending cancellations left,
84 # don't propagate CancelledError.
85 propagate_cancellation_error = None
86
87 if et is not None:
88 if not self._aborting:
89 # Our parent task is being cancelled:
90 #
91 # async with TaskGroup() as g:
92 # g.create_task(...)
93 # await ... # <- CancelledError
94 #
95 # or there's an exception in "async with":
96 #
97 # async with TaskGroup() as g:
98 # g.create_task(...)
99 # 1 / 0
100 #
101 self._abort()
102
103 # We use while-loop here because "self._on_completed_fut"
104 # can be cancelled multiple times if our parent task
105 # is being cancelled repeatedly (or even once, when
106 # our own cancellation is already in progress)
107 while self._tasks:
108 if self._on_completed_fut is None:
109 self._on_completed_fut = self._loop.create_future()
110
111 try:
112 await self._on_completed_fut
113 except exceptions.CancelledError as ex:
114 if not self._aborting:
115 # Our parent task is being cancelled:
116 #
117 # async def wrapper():
118 # async with TaskGroup() as g:
119 # g.create_task(foo)
120 #
121 # "wrapper" is being cancelled while "foo" is
122 # still running.
123 propagate_cancellation_error = ex
124 self._abort()
125
126 self._on_completed_fut = None
127
128 assert not self._tasks
129
130 if self._base_error is not None:
131 raise self._base_error
132
133 # Propagate CancelledError if there is one, except if there
134 # are other errors -- those have priority.
135 if propagate_cancellation_error and not self._errors:
136 raise propagate_cancellation_error
137
138 if et is not None and et is not exceptions.CancelledError:
139 self._errors.append(exc)
140
141 if self._errors:
142 # Exceptions are heavy objects that can have object
143 # cycles (bad for GC); let's not keep a reference to
144 # a bunch of them.
145 try:
146 me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
147 raise me from None
148 finally:
149 self._errors = None
150
151 def create_task(self, coro, *, name=None, context=None):
152 """Create a new task in this group and return it.
153
154 Similar to `asyncio.create_task`.
155 """
156 if not self._entered:
157 raise RuntimeError(f"TaskGroup {self!r} has not been entered")
158 if self._exiting and not self._tasks:
159 raise RuntimeError(f"TaskGroup {self!r} is finished")
160 if self._aborting:
161 raise RuntimeError(f"TaskGroup {self!r} is shutting down")
162 if context is None:
163 task = self._loop.create_task(coro)
164 else:
165 task = self._loop.create_task(coro, context=context)
166 tasks._set_task_name(task, name)
167 # optimization: Immediately call the done callback if the task is
168 # already done (e.g. if the coro was able to complete eagerly),
169 # and skip scheduling a done callback
170 if task.done():
171 self._on_task_done(task)
172 else:
173 self._tasks.add(task)
174 task.add_done_callback(self._on_task_done)
175 return task
176
177 # Since Python 3.8 Tasks propagate all exceptions correctly,
178 # except for KeyboardInterrupt and SystemExit which are
179 # still considered special.
180
181 def _is_base_error(self, exc: BaseException) -> bool:
182 assert isinstance(exc, BaseException)
183 return isinstance(exc, (SystemExit, KeyboardInterrupt))
184
185 def _abort(self):
186 self._aborting = True
187
188 for t in self._tasks:
189 if not t.done():
190 t.cancel()
191
192 def _on_task_done(self, task):
193 self._tasks.discard(task)
194
195 if self._on_completed_fut is not None and not self._tasks:
196 if not self._on_completed_fut.done():
197 self._on_completed_fut.set_result(True)
198
199 if task.cancelled():
200 return
201
202 exc = task.exception()
203 if exc is None:
204 return
205
206 self._errors.append(exc)
207 if self._is_base_error(exc) and self._base_error is None:
208 self._base_error = exc
209
210 if self._parent_task.done():
211 # Not sure if this case is possible, but we want to handle
212 # it anyways.
213 self._loop.call_exception_handler({
214 'message': f'Task {task!r} has errored out but its parent '
215 f'task {self._parent_task} is already completed',
216 'exception': exc,
217 'task': task,
218 })
219 return
220
221 if not self._aborting and not self._parent_cancel_requested:
222 # If parent task *is not* being cancelled, it means that we want
223 # to manually cancel it to abort whatever is being run right now
224 # in the TaskGroup. But we want to mark parent task as
225 # "not cancelled" later in __aexit__. Example situation that
226 # we need to handle:
227 #
228 # async def foo():
229 # try:
230 # async with TaskGroup() as g:
231 # g.create_task(crash_soon())
232 # await something # <- this needs to be canceled
233 # # by the TaskGroup, e.g.
234 # # foo() needs to be cancelled
235 # except Exception:
236 # # Ignore any exceptions raised in the TaskGroup
237 # pass
238 # await something_else # this line has to be called
239 # # after TaskGroup is finished.
240 self._abort()
241 self._parent_cancel_requested = True
242 self._parent_task.cancel()