(root)/
Python-3.12.0/
Lib/
test/
test_asyncio/
test_eager_task_factory.py
       1  """Tests for base_events.py"""
       2  
       3  import asyncio
       4  import contextvars
       5  import gc
       6  import time
       7  import unittest
       8  
       9  from types import GenericAlias
      10  from unittest import mock
      11  from asyncio import base_events
      12  from asyncio import tasks
      13  from test.test_asyncio import utils as test_utils
      14  from test.test_asyncio.test_tasks import get_innermost_context
      15  from test.support.script_helper import assert_python_ok
      16  
      17  MOCK_ANY = mock.ANY
      18  
      19  
      20  def tearDownModule():
      21      asyncio.set_event_loop_policy(None)
      22  
      23  
      24  class ESC[4;38;5;81mEagerTaskFactoryLoopTests:
      25  
      26      Task = None
      27  
      28      def run_coro(self, coro):
      29          """
      30          Helper method to run the `coro` coroutine in the test event loop.
      31          It helps with making sure the event loop is running before starting
      32          to execute `coro`. This is important for testing the eager step
      33          functionality, since an eager step is taken only if the event loop
      34          is already running.
      35          """
      36  
      37          async def coro_runner():
      38              self.assertTrue(asyncio.get_event_loop().is_running())
      39              return await coro
      40  
      41          return self.loop.run_until_complete(coro)
      42  
      43      def setUp(self):
      44          super().setUp()
      45          self.loop = asyncio.new_event_loop()
      46          self.eager_task_factory = asyncio.create_eager_task_factory(self.Task)
      47          self.loop.set_task_factory(self.eager_task_factory)
      48          self.set_event_loop(self.loop)
      49  
      50      def test_eager_task_factory_set(self):
      51          self.assertIsNotNone(self.eager_task_factory)
      52          self.assertIs(self.loop.get_task_factory(), self.eager_task_factory)
      53  
      54          async def noop(): pass
      55  
      56          async def run():
      57              t = self.loop.create_task(noop())
      58              self.assertIsInstance(t, self.Task)
      59              await t
      60  
      61          self.run_coro(run())
      62  
      63      def test_await_future_during_eager_step(self):
      64  
      65          async def set_result(fut, val):
      66              fut.set_result(val)
      67  
      68          async def run():
      69              fut = self.loop.create_future()
      70              t = self.loop.create_task(set_result(fut, 'my message'))
      71              # assert the eager step completed the task
      72              self.assertTrue(t.done())
      73              return await fut
      74  
      75          self.assertEqual(self.run_coro(run()), 'my message')
      76  
      77      def test_eager_completion(self):
      78  
      79          async def coro():
      80              return 'hello'
      81  
      82          async def run():
      83              t = self.loop.create_task(coro())
      84              # assert the eager step completed the task
      85              self.assertTrue(t.done())
      86              return await t
      87  
      88          self.assertEqual(self.run_coro(run()), 'hello')
      89  
      90      def test_block_after_eager_step(self):
      91  
      92          async def coro():
      93              await asyncio.sleep(0.1)
      94              return 'finished after blocking'
      95  
      96          async def run():
      97              t = self.loop.create_task(coro())
      98              self.assertFalse(t.done())
      99              result = await t
     100              self.assertTrue(t.done())
     101              return result
     102  
     103          self.assertEqual(self.run_coro(run()), 'finished after blocking')
     104  
     105      def test_cancellation_after_eager_completion(self):
     106  
     107          async def coro():
     108              return 'finished without blocking'
     109  
     110          async def run():
     111              t = self.loop.create_task(coro())
     112              t.cancel()
     113              result = await t
     114              # finished task can't be cancelled
     115              self.assertFalse(t.cancelled())
     116              return result
     117  
     118          self.assertEqual(self.run_coro(run()), 'finished without blocking')
     119  
     120      def test_cancellation_after_eager_step_blocks(self):
     121  
     122          async def coro():
     123              await asyncio.sleep(0.1)
     124              return 'finished after blocking'
     125  
     126          async def run():
     127              t = self.loop.create_task(coro())
     128              t.cancel('cancellation message')
     129              self.assertGreater(t.cancelling(), 0)
     130              result = await t
     131  
     132          with self.assertRaises(asyncio.CancelledError) as cm:
     133              self.run_coro(run())
     134  
     135          self.assertEqual('cancellation message', cm.exception.args[0])
     136  
     137      def test_current_task(self):
     138          captured_current_task = None
     139  
     140          async def coro():
     141              nonlocal captured_current_task
     142              captured_current_task = asyncio.current_task()
     143              # verify the task before and after blocking is identical
     144              await asyncio.sleep(0.1)
     145              self.assertIs(asyncio.current_task(), captured_current_task)
     146  
     147          async def run():
     148              t = self.loop.create_task(coro())
     149              self.assertIs(captured_current_task, t)
     150              await t
     151  
     152          self.run_coro(run())
     153          captured_current_task = None
     154  
     155      def test_all_tasks_with_eager_completion(self):
     156          captured_all_tasks = None
     157  
     158          async def coro():
     159              nonlocal captured_all_tasks
     160              captured_all_tasks = asyncio.all_tasks()
     161  
     162          async def run():
     163              t = self.loop.create_task(coro())
     164              self.assertIn(t, captured_all_tasks)
     165              self.assertNotIn(t, asyncio.all_tasks())
     166  
     167          self.run_coro(run())
     168  
     169      def test_all_tasks_with_blocking(self):
     170          captured_eager_all_tasks = None
     171  
     172          async def coro(fut1, fut2):
     173              nonlocal captured_eager_all_tasks
     174              captured_eager_all_tasks = asyncio.all_tasks()
     175              await fut1
     176              fut2.set_result(None)
     177  
     178          async def run():
     179              fut1 = self.loop.create_future()
     180              fut2 = self.loop.create_future()
     181              t = self.loop.create_task(coro(fut1, fut2))
     182              self.assertIn(t, captured_eager_all_tasks)
     183              self.assertIn(t, asyncio.all_tasks())
     184              fut1.set_result(None)
     185              await fut2
     186              self.assertNotIn(t, asyncio.all_tasks())
     187  
     188          self.run_coro(run())
     189  
     190      def test_context_vars(self):
     191          cv = contextvars.ContextVar('cv', default=0)
     192  
     193          coro_first_step_ran = False
     194          coro_second_step_ran = False
     195  
     196          async def coro():
     197              nonlocal coro_first_step_ran
     198              nonlocal coro_second_step_ran
     199              self.assertEqual(cv.get(), 1)
     200              cv.set(2)
     201              self.assertEqual(cv.get(), 2)
     202              coro_first_step_ran = True
     203              await asyncio.sleep(0.1)
     204              self.assertEqual(cv.get(), 2)
     205              cv.set(3)
     206              self.assertEqual(cv.get(), 3)
     207              coro_second_step_ran = True
     208  
     209          async def run():
     210              cv.set(1)
     211              t = self.loop.create_task(coro())
     212              self.assertTrue(coro_first_step_ran)
     213              self.assertFalse(coro_second_step_ran)
     214              self.assertEqual(cv.get(), 1)
     215              await t
     216              self.assertTrue(coro_second_step_ran)
     217              self.assertEqual(cv.get(), 1)
     218  
     219          self.run_coro(run())
     220  
     221  
     222  class ESC[4;38;5;81mPyEagerTaskFactoryLoopTests(ESC[4;38;5;149mEagerTaskFactoryLoopTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     223      Task = tasks._PyTask
     224  
     225  
     226  @unittest.skipUnless(hasattr(tasks, '_CTask'),
     227                       'requires the C _asyncio module')
     228  class ESC[4;38;5;81mCEagerTaskFactoryLoopTests(ESC[4;38;5;149mEagerTaskFactoryLoopTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     229      Task = getattr(tasks, '_CTask', None)
     230  
     231      def test_issue105987(self):
     232          code = """if 1:
     233          from _asyncio import _swap_current_task
     234  
     235          class DummyTask:
     236              pass
     237  
     238          class DummyLoop:
     239              pass
     240  
     241          l = DummyLoop()
     242          _swap_current_task(l, DummyTask())
     243          t = _swap_current_task(l, None)
     244          """
     245  
     246          _, out, err = assert_python_ok("-c", code)
     247          self.assertFalse(err)
     248  
     249  class ESC[4;38;5;81mAsyncTaskCounter:
     250      def __init__(self, loop, *, task_class, eager):
     251          self.suspense_count = 0
     252          self.task_count = 0
     253  
     254          def CountingTask(*args, eager_start=False, **kwargs):
     255              if not eager_start:
     256                  self.task_count += 1
     257              kwargs["eager_start"] = eager_start
     258              return task_class(*args, **kwargs)
     259  
     260          if eager:
     261              factory = asyncio.create_eager_task_factory(CountingTask)
     262          else:
     263              def factory(loop, coro, **kwargs):
     264                  return CountingTask(coro, loop=loop, **kwargs)
     265          loop.set_task_factory(factory)
     266  
     267      def get(self):
     268          return self.task_count
     269  
     270  
     271  async def awaitable_chain(depth):
     272      if depth == 0:
     273          return 0
     274      return 1 + await awaitable_chain(depth - 1)
     275  
     276  
     277  async def recursive_taskgroups(width, depth):
     278      if depth == 0:
     279          return
     280  
     281      async with asyncio.TaskGroup() as tg:
     282          futures = [
     283              tg.create_task(recursive_taskgroups(width, depth - 1))
     284              for _ in range(width)
     285          ]
     286  
     287  
     288  async def recursive_gather(width, depth):
     289      if depth == 0:
     290          return
     291  
     292      await asyncio.gather(
     293          *[recursive_gather(width, depth - 1) for _ in range(width)]
     294      )
     295  
     296  
     297  class ESC[4;38;5;81mBaseTaskCountingTests:
     298  
     299      Task = None
     300      eager = None
     301      expected_task_count = None
     302  
     303      def setUp(self):
     304          super().setUp()
     305          self.loop = asyncio.new_event_loop()
     306          self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager)
     307          self.set_event_loop(self.loop)
     308  
     309      def test_awaitables_chain(self):
     310          observed_depth = self.loop.run_until_complete(awaitable_chain(100))
     311          self.assertEqual(observed_depth, 100)
     312          self.assertEqual(self.counter.get(), 0 if self.eager else 1)
     313  
     314      def test_recursive_taskgroups(self):
     315          num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4))
     316          self.assertEqual(self.counter.get(), self.expected_task_count)
     317  
     318      def test_recursive_gather(self):
     319          self.loop.run_until_complete(recursive_gather(5, 4))
     320          self.assertEqual(self.counter.get(), self.expected_task_count)
     321  
     322  
     323  class ESC[4;38;5;81mBaseNonEagerTaskFactoryTests(ESC[4;38;5;149mBaseTaskCountingTests):
     324      eager = False
     325      expected_task_count = 781  # 1 + 5 + 5^2 + 5^3 + 5^4
     326  
     327  
     328  class ESC[4;38;5;81mBaseEagerTaskFactoryTests(ESC[4;38;5;149mBaseTaskCountingTests):
     329      eager = True
     330      expected_task_count = 0
     331  
     332  
     333  class ESC[4;38;5;81mNonEagerTests(ESC[4;38;5;149mBaseNonEagerTaskFactoryTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     334      Task = asyncio.Task
     335  
     336  
     337  class ESC[4;38;5;81mEagerTests(ESC[4;38;5;149mBaseEagerTaskFactoryTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     338      Task = asyncio.Task
     339  
     340  
     341  class ESC[4;38;5;81mNonEagerPyTaskTests(ESC[4;38;5;149mBaseNonEagerTaskFactoryTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     342      Task = tasks._PyTask
     343  
     344  
     345  class ESC[4;38;5;81mEagerPyTaskTests(ESC[4;38;5;149mBaseEagerTaskFactoryTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     346      Task = tasks._PyTask
     347  
     348  
     349  @unittest.skipUnless(hasattr(tasks, '_CTask'),
     350                       'requires the C _asyncio module')
     351  class ESC[4;38;5;81mNonEagerCTaskTests(ESC[4;38;5;149mBaseNonEagerTaskFactoryTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     352      Task = getattr(tasks, '_CTask', None)
     353  
     354  
     355  @unittest.skipUnless(hasattr(tasks, '_CTask'),
     356                       'requires the C _asyncio module')
     357  class ESC[4;38;5;81mEagerCTaskTests(ESC[4;38;5;149mBaseEagerTaskFactoryTests, ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     358      Task = getattr(tasks, '_CTask', None)
     359  
     360  if __name__ == '__main__':
     361      unittest.main()