1 # Adapted with permission from the EdgeDB project;
2 # license: PSFL.
3
4
5 import asyncio
6 import contextvars
7 import contextlib
8 from asyncio import taskgroups
9 import unittest
10
11 from test.test_asyncio.utils import await_without_task
12
13
14 # To prevent a warning "test altered the execution environment"
15 def tearDownModule():
16 asyncio.set_event_loop_policy(None)
17
18
19 class ESC[4;38;5;81mMyExc(ESC[4;38;5;149mException):
20 pass
21
22
23 class ESC[4;38;5;81mMyBaseExc(ESC[4;38;5;149mBaseException):
24 pass
25
26
27 def get_error_types(eg):
28 return {type(exc) for exc in eg.exceptions}
29
30
31 class ESC[4;38;5;81mTestTaskGroup(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mIsolatedAsyncioTestCase):
32
33 async def test_taskgroup_01(self):
34
35 async def foo1():
36 await asyncio.sleep(0.1)
37 return 42
38
39 async def foo2():
40 await asyncio.sleep(0.2)
41 return 11
42
43 async with taskgroups.TaskGroup() as g:
44 t1 = g.create_task(foo1())
45 t2 = g.create_task(foo2())
46
47 self.assertEqual(t1.result(), 42)
48 self.assertEqual(t2.result(), 11)
49
50 async def test_taskgroup_02(self):
51
52 async def foo1():
53 await asyncio.sleep(0.1)
54 return 42
55
56 async def foo2():
57 await asyncio.sleep(0.2)
58 return 11
59
60 async with taskgroups.TaskGroup() as g:
61 t1 = g.create_task(foo1())
62 await asyncio.sleep(0.15)
63 t2 = g.create_task(foo2())
64
65 self.assertEqual(t1.result(), 42)
66 self.assertEqual(t2.result(), 11)
67
68 async def test_taskgroup_03(self):
69
70 async def foo1():
71 await asyncio.sleep(1)
72 return 42
73
74 async def foo2():
75 await asyncio.sleep(0.2)
76 return 11
77
78 async with taskgroups.TaskGroup() as g:
79 t1 = g.create_task(foo1())
80 await asyncio.sleep(0.15)
81 # cancel t1 explicitly, i.e. everything should continue
82 # working as expected.
83 t1.cancel()
84
85 t2 = g.create_task(foo2())
86
87 self.assertTrue(t1.cancelled())
88 self.assertEqual(t2.result(), 11)
89
90 async def test_taskgroup_04(self):
91
92 NUM = 0
93 t2_cancel = False
94 t2 = None
95
96 async def foo1():
97 await asyncio.sleep(0.1)
98 1 / 0
99
100 async def foo2():
101 nonlocal NUM, t2_cancel
102 try:
103 await asyncio.sleep(1)
104 except asyncio.CancelledError:
105 t2_cancel = True
106 raise
107 NUM += 1
108
109 async def runner():
110 nonlocal NUM, t2
111
112 async with taskgroups.TaskGroup() as g:
113 g.create_task(foo1())
114 t2 = g.create_task(foo2())
115
116 NUM += 10
117
118 with self.assertRaises(ExceptionGroup) as cm:
119 await asyncio.create_task(runner())
120
121 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
122
123 self.assertEqual(NUM, 0)
124 self.assertTrue(t2_cancel)
125 self.assertTrue(t2.cancelled())
126
127 async def test_cancel_children_on_child_error(self):
128 # When a child task raises an error, the rest of the children
129 # are cancelled and the errors are gathered into an EG.
130
131 NUM = 0
132 t2_cancel = False
133 runner_cancel = False
134
135 async def foo1():
136 await asyncio.sleep(0.1)
137 1 / 0
138
139 async def foo2():
140 nonlocal NUM, t2_cancel
141 try:
142 await asyncio.sleep(5)
143 except asyncio.CancelledError:
144 t2_cancel = True
145 raise
146 NUM += 1
147
148 async def runner():
149 nonlocal NUM, runner_cancel
150
151 async with taskgroups.TaskGroup() as g:
152 g.create_task(foo1())
153 g.create_task(foo1())
154 g.create_task(foo1())
155 g.create_task(foo2())
156 try:
157 await asyncio.sleep(10)
158 except asyncio.CancelledError:
159 runner_cancel = True
160 raise
161
162 NUM += 10
163
164 # The 3 foo1 sub tasks can be racy when the host is busy - if the
165 # cancellation happens in the middle, we'll see partial sub errors here
166 with self.assertRaises(ExceptionGroup) as cm:
167 await asyncio.create_task(runner())
168
169 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
170 self.assertEqual(NUM, 0)
171 self.assertTrue(t2_cancel)
172 self.assertTrue(runner_cancel)
173
174 async def test_cancellation(self):
175
176 NUM = 0
177
178 async def foo():
179 nonlocal NUM
180 try:
181 await asyncio.sleep(5)
182 except asyncio.CancelledError:
183 NUM += 1
184 raise
185
186 async def runner():
187 async with taskgroups.TaskGroup() as g:
188 for _ in range(5):
189 g.create_task(foo())
190
191 r = asyncio.create_task(runner())
192 await asyncio.sleep(0.1)
193
194 self.assertFalse(r.done())
195 r.cancel()
196 with self.assertRaises(asyncio.CancelledError) as cm:
197 await r
198
199 self.assertEqual(NUM, 5)
200
201 async def test_taskgroup_07(self):
202
203 NUM = 0
204
205 async def foo():
206 nonlocal NUM
207 try:
208 await asyncio.sleep(5)
209 except asyncio.CancelledError:
210 NUM += 1
211 raise
212
213 async def runner():
214 nonlocal NUM
215 async with taskgroups.TaskGroup() as g:
216 for _ in range(5):
217 g.create_task(foo())
218
219 try:
220 await asyncio.sleep(10)
221 except asyncio.CancelledError:
222 NUM += 10
223 raise
224
225 r = asyncio.create_task(runner())
226 await asyncio.sleep(0.1)
227
228 self.assertFalse(r.done())
229 r.cancel()
230 with self.assertRaises(asyncio.CancelledError):
231 await r
232
233 self.assertEqual(NUM, 15)
234
235 async def test_taskgroup_08(self):
236
237 async def foo():
238 try:
239 await asyncio.sleep(10)
240 finally:
241 1 / 0
242
243 async def runner():
244 async with taskgroups.TaskGroup() as g:
245 for _ in range(5):
246 g.create_task(foo())
247
248 await asyncio.sleep(10)
249
250 r = asyncio.create_task(runner())
251 await asyncio.sleep(0.1)
252
253 self.assertFalse(r.done())
254 r.cancel()
255 with self.assertRaises(ExceptionGroup) as cm:
256 await r
257 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
258
259 async def test_taskgroup_09(self):
260
261 t1 = t2 = None
262
263 async def foo1():
264 await asyncio.sleep(1)
265 return 42
266
267 async def foo2():
268 await asyncio.sleep(2)
269 return 11
270
271 async def runner():
272 nonlocal t1, t2
273 async with taskgroups.TaskGroup() as g:
274 t1 = g.create_task(foo1())
275 t2 = g.create_task(foo2())
276 await asyncio.sleep(0.1)
277 1 / 0
278
279 try:
280 await runner()
281 except ExceptionGroup as t:
282 self.assertEqual(get_error_types(t), {ZeroDivisionError})
283 else:
284 self.fail('ExceptionGroup was not raised')
285
286 self.assertTrue(t1.cancelled())
287 self.assertTrue(t2.cancelled())
288
289 async def test_taskgroup_10(self):
290
291 t1 = t2 = None
292
293 async def foo1():
294 await asyncio.sleep(1)
295 return 42
296
297 async def foo2():
298 await asyncio.sleep(2)
299 return 11
300
301 async def runner():
302 nonlocal t1, t2
303 async with taskgroups.TaskGroup() as g:
304 t1 = g.create_task(foo1())
305 t2 = g.create_task(foo2())
306 1 / 0
307
308 try:
309 await runner()
310 except ExceptionGroup as t:
311 self.assertEqual(get_error_types(t), {ZeroDivisionError})
312 else:
313 self.fail('ExceptionGroup was not raised')
314
315 self.assertTrue(t1.cancelled())
316 self.assertTrue(t2.cancelled())
317
318 async def test_taskgroup_11(self):
319
320 async def foo():
321 try:
322 await asyncio.sleep(10)
323 finally:
324 1 / 0
325
326 async def runner():
327 async with taskgroups.TaskGroup():
328 async with taskgroups.TaskGroup() as g2:
329 for _ in range(5):
330 g2.create_task(foo())
331
332 await asyncio.sleep(10)
333
334 r = asyncio.create_task(runner())
335 await asyncio.sleep(0.1)
336
337 self.assertFalse(r.done())
338 r.cancel()
339 with self.assertRaises(ExceptionGroup) as cm:
340 await r
341
342 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
343 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
344
345 async def test_taskgroup_12(self):
346
347 async def foo():
348 try:
349 await asyncio.sleep(10)
350 finally:
351 1 / 0
352
353 async def runner():
354 async with taskgroups.TaskGroup() as g1:
355 g1.create_task(asyncio.sleep(10))
356
357 async with taskgroups.TaskGroup() as g2:
358 for _ in range(5):
359 g2.create_task(foo())
360
361 await asyncio.sleep(10)
362
363 r = asyncio.create_task(runner())
364 await asyncio.sleep(0.1)
365
366 self.assertFalse(r.done())
367 r.cancel()
368 with self.assertRaises(ExceptionGroup) as cm:
369 await r
370
371 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
372 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
373
374 async def test_taskgroup_13(self):
375
376 async def crash_after(t):
377 await asyncio.sleep(t)
378 raise ValueError(t)
379
380 async def runner():
381 async with taskgroups.TaskGroup() as g1:
382 g1.create_task(crash_after(0.1))
383
384 async with taskgroups.TaskGroup() as g2:
385 g2.create_task(crash_after(10))
386
387 r = asyncio.create_task(runner())
388 with self.assertRaises(ExceptionGroup) as cm:
389 await r
390
391 self.assertEqual(get_error_types(cm.exception), {ValueError})
392
393 async def test_taskgroup_14(self):
394
395 async def crash_after(t):
396 await asyncio.sleep(t)
397 raise ValueError(t)
398
399 async def runner():
400 async with taskgroups.TaskGroup() as g1:
401 g1.create_task(crash_after(10))
402
403 async with taskgroups.TaskGroup() as g2:
404 g2.create_task(crash_after(0.1))
405
406 r = asyncio.create_task(runner())
407 with self.assertRaises(ExceptionGroup) as cm:
408 await r
409
410 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
411 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
412
413 async def test_taskgroup_15(self):
414
415 async def crash_soon():
416 await asyncio.sleep(0.3)
417 1 / 0
418
419 async def runner():
420 async with taskgroups.TaskGroup() as g1:
421 g1.create_task(crash_soon())
422 try:
423 await asyncio.sleep(10)
424 except asyncio.CancelledError:
425 await asyncio.sleep(0.5)
426 raise
427
428 r = asyncio.create_task(runner())
429 await asyncio.sleep(0.1)
430
431 self.assertFalse(r.done())
432 r.cancel()
433 with self.assertRaises(ExceptionGroup) as cm:
434 await r
435 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
436
437 async def test_taskgroup_16(self):
438
439 async def crash_soon():
440 await asyncio.sleep(0.3)
441 1 / 0
442
443 async def nested_runner():
444 async with taskgroups.TaskGroup() as g1:
445 g1.create_task(crash_soon())
446 try:
447 await asyncio.sleep(10)
448 except asyncio.CancelledError:
449 await asyncio.sleep(0.5)
450 raise
451
452 async def runner():
453 t = asyncio.create_task(nested_runner())
454 await t
455
456 r = asyncio.create_task(runner())
457 await asyncio.sleep(0.1)
458
459 self.assertFalse(r.done())
460 r.cancel()
461 with self.assertRaises(ExceptionGroup) as cm:
462 await r
463 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
464
465 async def test_taskgroup_17(self):
466 NUM = 0
467
468 async def runner():
469 nonlocal NUM
470 async with taskgroups.TaskGroup():
471 try:
472 await asyncio.sleep(10)
473 except asyncio.CancelledError:
474 NUM += 10
475 raise
476
477 r = asyncio.create_task(runner())
478 await asyncio.sleep(0.1)
479
480 self.assertFalse(r.done())
481 r.cancel()
482 with self.assertRaises(asyncio.CancelledError):
483 await r
484
485 self.assertEqual(NUM, 10)
486
487 async def test_taskgroup_18(self):
488 NUM = 0
489
490 async def runner():
491 nonlocal NUM
492 async with taskgroups.TaskGroup():
493 try:
494 await asyncio.sleep(10)
495 except asyncio.CancelledError:
496 NUM += 10
497 # This isn't a good idea, but we have to support
498 # this weird case.
499 raise MyExc
500
501 r = asyncio.create_task(runner())
502 await asyncio.sleep(0.1)
503
504 self.assertFalse(r.done())
505 r.cancel()
506
507 try:
508 await r
509 except ExceptionGroup as t:
510 self.assertEqual(get_error_types(t),{MyExc})
511 else:
512 self.fail('ExceptionGroup was not raised')
513
514 self.assertEqual(NUM, 10)
515
516 async def test_taskgroup_19(self):
517 async def crash_soon():
518 await asyncio.sleep(0.1)
519 1 / 0
520
521 async def nested():
522 try:
523 await asyncio.sleep(10)
524 finally:
525 raise MyExc
526
527 async def runner():
528 async with taskgroups.TaskGroup() as g:
529 g.create_task(crash_soon())
530 await nested()
531
532 r = asyncio.create_task(runner())
533 try:
534 await r
535 except ExceptionGroup as t:
536 self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
537 else:
538 self.fail('TasgGroupError was not raised')
539
540 async def test_taskgroup_20(self):
541 async def crash_soon():
542 await asyncio.sleep(0.1)
543 1 / 0
544
545 async def nested():
546 try:
547 await asyncio.sleep(10)
548 finally:
549 raise KeyboardInterrupt
550
551 async def runner():
552 async with taskgroups.TaskGroup() as g:
553 g.create_task(crash_soon())
554 await nested()
555
556 with self.assertRaises(KeyboardInterrupt):
557 await runner()
558
559 async def test_taskgroup_20a(self):
560 async def crash_soon():
561 await asyncio.sleep(0.1)
562 1 / 0
563
564 async def nested():
565 try:
566 await asyncio.sleep(10)
567 finally:
568 raise MyBaseExc
569
570 async def runner():
571 async with taskgroups.TaskGroup() as g:
572 g.create_task(crash_soon())
573 await nested()
574
575 with self.assertRaises(BaseExceptionGroup) as cm:
576 await runner()
577
578 self.assertEqual(
579 get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
580 )
581
582 async def _test_taskgroup_21(self):
583 # This test doesn't work as asyncio, currently, doesn't
584 # correctly propagate KeyboardInterrupt (or SystemExit) --
585 # those cause the event loop itself to crash.
586 # (Compare to the previous (passing) test -- that one raises
587 # a plain exception but raises KeyboardInterrupt in nested();
588 # this test does it the other way around.)
589
590 async def crash_soon():
591 await asyncio.sleep(0.1)
592 raise KeyboardInterrupt
593
594 async def nested():
595 try:
596 await asyncio.sleep(10)
597 finally:
598 raise TypeError
599
600 async def runner():
601 async with taskgroups.TaskGroup() as g:
602 g.create_task(crash_soon())
603 await nested()
604
605 with self.assertRaises(KeyboardInterrupt):
606 await runner()
607
608 async def test_taskgroup_21a(self):
609
610 async def crash_soon():
611 await asyncio.sleep(0.1)
612 raise MyBaseExc
613
614 async def nested():
615 try:
616 await asyncio.sleep(10)
617 finally:
618 raise TypeError
619
620 async def runner():
621 async with taskgroups.TaskGroup() as g:
622 g.create_task(crash_soon())
623 await nested()
624
625 with self.assertRaises(BaseExceptionGroup) as cm:
626 await runner()
627
628 self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
629
630 async def test_taskgroup_22(self):
631
632 async def foo1():
633 await asyncio.sleep(1)
634 return 42
635
636 async def foo2():
637 await asyncio.sleep(2)
638 return 11
639
640 async def runner():
641 async with taskgroups.TaskGroup() as g:
642 g.create_task(foo1())
643 g.create_task(foo2())
644
645 r = asyncio.create_task(runner())
646 await asyncio.sleep(0.05)
647 r.cancel()
648
649 with self.assertRaises(asyncio.CancelledError):
650 await r
651
652 async def test_taskgroup_23(self):
653
654 async def do_job(delay):
655 await asyncio.sleep(delay)
656
657 async with taskgroups.TaskGroup() as g:
658 for count in range(10):
659 await asyncio.sleep(0.1)
660 g.create_task(do_job(0.3))
661 if count == 5:
662 self.assertLess(len(g._tasks), 5)
663 await asyncio.sleep(1.35)
664 self.assertEqual(len(g._tasks), 0)
665
666 async def test_taskgroup_24(self):
667
668 async def root(g):
669 await asyncio.sleep(0.1)
670 g.create_task(coro1(0.1))
671 g.create_task(coro1(0.2))
672
673 async def coro1(delay):
674 await asyncio.sleep(delay)
675
676 async def runner():
677 async with taskgroups.TaskGroup() as g:
678 g.create_task(root(g))
679
680 await runner()
681
682 async def test_taskgroup_25(self):
683 nhydras = 0
684
685 async def hydra(g):
686 nonlocal nhydras
687 nhydras += 1
688 await asyncio.sleep(0.01)
689 g.create_task(hydra(g))
690 g.create_task(hydra(g))
691
692 async def hercules():
693 while nhydras < 10:
694 await asyncio.sleep(0.015)
695 1 / 0
696
697 async def runner():
698 async with taskgroups.TaskGroup() as g:
699 g.create_task(hydra(g))
700 g.create_task(hercules())
701
702 with self.assertRaises(ExceptionGroup) as cm:
703 await runner()
704
705 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
706 self.assertGreaterEqual(nhydras, 10)
707
708 async def test_taskgroup_task_name(self):
709 async def coro():
710 await asyncio.sleep(0)
711 async with taskgroups.TaskGroup() as g:
712 t = g.create_task(coro(), name="yolo")
713 self.assertEqual(t.get_name(), "yolo")
714
715 async def test_taskgroup_task_context(self):
716 cvar = contextvars.ContextVar('cvar')
717
718 async def coro(val):
719 await asyncio.sleep(0)
720 cvar.set(val)
721
722 async with taskgroups.TaskGroup() as g:
723 ctx = contextvars.copy_context()
724 self.assertIsNone(ctx.get(cvar))
725 t1 = g.create_task(coro(1), context=ctx)
726 await t1
727 self.assertEqual(1, ctx.get(cvar))
728 t2 = g.create_task(coro(2), context=ctx)
729 await t2
730 self.assertEqual(2, ctx.get(cvar))
731
732 async def test_taskgroup_no_create_task_after_failure(self):
733 async def coro1():
734 await asyncio.sleep(0.001)
735 1 / 0
736 async def coro2(g):
737 try:
738 await asyncio.sleep(1)
739 except asyncio.CancelledError:
740 with self.assertRaises(RuntimeError):
741 g.create_task(c1 := coro1())
742 # We still have to await c1 to avoid a warning
743 with self.assertRaises(ZeroDivisionError):
744 await c1
745
746 with self.assertRaises(ExceptionGroup) as cm:
747 async with taskgroups.TaskGroup() as g:
748 g.create_task(coro1())
749 g.create_task(coro2(g))
750
751 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
752
753 async def test_taskgroup_context_manager_exit_raises(self):
754 # See https://github.com/python/cpython/issues/95289
755 class ESC[4;38;5;81mCustomException(ESC[4;38;5;149mException):
756 pass
757
758 async def raise_exc():
759 raise CustomException
760
761 @contextlib.asynccontextmanager
762 async def database():
763 try:
764 yield
765 finally:
766 raise CustomException
767
768 async def main():
769 task = asyncio.current_task()
770 try:
771 async with taskgroups.TaskGroup() as tg:
772 async with database():
773 tg.create_task(raise_exc())
774 await asyncio.sleep(1)
775 except* CustomException as err:
776 self.assertEqual(task.cancelling(), 0)
777 self.assertEqual(len(err.exceptions), 2)
778
779 else:
780 self.fail('CustomException not raised')
781
782 await asyncio.create_task(main())
783
784 async def test_taskgroup_already_entered(self):
785 tg = taskgroups.TaskGroup()
786 async with tg:
787 with self.assertRaisesRegex(RuntimeError, "has already been entered"):
788 async with tg:
789 pass
790
791 async def test_taskgroup_double_enter(self):
792 tg = taskgroups.TaskGroup()
793 async with tg:
794 pass
795 with self.assertRaisesRegex(RuntimeError, "has already been entered"):
796 async with tg:
797 pass
798
799 async def test_taskgroup_finished(self):
800 tg = taskgroups.TaskGroup()
801 async with tg:
802 pass
803 coro = asyncio.sleep(0)
804 with self.assertRaisesRegex(RuntimeError, "is finished"):
805 tg.create_task(coro)
806 # We still have to await coro to avoid a warning
807 await coro
808
809 async def test_taskgroup_not_entered(self):
810 tg = taskgroups.TaskGroup()
811 coro = asyncio.sleep(0)
812 with self.assertRaisesRegex(RuntimeError, "has not been entered"):
813 tg.create_task(coro)
814 # We still have to await coro to avoid a warning
815 await coro
816
817 async def test_taskgroup_without_parent_task(self):
818 tg = taskgroups.TaskGroup()
819 with self.assertRaisesRegex(RuntimeError, "parent task"):
820 await await_without_task(tg.__aenter__())
821 coro = asyncio.sleep(0)
822 with self.assertRaisesRegex(RuntimeError, "has not been entered"):
823 tg.create_task(coro)
824 # We still have to await coro to avoid a warning
825 await coro
826
827
828 if __name__ == "__main__":
829 unittest.main()