(root)/
Python-3.12.0/
Lib/
test/
test_asyncio/
test_taskgroups.py
       1  # Adapted with permission from the EdgeDB project;
       2  # license: PSFL.
       3  
       4  
       5  import asyncio
       6  import contextvars
       7  import contextlib
       8  from asyncio import taskgroups
       9  import unittest
      10  
      11  
      12  # To prevent a warning "test altered the execution environment"
      13  def tearDownModule():
      14      asyncio.set_event_loop_policy(None)
      15  
      16  
      17  class ESC[4;38;5;81mMyExc(ESC[4;38;5;149mException):
      18      pass
      19  
      20  
      21  class ESC[4;38;5;81mMyBaseExc(ESC[4;38;5;149mBaseException):
      22      pass
      23  
      24  
      25  def get_error_types(eg):
      26      return {type(exc) for exc in eg.exceptions}
      27  
      28  
      29  class ESC[4;38;5;81mTestTaskGroup(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mIsolatedAsyncioTestCase):
      30  
      31      async def test_taskgroup_01(self):
      32  
      33          async def foo1():
      34              await asyncio.sleep(0.1)
      35              return 42
      36  
      37          async def foo2():
      38              await asyncio.sleep(0.2)
      39              return 11
      40  
      41          async with taskgroups.TaskGroup() as g:
      42              t1 = g.create_task(foo1())
      43              t2 = g.create_task(foo2())
      44  
      45          self.assertEqual(t1.result(), 42)
      46          self.assertEqual(t2.result(), 11)
      47  
      48      async def test_taskgroup_02(self):
      49  
      50          async def foo1():
      51              await asyncio.sleep(0.1)
      52              return 42
      53  
      54          async def foo2():
      55              await asyncio.sleep(0.2)
      56              return 11
      57  
      58          async with taskgroups.TaskGroup() as g:
      59              t1 = g.create_task(foo1())
      60              await asyncio.sleep(0.15)
      61              t2 = g.create_task(foo2())
      62  
      63          self.assertEqual(t1.result(), 42)
      64          self.assertEqual(t2.result(), 11)
      65  
      66      async def test_taskgroup_03(self):
      67  
      68          async def foo1():
      69              await asyncio.sleep(1)
      70              return 42
      71  
      72          async def foo2():
      73              await asyncio.sleep(0.2)
      74              return 11
      75  
      76          async with taskgroups.TaskGroup() as g:
      77              t1 = g.create_task(foo1())
      78              await asyncio.sleep(0.15)
      79              # cancel t1 explicitly, i.e. everything should continue
      80              # working as expected.
      81              t1.cancel()
      82  
      83              t2 = g.create_task(foo2())
      84  
      85          self.assertTrue(t1.cancelled())
      86          self.assertEqual(t2.result(), 11)
      87  
      88      async def test_taskgroup_04(self):
      89  
      90          NUM = 0
      91          t2_cancel = False
      92          t2 = None
      93  
      94          async def foo1():
      95              await asyncio.sleep(0.1)
      96              1 / 0
      97  
      98          async def foo2():
      99              nonlocal NUM, t2_cancel
     100              try:
     101                  await asyncio.sleep(1)
     102              except asyncio.CancelledError:
     103                  t2_cancel = True
     104                  raise
     105              NUM += 1
     106  
     107          async def runner():
     108              nonlocal NUM, t2
     109  
     110              async with taskgroups.TaskGroup() as g:
     111                  g.create_task(foo1())
     112                  t2 = g.create_task(foo2())
     113  
     114              NUM += 10
     115  
     116          with self.assertRaises(ExceptionGroup) as cm:
     117              await asyncio.create_task(runner())
     118  
     119          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     120  
     121          self.assertEqual(NUM, 0)
     122          self.assertTrue(t2_cancel)
     123          self.assertTrue(t2.cancelled())
     124  
     125      async def test_cancel_children_on_child_error(self):
     126          # When a child task raises an error, the rest of the children
     127          # are cancelled and the errors are gathered into an EG.
     128  
     129          NUM = 0
     130          t2_cancel = False
     131          runner_cancel = False
     132  
     133          async def foo1():
     134              await asyncio.sleep(0.1)
     135              1 / 0
     136  
     137          async def foo2():
     138              nonlocal NUM, t2_cancel
     139              try:
     140                  await asyncio.sleep(5)
     141              except asyncio.CancelledError:
     142                  t2_cancel = True
     143                  raise
     144              NUM += 1
     145  
     146          async def runner():
     147              nonlocal NUM, runner_cancel
     148  
     149              async with taskgroups.TaskGroup() as g:
     150                  g.create_task(foo1())
     151                  g.create_task(foo1())
     152                  g.create_task(foo1())
     153                  g.create_task(foo2())
     154                  try:
     155                      await asyncio.sleep(10)
     156                  except asyncio.CancelledError:
     157                      runner_cancel = True
     158                      raise
     159  
     160              NUM += 10
     161  
     162          # The 3 foo1 sub tasks can be racy when the host is busy - if the
     163          # cancellation happens in the middle, we'll see partial sub errors here
     164          with self.assertRaises(ExceptionGroup) as cm:
     165              await asyncio.create_task(runner())
     166  
     167          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     168          self.assertEqual(NUM, 0)
     169          self.assertTrue(t2_cancel)
     170          self.assertTrue(runner_cancel)
     171  
     172      async def test_cancellation(self):
     173  
     174          NUM = 0
     175  
     176          async def foo():
     177              nonlocal NUM
     178              try:
     179                  await asyncio.sleep(5)
     180              except asyncio.CancelledError:
     181                  NUM += 1
     182                  raise
     183  
     184          async def runner():
     185              async with taskgroups.TaskGroup() as g:
     186                  for _ in range(5):
     187                      g.create_task(foo())
     188  
     189          r = asyncio.create_task(runner())
     190          await asyncio.sleep(0.1)
     191  
     192          self.assertFalse(r.done())
     193          r.cancel()
     194          with self.assertRaises(asyncio.CancelledError) as cm:
     195              await r
     196  
     197          self.assertEqual(NUM, 5)
     198  
     199      async def test_taskgroup_07(self):
     200  
     201          NUM = 0
     202  
     203          async def foo():
     204              nonlocal NUM
     205              try:
     206                  await asyncio.sleep(5)
     207              except asyncio.CancelledError:
     208                  NUM += 1
     209                  raise
     210  
     211          async def runner():
     212              nonlocal NUM
     213              async with taskgroups.TaskGroup() as g:
     214                  for _ in range(5):
     215                      g.create_task(foo())
     216  
     217                  try:
     218                      await asyncio.sleep(10)
     219                  except asyncio.CancelledError:
     220                      NUM += 10
     221                      raise
     222  
     223          r = asyncio.create_task(runner())
     224          await asyncio.sleep(0.1)
     225  
     226          self.assertFalse(r.done())
     227          r.cancel()
     228          with self.assertRaises(asyncio.CancelledError):
     229              await r
     230  
     231          self.assertEqual(NUM, 15)
     232  
     233      async def test_taskgroup_08(self):
     234  
     235          async def foo():
     236              try:
     237                  await asyncio.sleep(10)
     238              finally:
     239                  1 / 0
     240  
     241          async def runner():
     242              async with taskgroups.TaskGroup() as g:
     243                  for _ in range(5):
     244                      g.create_task(foo())
     245  
     246                  await asyncio.sleep(10)
     247  
     248          r = asyncio.create_task(runner())
     249          await asyncio.sleep(0.1)
     250  
     251          self.assertFalse(r.done())
     252          r.cancel()
     253          with self.assertRaises(ExceptionGroup) as cm:
     254              await r
     255          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     256  
     257      async def test_taskgroup_09(self):
     258  
     259          t1 = t2 = None
     260  
     261          async def foo1():
     262              await asyncio.sleep(1)
     263              return 42
     264  
     265          async def foo2():
     266              await asyncio.sleep(2)
     267              return 11
     268  
     269          async def runner():
     270              nonlocal t1, t2
     271              async with taskgroups.TaskGroup() as g:
     272                  t1 = g.create_task(foo1())
     273                  t2 = g.create_task(foo2())
     274                  await asyncio.sleep(0.1)
     275                  1 / 0
     276  
     277          try:
     278              await runner()
     279          except ExceptionGroup as t:
     280              self.assertEqual(get_error_types(t), {ZeroDivisionError})
     281          else:
     282              self.fail('ExceptionGroup was not raised')
     283  
     284          self.assertTrue(t1.cancelled())
     285          self.assertTrue(t2.cancelled())
     286  
     287      async def test_taskgroup_10(self):
     288  
     289          t1 = t2 = None
     290  
     291          async def foo1():
     292              await asyncio.sleep(1)
     293              return 42
     294  
     295          async def foo2():
     296              await asyncio.sleep(2)
     297              return 11
     298  
     299          async def runner():
     300              nonlocal t1, t2
     301              async with taskgroups.TaskGroup() as g:
     302                  t1 = g.create_task(foo1())
     303                  t2 = g.create_task(foo2())
     304                  1 / 0
     305  
     306          try:
     307              await runner()
     308          except ExceptionGroup as t:
     309              self.assertEqual(get_error_types(t), {ZeroDivisionError})
     310          else:
     311              self.fail('ExceptionGroup was not raised')
     312  
     313          self.assertTrue(t1.cancelled())
     314          self.assertTrue(t2.cancelled())
     315  
     316      async def test_taskgroup_11(self):
     317  
     318          async def foo():
     319              try:
     320                  await asyncio.sleep(10)
     321              finally:
     322                  1 / 0
     323  
     324          async def runner():
     325              async with taskgroups.TaskGroup():
     326                  async with taskgroups.TaskGroup() as g2:
     327                      for _ in range(5):
     328                          g2.create_task(foo())
     329  
     330                      await asyncio.sleep(10)
     331  
     332          r = asyncio.create_task(runner())
     333          await asyncio.sleep(0.1)
     334  
     335          self.assertFalse(r.done())
     336          r.cancel()
     337          with self.assertRaises(ExceptionGroup) as cm:
     338              await r
     339  
     340          self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
     341          self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
     342  
     343      async def test_taskgroup_12(self):
     344  
     345          async def foo():
     346              try:
     347                  await asyncio.sleep(10)
     348              finally:
     349                  1 / 0
     350  
     351          async def runner():
     352              async with taskgroups.TaskGroup() as g1:
     353                  g1.create_task(asyncio.sleep(10))
     354  
     355                  async with taskgroups.TaskGroup() as g2:
     356                      for _ in range(5):
     357                          g2.create_task(foo())
     358  
     359                      await asyncio.sleep(10)
     360  
     361          r = asyncio.create_task(runner())
     362          await asyncio.sleep(0.1)
     363  
     364          self.assertFalse(r.done())
     365          r.cancel()
     366          with self.assertRaises(ExceptionGroup) as cm:
     367              await r
     368  
     369          self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
     370          self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
     371  
     372      async def test_taskgroup_13(self):
     373  
     374          async def crash_after(t):
     375              await asyncio.sleep(t)
     376              raise ValueError(t)
     377  
     378          async def runner():
     379              async with taskgroups.TaskGroup() as g1:
     380                  g1.create_task(crash_after(0.1))
     381  
     382                  async with taskgroups.TaskGroup() as g2:
     383                      g2.create_task(crash_after(10))
     384  
     385          r = asyncio.create_task(runner())
     386          with self.assertRaises(ExceptionGroup) as cm:
     387              await r
     388  
     389          self.assertEqual(get_error_types(cm.exception), {ValueError})
     390  
     391      async def test_taskgroup_14(self):
     392  
     393          async def crash_after(t):
     394              await asyncio.sleep(t)
     395              raise ValueError(t)
     396  
     397          async def runner():
     398              async with taskgroups.TaskGroup() as g1:
     399                  g1.create_task(crash_after(10))
     400  
     401                  async with taskgroups.TaskGroup() as g2:
     402                      g2.create_task(crash_after(0.1))
     403  
     404          r = asyncio.create_task(runner())
     405          with self.assertRaises(ExceptionGroup) as cm:
     406              await r
     407  
     408          self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
     409          self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
     410  
     411      async def test_taskgroup_15(self):
     412  
     413          async def crash_soon():
     414              await asyncio.sleep(0.3)
     415              1 / 0
     416  
     417          async def runner():
     418              async with taskgroups.TaskGroup() as g1:
     419                  g1.create_task(crash_soon())
     420                  try:
     421                      await asyncio.sleep(10)
     422                  except asyncio.CancelledError:
     423                      await asyncio.sleep(0.5)
     424                      raise
     425  
     426          r = asyncio.create_task(runner())
     427          await asyncio.sleep(0.1)
     428  
     429          self.assertFalse(r.done())
     430          r.cancel()
     431          with self.assertRaises(ExceptionGroup) as cm:
     432              await r
     433          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     434  
     435      async def test_taskgroup_16(self):
     436  
     437          async def crash_soon():
     438              await asyncio.sleep(0.3)
     439              1 / 0
     440  
     441          async def nested_runner():
     442              async with taskgroups.TaskGroup() as g1:
     443                  g1.create_task(crash_soon())
     444                  try:
     445                      await asyncio.sleep(10)
     446                  except asyncio.CancelledError:
     447                      await asyncio.sleep(0.5)
     448                      raise
     449  
     450          async def runner():
     451              t = asyncio.create_task(nested_runner())
     452              await t
     453  
     454          r = asyncio.create_task(runner())
     455          await asyncio.sleep(0.1)
     456  
     457          self.assertFalse(r.done())
     458          r.cancel()
     459          with self.assertRaises(ExceptionGroup) as cm:
     460              await r
     461          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     462  
     463      async def test_taskgroup_17(self):
     464          NUM = 0
     465  
     466          async def runner():
     467              nonlocal NUM
     468              async with taskgroups.TaskGroup():
     469                  try:
     470                      await asyncio.sleep(10)
     471                  except asyncio.CancelledError:
     472                      NUM += 10
     473                      raise
     474  
     475          r = asyncio.create_task(runner())
     476          await asyncio.sleep(0.1)
     477  
     478          self.assertFalse(r.done())
     479          r.cancel()
     480          with self.assertRaises(asyncio.CancelledError):
     481              await r
     482  
     483          self.assertEqual(NUM, 10)
     484  
     485      async def test_taskgroup_18(self):
     486          NUM = 0
     487  
     488          async def runner():
     489              nonlocal NUM
     490              async with taskgroups.TaskGroup():
     491                  try:
     492                      await asyncio.sleep(10)
     493                  except asyncio.CancelledError:
     494                      NUM += 10
     495                      # This isn't a good idea, but we have to support
     496                      # this weird case.
     497                      raise MyExc
     498  
     499          r = asyncio.create_task(runner())
     500          await asyncio.sleep(0.1)
     501  
     502          self.assertFalse(r.done())
     503          r.cancel()
     504  
     505          try:
     506              await r
     507          except ExceptionGroup as t:
     508              self.assertEqual(get_error_types(t),{MyExc})
     509          else:
     510              self.fail('ExceptionGroup was not raised')
     511  
     512          self.assertEqual(NUM, 10)
     513  
     514      async def test_taskgroup_19(self):
     515          async def crash_soon():
     516              await asyncio.sleep(0.1)
     517              1 / 0
     518  
     519          async def nested():
     520              try:
     521                  await asyncio.sleep(10)
     522              finally:
     523                  raise MyExc
     524  
     525          async def runner():
     526              async with taskgroups.TaskGroup() as g:
     527                  g.create_task(crash_soon())
     528                  await nested()
     529  
     530          r = asyncio.create_task(runner())
     531          try:
     532              await r
     533          except ExceptionGroup as t:
     534              self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
     535          else:
     536              self.fail('TasgGroupError was not raised')
     537  
     538      async def test_taskgroup_20(self):
     539          async def crash_soon():
     540              await asyncio.sleep(0.1)
     541              1 / 0
     542  
     543          async def nested():
     544              try:
     545                  await asyncio.sleep(10)
     546              finally:
     547                  raise KeyboardInterrupt
     548  
     549          async def runner():
     550              async with taskgroups.TaskGroup() as g:
     551                  g.create_task(crash_soon())
     552                  await nested()
     553  
     554          with self.assertRaises(KeyboardInterrupt):
     555              await runner()
     556  
     557      async def test_taskgroup_20a(self):
     558          async def crash_soon():
     559              await asyncio.sleep(0.1)
     560              1 / 0
     561  
     562          async def nested():
     563              try:
     564                  await asyncio.sleep(10)
     565              finally:
     566                  raise MyBaseExc
     567  
     568          async def runner():
     569              async with taskgroups.TaskGroup() as g:
     570                  g.create_task(crash_soon())
     571                  await nested()
     572  
     573          with self.assertRaises(BaseExceptionGroup) as cm:
     574              await runner()
     575  
     576          self.assertEqual(
     577              get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
     578          )
     579  
     580      async def _test_taskgroup_21(self):
     581          # This test doesn't work as asyncio, currently, doesn't
     582          # correctly propagate KeyboardInterrupt (or SystemExit) --
     583          # those cause the event loop itself to crash.
     584          # (Compare to the previous (passing) test -- that one raises
     585          # a plain exception but raises KeyboardInterrupt in nested();
     586          # this test does it the other way around.)
     587  
     588          async def crash_soon():
     589              await asyncio.sleep(0.1)
     590              raise KeyboardInterrupt
     591  
     592          async def nested():
     593              try:
     594                  await asyncio.sleep(10)
     595              finally:
     596                  raise TypeError
     597  
     598          async def runner():
     599              async with taskgroups.TaskGroup() as g:
     600                  g.create_task(crash_soon())
     601                  await nested()
     602  
     603          with self.assertRaises(KeyboardInterrupt):
     604              await runner()
     605  
     606      async def test_taskgroup_21a(self):
     607  
     608          async def crash_soon():
     609              await asyncio.sleep(0.1)
     610              raise MyBaseExc
     611  
     612          async def nested():
     613              try:
     614                  await asyncio.sleep(10)
     615              finally:
     616                  raise TypeError
     617  
     618          async def runner():
     619              async with taskgroups.TaskGroup() as g:
     620                  g.create_task(crash_soon())
     621                  await nested()
     622  
     623          with self.assertRaises(BaseExceptionGroup) as cm:
     624              await runner()
     625  
     626          self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
     627  
     628      async def test_taskgroup_22(self):
     629  
     630          async def foo1():
     631              await asyncio.sleep(1)
     632              return 42
     633  
     634          async def foo2():
     635              await asyncio.sleep(2)
     636              return 11
     637  
     638          async def runner():
     639              async with taskgroups.TaskGroup() as g:
     640                  g.create_task(foo1())
     641                  g.create_task(foo2())
     642  
     643          r = asyncio.create_task(runner())
     644          await asyncio.sleep(0.05)
     645          r.cancel()
     646  
     647          with self.assertRaises(asyncio.CancelledError):
     648              await r
     649  
     650      async def test_taskgroup_23(self):
     651  
     652          async def do_job(delay):
     653              await asyncio.sleep(delay)
     654  
     655          async with taskgroups.TaskGroup() as g:
     656              for count in range(10):
     657                  await asyncio.sleep(0.1)
     658                  g.create_task(do_job(0.3))
     659                  if count == 5:
     660                      self.assertLess(len(g._tasks), 5)
     661              await asyncio.sleep(1.35)
     662              self.assertEqual(len(g._tasks), 0)
     663  
     664      async def test_taskgroup_24(self):
     665  
     666          async def root(g):
     667              await asyncio.sleep(0.1)
     668              g.create_task(coro1(0.1))
     669              g.create_task(coro1(0.2))
     670  
     671          async def coro1(delay):
     672              await asyncio.sleep(delay)
     673  
     674          async def runner():
     675              async with taskgroups.TaskGroup() as g:
     676                  g.create_task(root(g))
     677  
     678          await runner()
     679  
     680      async def test_taskgroup_25(self):
     681          nhydras = 0
     682  
     683          async def hydra(g):
     684              nonlocal nhydras
     685              nhydras += 1
     686              await asyncio.sleep(0.01)
     687              g.create_task(hydra(g))
     688              g.create_task(hydra(g))
     689  
     690          async def hercules():
     691              while nhydras < 10:
     692                  await asyncio.sleep(0.015)
     693              1 / 0
     694  
     695          async def runner():
     696              async with taskgroups.TaskGroup() as g:
     697                  g.create_task(hydra(g))
     698                  g.create_task(hercules())
     699  
     700          with self.assertRaises(ExceptionGroup) as cm:
     701              await runner()
     702  
     703          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     704          self.assertGreaterEqual(nhydras, 10)
     705  
     706      async def test_taskgroup_task_name(self):
     707          async def coro():
     708              await asyncio.sleep(0)
     709          async with taskgroups.TaskGroup() as g:
     710              t = g.create_task(coro(), name="yolo")
     711              self.assertEqual(t.get_name(), "yolo")
     712  
     713      async def test_taskgroup_task_context(self):
     714          cvar = contextvars.ContextVar('cvar')
     715  
     716          async def coro(val):
     717              await asyncio.sleep(0)
     718              cvar.set(val)
     719  
     720          async with taskgroups.TaskGroup() as g:
     721              ctx = contextvars.copy_context()
     722              self.assertIsNone(ctx.get(cvar))
     723              t1 = g.create_task(coro(1), context=ctx)
     724              await t1
     725              self.assertEqual(1, ctx.get(cvar))
     726              t2 = g.create_task(coro(2), context=ctx)
     727              await t2
     728              self.assertEqual(2, ctx.get(cvar))
     729  
     730      async def test_taskgroup_no_create_task_after_failure(self):
     731          async def coro1():
     732              await asyncio.sleep(0.001)
     733              1 / 0
     734          async def coro2(g):
     735              try:
     736                  await asyncio.sleep(1)
     737              except asyncio.CancelledError:
     738                  with self.assertRaises(RuntimeError):
     739                      g.create_task(c1 := coro1())
     740                  # We still have to await c1 to avoid a warning
     741                  with self.assertRaises(ZeroDivisionError):
     742                      await c1
     743  
     744          with self.assertRaises(ExceptionGroup) as cm:
     745              async with taskgroups.TaskGroup() as g:
     746                  g.create_task(coro1())
     747                  g.create_task(coro2(g))
     748  
     749          self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
     750  
     751      async def test_taskgroup_context_manager_exit_raises(self):
     752          # See https://github.com/python/cpython/issues/95289
     753          class ESC[4;38;5;81mCustomException(ESC[4;38;5;149mException):
     754              pass
     755  
     756          async def raise_exc():
     757              raise CustomException
     758  
     759          @contextlib.asynccontextmanager
     760          async def database():
     761              try:
     762                  yield
     763              finally:
     764                  raise CustomException
     765  
     766          async def main():
     767              task = asyncio.current_task()
     768              try:
     769                  async with taskgroups.TaskGroup() as tg:
     770                      async with database():
     771                          tg.create_task(raise_exc())
     772                          await asyncio.sleep(1)
     773              except* CustomException as err:
     774                  self.assertEqual(task.cancelling(), 0)
     775                  self.assertEqual(len(err.exceptions), 2)
     776  
     777              else:
     778                  self.fail('CustomException not raised')
     779  
     780          await asyncio.create_task(main())
     781  
     782  
     783  if __name__ == "__main__":
     784      unittest.main()