python (3.12.0)
1 # Adapted with permission from the EdgeDB project;
2 # license: PSFL.
3
4
5 import asyncio
6 import contextvars
7 import contextlib
8 from asyncio import taskgroups
9 import unittest
10
11
12 # To prevent a warning "test altered the execution environment"
13 def tearDownModule():
14 asyncio.set_event_loop_policy(None)
15
16
17 class ESC[4;38;5;81mMyExc(ESC[4;38;5;149mException):
18 pass
19
20
21 class ESC[4;38;5;81mMyBaseExc(ESC[4;38;5;149mBaseException):
22 pass
23
24
25 def get_error_types(eg):
26 return {type(exc) for exc in eg.exceptions}
27
28
29 class ESC[4;38;5;81mTestTaskGroup(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mIsolatedAsyncioTestCase):
30
31 async def test_taskgroup_01(self):
32
33 async def foo1():
34 await asyncio.sleep(0.1)
35 return 42
36
37 async def foo2():
38 await asyncio.sleep(0.2)
39 return 11
40
41 async with taskgroups.TaskGroup() as g:
42 t1 = g.create_task(foo1())
43 t2 = g.create_task(foo2())
44
45 self.assertEqual(t1.result(), 42)
46 self.assertEqual(t2.result(), 11)
47
48 async def test_taskgroup_02(self):
49
50 async def foo1():
51 await asyncio.sleep(0.1)
52 return 42
53
54 async def foo2():
55 await asyncio.sleep(0.2)
56 return 11
57
58 async with taskgroups.TaskGroup() as g:
59 t1 = g.create_task(foo1())
60 await asyncio.sleep(0.15)
61 t2 = g.create_task(foo2())
62
63 self.assertEqual(t1.result(), 42)
64 self.assertEqual(t2.result(), 11)
65
66 async def test_taskgroup_03(self):
67
68 async def foo1():
69 await asyncio.sleep(1)
70 return 42
71
72 async def foo2():
73 await asyncio.sleep(0.2)
74 return 11
75
76 async with taskgroups.TaskGroup() as g:
77 t1 = g.create_task(foo1())
78 await asyncio.sleep(0.15)
79 # cancel t1 explicitly, i.e. everything should continue
80 # working as expected.
81 t1.cancel()
82
83 t2 = g.create_task(foo2())
84
85 self.assertTrue(t1.cancelled())
86 self.assertEqual(t2.result(), 11)
87
88 async def test_taskgroup_04(self):
89
90 NUM = 0
91 t2_cancel = False
92 t2 = None
93
94 async def foo1():
95 await asyncio.sleep(0.1)
96 1 / 0
97
98 async def foo2():
99 nonlocal NUM, t2_cancel
100 try:
101 await asyncio.sleep(1)
102 except asyncio.CancelledError:
103 t2_cancel = True
104 raise
105 NUM += 1
106
107 async def runner():
108 nonlocal NUM, t2
109
110 async with taskgroups.TaskGroup() as g:
111 g.create_task(foo1())
112 t2 = g.create_task(foo2())
113
114 NUM += 10
115
116 with self.assertRaises(ExceptionGroup) as cm:
117 await asyncio.create_task(runner())
118
119 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
120
121 self.assertEqual(NUM, 0)
122 self.assertTrue(t2_cancel)
123 self.assertTrue(t2.cancelled())
124
125 async def test_cancel_children_on_child_error(self):
126 # When a child task raises an error, the rest of the children
127 # are cancelled and the errors are gathered into an EG.
128
129 NUM = 0
130 t2_cancel = False
131 runner_cancel = False
132
133 async def foo1():
134 await asyncio.sleep(0.1)
135 1 / 0
136
137 async def foo2():
138 nonlocal NUM, t2_cancel
139 try:
140 await asyncio.sleep(5)
141 except asyncio.CancelledError:
142 t2_cancel = True
143 raise
144 NUM += 1
145
146 async def runner():
147 nonlocal NUM, runner_cancel
148
149 async with taskgroups.TaskGroup() as g:
150 g.create_task(foo1())
151 g.create_task(foo1())
152 g.create_task(foo1())
153 g.create_task(foo2())
154 try:
155 await asyncio.sleep(10)
156 except asyncio.CancelledError:
157 runner_cancel = True
158 raise
159
160 NUM += 10
161
162 # The 3 foo1 sub tasks can be racy when the host is busy - if the
163 # cancellation happens in the middle, we'll see partial sub errors here
164 with self.assertRaises(ExceptionGroup) as cm:
165 await asyncio.create_task(runner())
166
167 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
168 self.assertEqual(NUM, 0)
169 self.assertTrue(t2_cancel)
170 self.assertTrue(runner_cancel)
171
172 async def test_cancellation(self):
173
174 NUM = 0
175
176 async def foo():
177 nonlocal NUM
178 try:
179 await asyncio.sleep(5)
180 except asyncio.CancelledError:
181 NUM += 1
182 raise
183
184 async def runner():
185 async with taskgroups.TaskGroup() as g:
186 for _ in range(5):
187 g.create_task(foo())
188
189 r = asyncio.create_task(runner())
190 await asyncio.sleep(0.1)
191
192 self.assertFalse(r.done())
193 r.cancel()
194 with self.assertRaises(asyncio.CancelledError) as cm:
195 await r
196
197 self.assertEqual(NUM, 5)
198
199 async def test_taskgroup_07(self):
200
201 NUM = 0
202
203 async def foo():
204 nonlocal NUM
205 try:
206 await asyncio.sleep(5)
207 except asyncio.CancelledError:
208 NUM += 1
209 raise
210
211 async def runner():
212 nonlocal NUM
213 async with taskgroups.TaskGroup() as g:
214 for _ in range(5):
215 g.create_task(foo())
216
217 try:
218 await asyncio.sleep(10)
219 except asyncio.CancelledError:
220 NUM += 10
221 raise
222
223 r = asyncio.create_task(runner())
224 await asyncio.sleep(0.1)
225
226 self.assertFalse(r.done())
227 r.cancel()
228 with self.assertRaises(asyncio.CancelledError):
229 await r
230
231 self.assertEqual(NUM, 15)
232
233 async def test_taskgroup_08(self):
234
235 async def foo():
236 try:
237 await asyncio.sleep(10)
238 finally:
239 1 / 0
240
241 async def runner():
242 async with taskgroups.TaskGroup() as g:
243 for _ in range(5):
244 g.create_task(foo())
245
246 await asyncio.sleep(10)
247
248 r = asyncio.create_task(runner())
249 await asyncio.sleep(0.1)
250
251 self.assertFalse(r.done())
252 r.cancel()
253 with self.assertRaises(ExceptionGroup) as cm:
254 await r
255 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
256
257 async def test_taskgroup_09(self):
258
259 t1 = t2 = None
260
261 async def foo1():
262 await asyncio.sleep(1)
263 return 42
264
265 async def foo2():
266 await asyncio.sleep(2)
267 return 11
268
269 async def runner():
270 nonlocal t1, t2
271 async with taskgroups.TaskGroup() as g:
272 t1 = g.create_task(foo1())
273 t2 = g.create_task(foo2())
274 await asyncio.sleep(0.1)
275 1 / 0
276
277 try:
278 await runner()
279 except ExceptionGroup as t:
280 self.assertEqual(get_error_types(t), {ZeroDivisionError})
281 else:
282 self.fail('ExceptionGroup was not raised')
283
284 self.assertTrue(t1.cancelled())
285 self.assertTrue(t2.cancelled())
286
287 async def test_taskgroup_10(self):
288
289 t1 = t2 = None
290
291 async def foo1():
292 await asyncio.sleep(1)
293 return 42
294
295 async def foo2():
296 await asyncio.sleep(2)
297 return 11
298
299 async def runner():
300 nonlocal t1, t2
301 async with taskgroups.TaskGroup() as g:
302 t1 = g.create_task(foo1())
303 t2 = g.create_task(foo2())
304 1 / 0
305
306 try:
307 await runner()
308 except ExceptionGroup as t:
309 self.assertEqual(get_error_types(t), {ZeroDivisionError})
310 else:
311 self.fail('ExceptionGroup was not raised')
312
313 self.assertTrue(t1.cancelled())
314 self.assertTrue(t2.cancelled())
315
316 async def test_taskgroup_11(self):
317
318 async def foo():
319 try:
320 await asyncio.sleep(10)
321 finally:
322 1 / 0
323
324 async def runner():
325 async with taskgroups.TaskGroup():
326 async with taskgroups.TaskGroup() as g2:
327 for _ in range(5):
328 g2.create_task(foo())
329
330 await asyncio.sleep(10)
331
332 r = asyncio.create_task(runner())
333 await asyncio.sleep(0.1)
334
335 self.assertFalse(r.done())
336 r.cancel()
337 with self.assertRaises(ExceptionGroup) as cm:
338 await r
339
340 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
341 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
342
343 async def test_taskgroup_12(self):
344
345 async def foo():
346 try:
347 await asyncio.sleep(10)
348 finally:
349 1 / 0
350
351 async def runner():
352 async with taskgroups.TaskGroup() as g1:
353 g1.create_task(asyncio.sleep(10))
354
355 async with taskgroups.TaskGroup() as g2:
356 for _ in range(5):
357 g2.create_task(foo())
358
359 await asyncio.sleep(10)
360
361 r = asyncio.create_task(runner())
362 await asyncio.sleep(0.1)
363
364 self.assertFalse(r.done())
365 r.cancel()
366 with self.assertRaises(ExceptionGroup) as cm:
367 await r
368
369 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
370 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
371
372 async def test_taskgroup_13(self):
373
374 async def crash_after(t):
375 await asyncio.sleep(t)
376 raise ValueError(t)
377
378 async def runner():
379 async with taskgroups.TaskGroup() as g1:
380 g1.create_task(crash_after(0.1))
381
382 async with taskgroups.TaskGroup() as g2:
383 g2.create_task(crash_after(10))
384
385 r = asyncio.create_task(runner())
386 with self.assertRaises(ExceptionGroup) as cm:
387 await r
388
389 self.assertEqual(get_error_types(cm.exception), {ValueError})
390
391 async def test_taskgroup_14(self):
392
393 async def crash_after(t):
394 await asyncio.sleep(t)
395 raise ValueError(t)
396
397 async def runner():
398 async with taskgroups.TaskGroup() as g1:
399 g1.create_task(crash_after(10))
400
401 async with taskgroups.TaskGroup() as g2:
402 g2.create_task(crash_after(0.1))
403
404 r = asyncio.create_task(runner())
405 with self.assertRaises(ExceptionGroup) as cm:
406 await r
407
408 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
409 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
410
411 async def test_taskgroup_15(self):
412
413 async def crash_soon():
414 await asyncio.sleep(0.3)
415 1 / 0
416
417 async def runner():
418 async with taskgroups.TaskGroup() as g1:
419 g1.create_task(crash_soon())
420 try:
421 await asyncio.sleep(10)
422 except asyncio.CancelledError:
423 await asyncio.sleep(0.5)
424 raise
425
426 r = asyncio.create_task(runner())
427 await asyncio.sleep(0.1)
428
429 self.assertFalse(r.done())
430 r.cancel()
431 with self.assertRaises(ExceptionGroup) as cm:
432 await r
433 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
434
435 async def test_taskgroup_16(self):
436
437 async def crash_soon():
438 await asyncio.sleep(0.3)
439 1 / 0
440
441 async def nested_runner():
442 async with taskgroups.TaskGroup() as g1:
443 g1.create_task(crash_soon())
444 try:
445 await asyncio.sleep(10)
446 except asyncio.CancelledError:
447 await asyncio.sleep(0.5)
448 raise
449
450 async def runner():
451 t = asyncio.create_task(nested_runner())
452 await t
453
454 r = asyncio.create_task(runner())
455 await asyncio.sleep(0.1)
456
457 self.assertFalse(r.done())
458 r.cancel()
459 with self.assertRaises(ExceptionGroup) as cm:
460 await r
461 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
462
463 async def test_taskgroup_17(self):
464 NUM = 0
465
466 async def runner():
467 nonlocal NUM
468 async with taskgroups.TaskGroup():
469 try:
470 await asyncio.sleep(10)
471 except asyncio.CancelledError:
472 NUM += 10
473 raise
474
475 r = asyncio.create_task(runner())
476 await asyncio.sleep(0.1)
477
478 self.assertFalse(r.done())
479 r.cancel()
480 with self.assertRaises(asyncio.CancelledError):
481 await r
482
483 self.assertEqual(NUM, 10)
484
485 async def test_taskgroup_18(self):
486 NUM = 0
487
488 async def runner():
489 nonlocal NUM
490 async with taskgroups.TaskGroup():
491 try:
492 await asyncio.sleep(10)
493 except asyncio.CancelledError:
494 NUM += 10
495 # This isn't a good idea, but we have to support
496 # this weird case.
497 raise MyExc
498
499 r = asyncio.create_task(runner())
500 await asyncio.sleep(0.1)
501
502 self.assertFalse(r.done())
503 r.cancel()
504
505 try:
506 await r
507 except ExceptionGroup as t:
508 self.assertEqual(get_error_types(t),{MyExc})
509 else:
510 self.fail('ExceptionGroup was not raised')
511
512 self.assertEqual(NUM, 10)
513
514 async def test_taskgroup_19(self):
515 async def crash_soon():
516 await asyncio.sleep(0.1)
517 1 / 0
518
519 async def nested():
520 try:
521 await asyncio.sleep(10)
522 finally:
523 raise MyExc
524
525 async def runner():
526 async with taskgroups.TaskGroup() as g:
527 g.create_task(crash_soon())
528 await nested()
529
530 r = asyncio.create_task(runner())
531 try:
532 await r
533 except ExceptionGroup as t:
534 self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
535 else:
536 self.fail('TasgGroupError was not raised')
537
538 async def test_taskgroup_20(self):
539 async def crash_soon():
540 await asyncio.sleep(0.1)
541 1 / 0
542
543 async def nested():
544 try:
545 await asyncio.sleep(10)
546 finally:
547 raise KeyboardInterrupt
548
549 async def runner():
550 async with taskgroups.TaskGroup() as g:
551 g.create_task(crash_soon())
552 await nested()
553
554 with self.assertRaises(KeyboardInterrupt):
555 await runner()
556
557 async def test_taskgroup_20a(self):
558 async def crash_soon():
559 await asyncio.sleep(0.1)
560 1 / 0
561
562 async def nested():
563 try:
564 await asyncio.sleep(10)
565 finally:
566 raise MyBaseExc
567
568 async def runner():
569 async with taskgroups.TaskGroup() as g:
570 g.create_task(crash_soon())
571 await nested()
572
573 with self.assertRaises(BaseExceptionGroup) as cm:
574 await runner()
575
576 self.assertEqual(
577 get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
578 )
579
580 async def _test_taskgroup_21(self):
581 # This test doesn't work as asyncio, currently, doesn't
582 # correctly propagate KeyboardInterrupt (or SystemExit) --
583 # those cause the event loop itself to crash.
584 # (Compare to the previous (passing) test -- that one raises
585 # a plain exception but raises KeyboardInterrupt in nested();
586 # this test does it the other way around.)
587
588 async def crash_soon():
589 await asyncio.sleep(0.1)
590 raise KeyboardInterrupt
591
592 async def nested():
593 try:
594 await asyncio.sleep(10)
595 finally:
596 raise TypeError
597
598 async def runner():
599 async with taskgroups.TaskGroup() as g:
600 g.create_task(crash_soon())
601 await nested()
602
603 with self.assertRaises(KeyboardInterrupt):
604 await runner()
605
606 async def test_taskgroup_21a(self):
607
608 async def crash_soon():
609 await asyncio.sleep(0.1)
610 raise MyBaseExc
611
612 async def nested():
613 try:
614 await asyncio.sleep(10)
615 finally:
616 raise TypeError
617
618 async def runner():
619 async with taskgroups.TaskGroup() as g:
620 g.create_task(crash_soon())
621 await nested()
622
623 with self.assertRaises(BaseExceptionGroup) as cm:
624 await runner()
625
626 self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
627
628 async def test_taskgroup_22(self):
629
630 async def foo1():
631 await asyncio.sleep(1)
632 return 42
633
634 async def foo2():
635 await asyncio.sleep(2)
636 return 11
637
638 async def runner():
639 async with taskgroups.TaskGroup() as g:
640 g.create_task(foo1())
641 g.create_task(foo2())
642
643 r = asyncio.create_task(runner())
644 await asyncio.sleep(0.05)
645 r.cancel()
646
647 with self.assertRaises(asyncio.CancelledError):
648 await r
649
650 async def test_taskgroup_23(self):
651
652 async def do_job(delay):
653 await asyncio.sleep(delay)
654
655 async with taskgroups.TaskGroup() as g:
656 for count in range(10):
657 await asyncio.sleep(0.1)
658 g.create_task(do_job(0.3))
659 if count == 5:
660 self.assertLess(len(g._tasks), 5)
661 await asyncio.sleep(1.35)
662 self.assertEqual(len(g._tasks), 0)
663
664 async def test_taskgroup_24(self):
665
666 async def root(g):
667 await asyncio.sleep(0.1)
668 g.create_task(coro1(0.1))
669 g.create_task(coro1(0.2))
670
671 async def coro1(delay):
672 await asyncio.sleep(delay)
673
674 async def runner():
675 async with taskgroups.TaskGroup() as g:
676 g.create_task(root(g))
677
678 await runner()
679
680 async def test_taskgroup_25(self):
681 nhydras = 0
682
683 async def hydra(g):
684 nonlocal nhydras
685 nhydras += 1
686 await asyncio.sleep(0.01)
687 g.create_task(hydra(g))
688 g.create_task(hydra(g))
689
690 async def hercules():
691 while nhydras < 10:
692 await asyncio.sleep(0.015)
693 1 / 0
694
695 async def runner():
696 async with taskgroups.TaskGroup() as g:
697 g.create_task(hydra(g))
698 g.create_task(hercules())
699
700 with self.assertRaises(ExceptionGroup) as cm:
701 await runner()
702
703 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
704 self.assertGreaterEqual(nhydras, 10)
705
706 async def test_taskgroup_task_name(self):
707 async def coro():
708 await asyncio.sleep(0)
709 async with taskgroups.TaskGroup() as g:
710 t = g.create_task(coro(), name="yolo")
711 self.assertEqual(t.get_name(), "yolo")
712
713 async def test_taskgroup_task_context(self):
714 cvar = contextvars.ContextVar('cvar')
715
716 async def coro(val):
717 await asyncio.sleep(0)
718 cvar.set(val)
719
720 async with taskgroups.TaskGroup() as g:
721 ctx = contextvars.copy_context()
722 self.assertIsNone(ctx.get(cvar))
723 t1 = g.create_task(coro(1), context=ctx)
724 await t1
725 self.assertEqual(1, ctx.get(cvar))
726 t2 = g.create_task(coro(2), context=ctx)
727 await t2
728 self.assertEqual(2, ctx.get(cvar))
729
730 async def test_taskgroup_no_create_task_after_failure(self):
731 async def coro1():
732 await asyncio.sleep(0.001)
733 1 / 0
734 async def coro2(g):
735 try:
736 await asyncio.sleep(1)
737 except asyncio.CancelledError:
738 with self.assertRaises(RuntimeError):
739 g.create_task(c1 := coro1())
740 # We still have to await c1 to avoid a warning
741 with self.assertRaises(ZeroDivisionError):
742 await c1
743
744 with self.assertRaises(ExceptionGroup) as cm:
745 async with taskgroups.TaskGroup() as g:
746 g.create_task(coro1())
747 g.create_task(coro2(g))
748
749 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
750
751 async def test_taskgroup_context_manager_exit_raises(self):
752 # See https://github.com/python/cpython/issues/95289
753 class ESC[4;38;5;81mCustomException(ESC[4;38;5;149mException):
754 pass
755
756 async def raise_exc():
757 raise CustomException
758
759 @contextlib.asynccontextmanager
760 async def database():
761 try:
762 yield
763 finally:
764 raise CustomException
765
766 async def main():
767 task = asyncio.current_task()
768 try:
769 async with taskgroups.TaskGroup() as tg:
770 async with database():
771 tg.create_task(raise_exc())
772 await asyncio.sleep(1)
773 except* CustomException as err:
774 self.assertEqual(task.cancelling(), 0)
775 self.assertEqual(len(err.exceptions), 2)
776
777 else:
778 self.fail('CustomException not raised')
779
780 await asyncio.create_task(main())
781
782
783 if __name__ == "__main__":
784 unittest.main()