1  import asyncio
       2  import gc
       3  import inspect
       4  import re
       5  import unittest
       6  from contextlib import contextmanager
       7  from test import support
       8  
       9  support.requires_working_socket(module=True)
      10  
      11  from asyncio import run, iscoroutinefunction
      12  from unittest import IsolatedAsyncioTestCase
      13  from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
      14                             create_autospec, sentinel, _CallList, seal)
      15  
      16  
      17  def tearDownModule():
      18      asyncio.set_event_loop_policy(None)
      19  
      20  
      21  class ESC[4;38;5;81mAsyncClass:
      22      def __init__(self): pass
      23      async def async_method(self): pass
      24      def normal_method(self): pass
      25  
      26      @classmethod
      27      async def async_class_method(cls): pass
      28  
      29      @staticmethod
      30      async def async_static_method(): pass
      31  
      32  
      33  class ESC[4;38;5;81mAwaitableClass:
      34      def __await__(self): yield
      35  
      36  async def async_func(): pass
      37  
      38  async def async_func_args(a, b, *, c): pass
      39  
      40  def normal_func(): pass
      41  
      42  class ESC[4;38;5;81mNormalClass(ESC[4;38;5;149mobject):
      43      def a(self): pass
      44  
      45  
      46  async_foo_name = f'{__name__}.AsyncClass'
      47  normal_foo_name = f'{__name__}.NormalClass'
      48  
      49  
      50  @contextmanager
      51  def assertNeverAwaited(test):
      52      with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"):
      53          yield
      54          # In non-CPython implementations of Python, this is needed because timely
      55          # deallocation is not guaranteed by the garbage collector.
      56          gc.collect()
      57  
      58  
      59  class ESC[4;38;5;81mAsyncPatchDecoratorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      60      def test_is_coroutine_function_patch(self):
      61          @patch.object(AsyncClass, 'async_method')
      62          def test_async(mock_method):
      63              self.assertTrue(iscoroutinefunction(mock_method))
      64          test_async()
      65  
      66      def test_is_async_patch(self):
      67          @patch.object(AsyncClass, 'async_method')
      68          def test_async(mock_method):
      69              m = mock_method()
      70              self.assertTrue(inspect.isawaitable(m))
      71              run(m)
      72  
      73          @patch(f'{async_foo_name}.async_method')
      74          def test_no_parent_attribute(mock_method):
      75              m = mock_method()
      76              self.assertTrue(inspect.isawaitable(m))
      77              run(m)
      78  
      79          test_async()
      80          test_no_parent_attribute()
      81  
      82      def test_is_AsyncMock_patch(self):
      83          @patch.object(AsyncClass, 'async_method')
      84          def test_async(mock_method):
      85              self.assertIsInstance(mock_method, AsyncMock)
      86  
      87          test_async()
      88  
      89      def test_is_AsyncMock_patch_staticmethod(self):
      90          @patch.object(AsyncClass, 'async_static_method')
      91          def test_async(mock_method):
      92              self.assertIsInstance(mock_method, AsyncMock)
      93  
      94          test_async()
      95  
      96      def test_is_AsyncMock_patch_classmethod(self):
      97          @patch.object(AsyncClass, 'async_class_method')
      98          def test_async(mock_method):
      99              self.assertIsInstance(mock_method, AsyncMock)
     100  
     101          test_async()
     102  
     103      def test_async_def_patch(self):
     104          @patch(f"{__name__}.async_func", return_value=1)
     105          @patch(f"{__name__}.async_func_args", return_value=2)
     106          async def test_async(func_args_mock, func_mock):
     107              self.assertEqual(func_args_mock._mock_name, "async_func_args")
     108              self.assertEqual(func_mock._mock_name, "async_func")
     109  
     110              self.assertIsInstance(async_func, AsyncMock)
     111              self.assertIsInstance(async_func_args, AsyncMock)
     112  
     113              self.assertEqual(await async_func(), 1)
     114              self.assertEqual(await async_func_args(1, 2, c=3), 2)
     115  
     116          run(test_async())
     117          self.assertTrue(inspect.iscoroutinefunction(async_func))
     118  
     119  
     120  class ESC[4;38;5;81mAsyncPatchCMTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     121      def test_is_async_function_cm(self):
     122          def test_async():
     123              with patch.object(AsyncClass, 'async_method') as mock_method:
     124                  self.assertTrue(iscoroutinefunction(mock_method))
     125  
     126          test_async()
     127  
     128      def test_is_async_cm(self):
     129          def test_async():
     130              with patch.object(AsyncClass, 'async_method') as mock_method:
     131                  m = mock_method()
     132                  self.assertTrue(inspect.isawaitable(m))
     133                  run(m)
     134  
     135          test_async()
     136  
     137      def test_is_AsyncMock_cm(self):
     138          def test_async():
     139              with patch.object(AsyncClass, 'async_method') as mock_method:
     140                  self.assertIsInstance(mock_method, AsyncMock)
     141  
     142          test_async()
     143  
     144      def test_async_def_cm(self):
     145          async def test_async():
     146              with patch(f"{__name__}.async_func", AsyncMock()):
     147                  self.assertIsInstance(async_func, AsyncMock)
     148              self.assertTrue(inspect.iscoroutinefunction(async_func))
     149  
     150          run(test_async())
     151  
     152      def test_patch_dict_async_def(self):
     153          foo = {'a': 'a'}
     154          @patch.dict(foo, {'a': 'b'})
     155          async def test_async():
     156              self.assertEqual(foo['a'], 'b')
     157  
     158          self.assertTrue(iscoroutinefunction(test_async))
     159          run(test_async())
     160  
     161      def test_patch_dict_async_def_context(self):
     162          foo = {'a': 'a'}
     163          async def test_async():
     164              with patch.dict(foo, {'a': 'b'}):
     165                  self.assertEqual(foo['a'], 'b')
     166  
     167          run(test_async())
     168  
     169  
     170  class ESC[4;38;5;81mAsyncMockTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     171      def test_iscoroutinefunction_default(self):
     172          mock = AsyncMock()
     173          self.assertTrue(iscoroutinefunction(mock))
     174  
     175      def test_iscoroutinefunction_function(self):
     176          async def foo(): pass
     177          mock = AsyncMock(foo)
     178          self.assertTrue(iscoroutinefunction(mock))
     179          self.assertTrue(inspect.iscoroutinefunction(mock))
     180  
     181      def test_isawaitable(self):
     182          mock = AsyncMock()
     183          m = mock()
     184          self.assertTrue(inspect.isawaitable(m))
     185          run(m)
     186          self.assertIn('assert_awaited', dir(mock))
     187  
     188      def test_iscoroutinefunction_normal_function(self):
     189          def foo(): pass
     190          mock = AsyncMock(foo)
     191          self.assertTrue(iscoroutinefunction(mock))
     192          self.assertTrue(inspect.iscoroutinefunction(mock))
     193  
     194      def test_future_isfuture(self):
     195          loop = asyncio.new_event_loop()
     196          fut = loop.create_future()
     197          loop.stop()
     198          loop.close()
     199          mock = AsyncMock(fut)
     200          self.assertIsInstance(mock, asyncio.Future)
     201  
     202  
     203  class ESC[4;38;5;81mAsyncAutospecTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     204      def test_is_AsyncMock_patch(self):
     205          @patch(async_foo_name, autospec=True)
     206          def test_async(mock_method):
     207              self.assertIsInstance(mock_method.async_method, AsyncMock)
     208              self.assertIsInstance(mock_method, MagicMock)
     209  
     210          @patch(async_foo_name, autospec=True)
     211          def test_normal_method(mock_method):
     212              self.assertIsInstance(mock_method.normal_method, MagicMock)
     213  
     214          test_async()
     215          test_normal_method()
     216  
     217      def test_create_autospec_instance(self):
     218          with self.assertRaises(RuntimeError):
     219              create_autospec(async_func, instance=True)
     220  
     221      def test_create_autospec(self):
     222          spec = create_autospec(async_func_args)
     223          awaitable = spec(1, 2, c=3)
     224          async def main():
     225              await awaitable
     226  
     227          self.assertEqual(spec.await_count, 0)
     228          self.assertIsNone(spec.await_args)
     229          self.assertEqual(spec.await_args_list, [])
     230          spec.assert_not_awaited()
     231  
     232          run(main())
     233  
     234          self.assertTrue(iscoroutinefunction(spec))
     235          self.assertTrue(asyncio.iscoroutine(awaitable))
     236          self.assertEqual(spec.await_count, 1)
     237          self.assertEqual(spec.await_args, call(1, 2, c=3))
     238          self.assertEqual(spec.await_args_list, [call(1, 2, c=3)])
     239          spec.assert_awaited_once()
     240          spec.assert_awaited_once_with(1, 2, c=3)
     241          spec.assert_awaited_with(1, 2, c=3)
     242          spec.assert_awaited()
     243  
     244          with self.assertRaises(AssertionError):
     245              spec.assert_any_await(e=1)
     246  
     247  
     248      def test_patch_with_autospec(self):
     249  
     250          async def test_async():
     251              with patch(f"{__name__}.async_func_args", autospec=True) as mock_method:
     252                  awaitable = mock_method(1, 2, c=3)
     253                  self.assertIsInstance(mock_method.mock, AsyncMock)
     254  
     255                  self.assertTrue(iscoroutinefunction(mock_method))
     256                  self.assertTrue(asyncio.iscoroutine(awaitable))
     257                  self.assertTrue(inspect.isawaitable(awaitable))
     258  
     259                  # Verify the default values during mock setup
     260                  self.assertEqual(mock_method.await_count, 0)
     261                  self.assertEqual(mock_method.await_args_list, [])
     262                  self.assertIsNone(mock_method.await_args)
     263                  mock_method.assert_not_awaited()
     264  
     265                  await awaitable
     266  
     267              self.assertEqual(mock_method.await_count, 1)
     268              self.assertEqual(mock_method.await_args, call(1, 2, c=3))
     269              self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)])
     270              mock_method.assert_awaited_once()
     271              mock_method.assert_awaited_once_with(1, 2, c=3)
     272              mock_method.assert_awaited_with(1, 2, c=3)
     273              mock_method.assert_awaited()
     274  
     275              mock_method.reset_mock()
     276              self.assertEqual(mock_method.await_count, 0)
     277              self.assertIsNone(mock_method.await_args)
     278              self.assertEqual(mock_method.await_args_list, [])
     279  
     280          run(test_async())
     281  
     282  
     283  class ESC[4;38;5;81mAsyncSpecTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     284      def test_spec_normal_methods_on_class(self):
     285          def inner_test(mock_type):
     286              mock = mock_type(AsyncClass)
     287              self.assertIsInstance(mock.async_method, AsyncMock)
     288              self.assertIsInstance(mock.normal_method, MagicMock)
     289  
     290          for mock_type in [AsyncMock, MagicMock]:
     291              with self.subTest(f"test method types with {mock_type}"):
     292                  inner_test(mock_type)
     293  
     294      def test_spec_normal_methods_on_class_with_mock(self):
     295          mock = Mock(AsyncClass)
     296          self.assertIsInstance(mock.async_method, AsyncMock)
     297          self.assertIsInstance(mock.normal_method, Mock)
     298  
     299      def test_spec_normal_methods_on_class_with_mock_seal(self):
     300          mock = Mock(AsyncClass)
     301          seal(mock)
     302          with self.assertRaises(AttributeError):
     303              mock.normal_method
     304          with self.assertRaises(AttributeError):
     305              mock.async_method
     306  
     307      def test_spec_async_attributes_instance(self):
     308          async_instance = AsyncClass()
     309          async_instance.async_func_attr = async_func
     310          async_instance.later_async_func_attr = normal_func
     311  
     312          mock_async_instance = Mock(spec_set=async_instance)
     313  
     314          async_instance.later_async_func_attr = async_func
     315  
     316          self.assertIsInstance(mock_async_instance.async_func_attr, AsyncMock)
     317          # only the shape of the spec at the time of mock construction matters
     318          self.assertNotIsInstance(mock_async_instance.later_async_func_attr, AsyncMock)
     319  
     320      def test_spec_mock_type_kw(self):
     321          def inner_test(mock_type):
     322              async_mock = mock_type(spec=async_func)
     323              self.assertIsInstance(async_mock, mock_type)
     324              with assertNeverAwaited(self):
     325                  self.assertTrue(inspect.isawaitable(async_mock()))
     326  
     327              sync_mock = mock_type(spec=normal_func)
     328              self.assertIsInstance(sync_mock, mock_type)
     329  
     330          for mock_type in [AsyncMock, MagicMock, Mock]:
     331              with self.subTest(f"test spec kwarg with {mock_type}"):
     332                  inner_test(mock_type)
     333  
     334      def test_spec_mock_type_positional(self):
     335          def inner_test(mock_type):
     336              async_mock = mock_type(async_func)
     337              self.assertIsInstance(async_mock, mock_type)
     338              with assertNeverAwaited(self):
     339                  self.assertTrue(inspect.isawaitable(async_mock()))
     340  
     341              sync_mock = mock_type(normal_func)
     342              self.assertIsInstance(sync_mock, mock_type)
     343  
     344          for mock_type in [AsyncMock, MagicMock, Mock]:
     345              with self.subTest(f"test spec positional with {mock_type}"):
     346                  inner_test(mock_type)
     347  
     348      def test_spec_as_normal_kw_AsyncMock(self):
     349          mock = AsyncMock(spec=normal_func)
     350          self.assertIsInstance(mock, AsyncMock)
     351          m = mock()
     352          self.assertTrue(inspect.isawaitable(m))
     353          run(m)
     354  
     355      def test_spec_as_normal_positional_AsyncMock(self):
     356          mock = AsyncMock(normal_func)
     357          self.assertIsInstance(mock, AsyncMock)
     358          m = mock()
     359          self.assertTrue(inspect.isawaitable(m))
     360          run(m)
     361  
     362      def test_spec_async_mock(self):
     363          @patch.object(AsyncClass, 'async_method', spec=True)
     364          def test_async(mock_method):
     365              self.assertIsInstance(mock_method, AsyncMock)
     366  
     367          test_async()
     368  
     369      def test_spec_parent_not_async_attribute_is(self):
     370          @patch(async_foo_name, spec=True)
     371          def test_async(mock_method):
     372              self.assertIsInstance(mock_method, MagicMock)
     373              self.assertIsInstance(mock_method.async_method, AsyncMock)
     374  
     375          test_async()
     376  
     377      def test_target_async_spec_not(self):
     378          @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
     379          def test_async_attribute(mock_method):
     380              self.assertIsInstance(mock_method, MagicMock)
     381              self.assertFalse(inspect.iscoroutine(mock_method))
     382              self.assertFalse(inspect.isawaitable(mock_method))
     383  
     384          test_async_attribute()
     385  
     386      def test_target_not_async_spec_is(self):
     387          @patch.object(NormalClass, 'a', spec=async_func)
     388          def test_attribute_not_async_spec_is(mock_async_func):
     389              self.assertIsInstance(mock_async_func, AsyncMock)
     390          test_attribute_not_async_spec_is()
     391  
     392      def test_spec_async_attributes(self):
     393          @patch(normal_foo_name, spec=AsyncClass)
     394          def test_async_attributes_coroutines(MockNormalClass):
     395              self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
     396              self.assertIsInstance(MockNormalClass, MagicMock)
     397  
     398          test_async_attributes_coroutines()
     399  
     400  
     401  class ESC[4;38;5;81mAsyncSpecSetTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     402      def test_is_AsyncMock_patch(self):
     403          @patch.object(AsyncClass, 'async_method', spec_set=True)
     404          def test_async(async_method):
     405              self.assertIsInstance(async_method, AsyncMock)
     406          test_async()
     407  
     408      def test_is_async_AsyncMock(self):
     409          mock = AsyncMock(spec_set=AsyncClass.async_method)
     410          self.assertTrue(iscoroutinefunction(mock))
     411          self.assertIsInstance(mock, AsyncMock)
     412  
     413      def test_is_child_AsyncMock(self):
     414          mock = MagicMock(spec_set=AsyncClass)
     415          self.assertTrue(iscoroutinefunction(mock.async_method))
     416          self.assertFalse(iscoroutinefunction(mock.normal_method))
     417          self.assertIsInstance(mock.async_method, AsyncMock)
     418          self.assertIsInstance(mock.normal_method, MagicMock)
     419          self.assertIsInstance(mock, MagicMock)
     420  
     421      def test_magicmock_lambda_spec(self):
     422          mock_obj = MagicMock()
     423          mock_obj.mock_func = MagicMock(spec=lambda x: x)
     424  
     425          with patch.object(mock_obj, "mock_func") as cm:
     426              self.assertIsInstance(cm, MagicMock)
     427  
     428  
     429  class ESC[4;38;5;81mAsyncArguments(ESC[4;38;5;149mIsolatedAsyncioTestCase):
     430      async def test_add_return_value(self):
     431          async def addition(self, var): pass
     432  
     433          mock = AsyncMock(addition, return_value=10)
     434          output = await mock(5)
     435  
     436          self.assertEqual(output, 10)
     437  
     438      async def test_add_side_effect_exception(self):
     439          class ESC[4;38;5;81mCustomError(ESC[4;38;5;149mException): pass
     440          async def addition(var): pass
     441          mock = AsyncMock(addition, side_effect=CustomError('side-effect'))
     442          with self.assertRaisesRegex(CustomError, 'side-effect'):
     443              await mock(5)
     444  
     445      async def test_add_side_effect_coroutine(self):
     446          async def addition(var):
     447              return var + 1
     448          mock = AsyncMock(side_effect=addition)
     449          result = await mock(5)
     450          self.assertEqual(result, 6)
     451  
     452      async def test_add_side_effect_normal_function(self):
     453          def addition(var):
     454              return var + 1
     455          mock = AsyncMock(side_effect=addition)
     456          result = await mock(5)
     457          self.assertEqual(result, 6)
     458  
     459      async def test_add_side_effect_iterable(self):
     460          vals = [1, 2, 3]
     461          mock = AsyncMock(side_effect=vals)
     462          for item in vals:
     463              self.assertEqual(await mock(), item)
     464  
     465          with self.assertRaises(StopAsyncIteration) as e:
     466              await mock()
     467  
     468      async def test_add_side_effect_exception_iterable(self):
     469          class ESC[4;38;5;81mSampleException(ESC[4;38;5;149mException):
     470              pass
     471  
     472          vals = [1, SampleException("foo")]
     473          mock = AsyncMock(side_effect=vals)
     474          self.assertEqual(await mock(), 1)
     475  
     476          with self.assertRaises(SampleException) as e:
     477              await mock()
     478  
     479      async def test_return_value_AsyncMock(self):
     480          value = AsyncMock(return_value=10)
     481          mock = AsyncMock(return_value=value)
     482          result = await mock()
     483          self.assertIs(result, value)
     484  
     485      async def test_return_value_awaitable(self):
     486          fut = asyncio.Future()
     487          fut.set_result(None)
     488          mock = AsyncMock(return_value=fut)
     489          result = await mock()
     490          self.assertIsInstance(result, asyncio.Future)
     491  
     492      async def test_side_effect_awaitable_values(self):
     493          fut = asyncio.Future()
     494          fut.set_result(None)
     495  
     496          mock = AsyncMock(side_effect=[fut])
     497          result = await mock()
     498          self.assertIsInstance(result, asyncio.Future)
     499  
     500          with self.assertRaises(StopAsyncIteration):
     501              await mock()
     502  
     503      async def test_side_effect_is_AsyncMock(self):
     504          effect = AsyncMock(return_value=10)
     505          mock = AsyncMock(side_effect=effect)
     506  
     507          result = await mock()
     508          self.assertEqual(result, 10)
     509  
     510      async def test_wraps_coroutine(self):
     511          value = asyncio.Future()
     512  
     513          ran = False
     514          async def inner():
     515              nonlocal ran
     516              ran = True
     517              return value
     518  
     519          mock = AsyncMock(wraps=inner)
     520          result = await mock()
     521          self.assertEqual(result, value)
     522          mock.assert_awaited()
     523          self.assertTrue(ran)
     524  
     525      async def test_wraps_normal_function(self):
     526          value = 1
     527  
     528          ran = False
     529          def inner():
     530              nonlocal ran
     531              ran = True
     532              return value
     533  
     534          mock = AsyncMock(wraps=inner)
     535          result = await mock()
     536          self.assertEqual(result, value)
     537          mock.assert_awaited()
     538          self.assertTrue(ran)
     539  
     540      async def test_await_args_list_order(self):
     541          async_mock = AsyncMock()
     542          mock2 = async_mock(2)
     543          mock1 = async_mock(1)
     544          await mock1
     545          await mock2
     546          async_mock.assert_has_awaits([call(1), call(2)])
     547          self.assertEqual(async_mock.await_args_list, [call(1), call(2)])
     548          self.assertEqual(async_mock.call_args_list, [call(2), call(1)])
     549  
     550  
     551  class ESC[4;38;5;81mAsyncMagicMethods(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     552      def test_async_magic_methods_return_async_mocks(self):
     553          m_mock = MagicMock()
     554          self.assertIsInstance(m_mock.__aenter__, AsyncMock)
     555          self.assertIsInstance(m_mock.__aexit__, AsyncMock)
     556          self.assertIsInstance(m_mock.__anext__, AsyncMock)
     557          # __aiter__ is actually a synchronous object
     558          # so should return a MagicMock
     559          self.assertIsInstance(m_mock.__aiter__, MagicMock)
     560  
     561      def test_sync_magic_methods_return_magic_mocks(self):
     562          a_mock = AsyncMock()
     563          self.assertIsInstance(a_mock.__enter__, MagicMock)
     564          self.assertIsInstance(a_mock.__exit__, MagicMock)
     565          self.assertIsInstance(a_mock.__next__, MagicMock)
     566          self.assertIsInstance(a_mock.__len__, MagicMock)
     567  
     568      def test_magicmock_has_async_magic_methods(self):
     569          m_mock = MagicMock()
     570          self.assertTrue(hasattr(m_mock, "__aenter__"))
     571          self.assertTrue(hasattr(m_mock, "__aexit__"))
     572          self.assertTrue(hasattr(m_mock, "__anext__"))
     573  
     574      def test_asyncmock_has_sync_magic_methods(self):
     575          a_mock = AsyncMock()
     576          self.assertTrue(hasattr(a_mock, "__enter__"))
     577          self.assertTrue(hasattr(a_mock, "__exit__"))
     578          self.assertTrue(hasattr(a_mock, "__next__"))
     579          self.assertTrue(hasattr(a_mock, "__len__"))
     580  
     581      def test_magic_methods_are_async_functions(self):
     582          m_mock = MagicMock()
     583          self.assertIsInstance(m_mock.__aenter__, AsyncMock)
     584          self.assertIsInstance(m_mock.__aexit__, AsyncMock)
     585          # AsyncMocks are also coroutine functions
     586          self.assertTrue(iscoroutinefunction(m_mock.__aenter__))
     587          self.assertTrue(iscoroutinefunction(m_mock.__aexit__))
     588  
     589  class ESC[4;38;5;81mAsyncContextManagerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     590  
     591      class ESC[4;38;5;81mWithAsyncContextManager:
     592          async def __aenter__(self, *args, **kwargs): pass
     593  
     594          async def __aexit__(self, *args, **kwargs): pass
     595  
     596      class ESC[4;38;5;81mWithSyncContextManager:
     597          def __enter__(self, *args, **kwargs): pass
     598  
     599          def __exit__(self, *args, **kwargs): pass
     600  
     601      class ESC[4;38;5;81mProductionCode:
     602          # Example real-world(ish) code
     603          def __init__(self):
     604              self.session = None
     605  
     606          async def main(self):
     607              async with self.session.post('https://python.org') as response:
     608                  val = await response.json()
     609                  return val
     610  
     611      def test_set_return_value_of_aenter(self):
     612          def inner_test(mock_type):
     613              pc = self.ProductionCode()
     614              pc.session = MagicMock(name='sessionmock')
     615              cm = mock_type(name='magic_cm')
     616              response = AsyncMock(name='response')
     617              response.json = AsyncMock(return_value={'json': 123})
     618              cm.__aenter__.return_value = response
     619              pc.session.post.return_value = cm
     620              result = run(pc.main())
     621              self.assertEqual(result, {'json': 123})
     622  
     623          for mock_type in [AsyncMock, MagicMock]:
     624              with self.subTest(f"test set return value of aenter with {mock_type}"):
     625                  inner_test(mock_type)
     626  
     627      def test_mock_supports_async_context_manager(self):
     628          def inner_test(mock_type):
     629              called = False
     630              cm = self.WithAsyncContextManager()
     631              cm_mock = mock_type(cm)
     632  
     633              async def use_context_manager():
     634                  nonlocal called
     635                  async with cm_mock as result:
     636                      called = True
     637                  return result
     638  
     639              cm_result = run(use_context_manager())
     640              self.assertTrue(called)
     641              self.assertTrue(cm_mock.__aenter__.called)
     642              self.assertTrue(cm_mock.__aexit__.called)
     643              cm_mock.__aenter__.assert_awaited()
     644              cm_mock.__aexit__.assert_awaited()
     645              # We mock __aenter__ so it does not return self
     646              self.assertIsNot(cm_mock, cm_result)
     647  
     648          for mock_type in [AsyncMock, MagicMock]:
     649              with self.subTest(f"test context manager magics with {mock_type}"):
     650                  inner_test(mock_type)
     651  
     652  
     653      def test_mock_customize_async_context_manager(self):
     654          instance = self.WithAsyncContextManager()
     655          mock_instance = MagicMock(instance)
     656  
     657          expected_result = object()
     658          mock_instance.__aenter__.return_value = expected_result
     659  
     660          async def use_context_manager():
     661              async with mock_instance as result:
     662                  return result
     663  
     664          self.assertIs(run(use_context_manager()), expected_result)
     665  
     666      def test_mock_customize_async_context_manager_with_coroutine(self):
     667          enter_called = False
     668          exit_called = False
     669  
     670          async def enter_coroutine(*args):
     671              nonlocal enter_called
     672              enter_called = True
     673  
     674          async def exit_coroutine(*args):
     675              nonlocal exit_called
     676              exit_called = True
     677  
     678          instance = self.WithAsyncContextManager()
     679          mock_instance = MagicMock(instance)
     680  
     681          mock_instance.__aenter__ = enter_coroutine
     682          mock_instance.__aexit__ = exit_coroutine
     683  
     684          async def use_context_manager():
     685              async with mock_instance:
     686                  pass
     687  
     688          run(use_context_manager())
     689          self.assertTrue(enter_called)
     690          self.assertTrue(exit_called)
     691  
     692      def test_context_manager_raise_exception_by_default(self):
     693          async def raise_in(context_manager):
     694              async with context_manager:
     695                  raise TypeError()
     696  
     697          instance = self.WithAsyncContextManager()
     698          mock_instance = MagicMock(instance)
     699          with self.assertRaises(TypeError):
     700              run(raise_in(mock_instance))
     701  
     702  
     703  class ESC[4;38;5;81mAsyncIteratorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     704      class ESC[4;38;5;81mWithAsyncIterator(ESC[4;38;5;149mobject):
     705          def __init__(self):
     706              self.items = ["foo", "NormalFoo", "baz"]
     707  
     708          def __aiter__(self): pass
     709  
     710          async def __anext__(self): pass
     711  
     712      def test_aiter_set_return_value(self):
     713          mock_iter = AsyncMock(name="tester")
     714          mock_iter.__aiter__.return_value = [1, 2, 3]
     715          async def main():
     716              return [i async for i in mock_iter]
     717          result = run(main())
     718          self.assertEqual(result, [1, 2, 3])
     719  
     720      def test_mock_aiter_and_anext_asyncmock(self):
     721          def inner_test(mock_type):
     722              instance = self.WithAsyncIterator()
     723              mock_instance = mock_type(instance)
     724              # Check that the mock and the real thing bahave the same
     725              # __aiter__ is not actually async, so not a coroutinefunction
     726              self.assertFalse(iscoroutinefunction(instance.__aiter__))
     727              self.assertFalse(iscoroutinefunction(mock_instance.__aiter__))
     728              # __anext__ is async
     729              self.assertTrue(iscoroutinefunction(instance.__anext__))
     730              self.assertTrue(iscoroutinefunction(mock_instance.__anext__))
     731  
     732          for mock_type in [AsyncMock, MagicMock]:
     733              with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
     734                  inner_test(mock_type)
     735  
     736  
     737      def test_mock_async_for(self):
     738          async def iterate(iterator):
     739              accumulator = []
     740              async for item in iterator:
     741                  accumulator.append(item)
     742  
     743              return accumulator
     744  
     745          expected = ["FOO", "BAR", "BAZ"]
     746          def test_default(mock_type):
     747              mock_instance = mock_type(self.WithAsyncIterator())
     748              self.assertEqual(run(iterate(mock_instance)), [])
     749  
     750  
     751          def test_set_return_value(mock_type):
     752              mock_instance = mock_type(self.WithAsyncIterator())
     753              mock_instance.__aiter__.return_value = expected[:]
     754              self.assertEqual(run(iterate(mock_instance)), expected)
     755  
     756          def test_set_return_value_iter(mock_type):
     757              mock_instance = mock_type(self.WithAsyncIterator())
     758              mock_instance.__aiter__.return_value = iter(expected[:])
     759              self.assertEqual(run(iterate(mock_instance)), expected)
     760  
     761          for mock_type in [AsyncMock, MagicMock]:
     762              with self.subTest(f"default value with {mock_type}"):
     763                  test_default(mock_type)
     764  
     765              with self.subTest(f"set return_value with {mock_type}"):
     766                  test_set_return_value(mock_type)
     767  
     768              with self.subTest(f"set return_value iterator with {mock_type}"):
     769                  test_set_return_value_iter(mock_type)
     770  
     771  
     772  class ESC[4;38;5;81mAsyncMockAssert(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     773      def setUp(self):
     774          self.mock = AsyncMock()
     775  
     776      async def _runnable_test(self, *args, **kwargs):
     777          await self.mock(*args, **kwargs)
     778  
     779      async def _await_coroutine(self, coroutine):
     780          return await coroutine
     781  
     782      def test_assert_called_but_not_awaited(self):
     783          mock = AsyncMock(AsyncClass)
     784          with assertNeverAwaited(self):
     785              mock.async_method()
     786          self.assertTrue(iscoroutinefunction(mock.async_method))
     787          mock.async_method.assert_called()
     788          mock.async_method.assert_called_once()
     789          mock.async_method.assert_called_once_with()
     790          with self.assertRaises(AssertionError):
     791              mock.assert_awaited()
     792          with self.assertRaises(AssertionError):
     793              mock.async_method.assert_awaited()
     794  
     795      def test_assert_called_then_awaited(self):
     796          mock = AsyncMock(AsyncClass)
     797          mock_coroutine = mock.async_method()
     798          mock.async_method.assert_called()
     799          mock.async_method.assert_called_once()
     800          mock.async_method.assert_called_once_with()
     801          with self.assertRaises(AssertionError):
     802              mock.async_method.assert_awaited()
     803  
     804          run(self._await_coroutine(mock_coroutine))
     805          # Assert we haven't re-called the function
     806          mock.async_method.assert_called_once()
     807          mock.async_method.assert_awaited()
     808          mock.async_method.assert_awaited_once()
     809          mock.async_method.assert_awaited_once_with()
     810  
     811      def test_assert_called_and_awaited_at_same_time(self):
     812          with self.assertRaises(AssertionError):
     813              self.mock.assert_awaited()
     814  
     815          with self.assertRaises(AssertionError):
     816              self.mock.assert_called()
     817  
     818          run(self._runnable_test())
     819          self.mock.assert_called_once()
     820          self.mock.assert_awaited_once()
     821  
     822      def test_assert_called_twice_and_awaited_once(self):
     823          mock = AsyncMock(AsyncClass)
     824          coroutine = mock.async_method()
     825          # The first call will be awaited so no warning there
     826          # But this call will never get awaited, so it will warn here
     827          with assertNeverAwaited(self):
     828              mock.async_method()
     829          with self.assertRaises(AssertionError):
     830              mock.async_method.assert_awaited()
     831          mock.async_method.assert_called()
     832          run(self._await_coroutine(coroutine))
     833          mock.async_method.assert_awaited()
     834          mock.async_method.assert_awaited_once()
     835  
     836      def test_assert_called_once_and_awaited_twice(self):
     837          mock = AsyncMock(AsyncClass)
     838          coroutine = mock.async_method()
     839          mock.async_method.assert_called_once()
     840          run(self._await_coroutine(coroutine))
     841          with self.assertRaises(RuntimeError):
     842              # Cannot reuse already awaited coroutine
     843              run(self._await_coroutine(coroutine))
     844          mock.async_method.assert_awaited()
     845  
     846      def test_assert_awaited_but_not_called(self):
     847          with self.assertRaises(AssertionError):
     848              self.mock.assert_awaited()
     849          with self.assertRaises(AssertionError):
     850              self.mock.assert_called()
     851          with self.assertRaises(TypeError):
     852              # You cannot await an AsyncMock, it must be a coroutine
     853              run(self._await_coroutine(self.mock))
     854  
     855          with self.assertRaises(AssertionError):
     856              self.mock.assert_awaited()
     857          with self.assertRaises(AssertionError):
     858              self.mock.assert_called()
     859  
     860      def test_assert_has_calls_not_awaits(self):
     861          kalls = [call('foo')]
     862          with assertNeverAwaited(self):
     863              self.mock('foo')
     864          self.mock.assert_has_calls(kalls)
     865          with self.assertRaises(AssertionError):
     866              self.mock.assert_has_awaits(kalls)
     867  
     868      def test_assert_has_mock_calls_on_async_mock_no_spec(self):
     869          with assertNeverAwaited(self):
     870              self.mock()
     871          kalls_empty = [('', (), {})]
     872          self.assertEqual(self.mock.mock_calls, kalls_empty)
     873  
     874          with assertNeverAwaited(self):
     875              self.mock('foo')
     876          with assertNeverAwaited(self):
     877              self.mock('baz')
     878          mock_kalls = ([call(), call('foo'), call('baz')])
     879          self.assertEqual(self.mock.mock_calls, mock_kalls)
     880  
     881      def test_assert_has_mock_calls_on_async_mock_with_spec(self):
     882          a_class_mock = AsyncMock(AsyncClass)
     883          with assertNeverAwaited(self):
     884              a_class_mock.async_method()
     885          kalls_empty = [('', (), {})]
     886          self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
     887          self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
     888  
     889          with assertNeverAwaited(self):
     890              a_class_mock.async_method(1, 2, 3, a=4, b=5)
     891          method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
     892          mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
     893          self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
     894          self.assertEqual(a_class_mock.mock_calls, mock_kalls)
     895  
     896      def test_async_method_calls_recorded(self):
     897          with assertNeverAwaited(self):
     898              self.mock.something(3, fish=None)
     899          with assertNeverAwaited(self):
     900              self.mock.something_else.something(6, cake=sentinel.Cake)
     901  
     902          self.assertEqual(self.mock.method_calls, [
     903              ("something", (3,), {'fish': None}),
     904              ("something_else.something", (6,), {'cake': sentinel.Cake})
     905          ],
     906              "method calls not recorded correctly")
     907          self.assertEqual(self.mock.something_else.method_calls,
     908                           [("something", (6,), {'cake': sentinel.Cake})],
     909                           "method calls not recorded correctly")
     910  
     911      def test_async_arg_lists(self):
     912          def assert_attrs(mock):
     913              names = ('call_args_list', 'method_calls', 'mock_calls')
     914              for name in names:
     915                  attr = getattr(mock, name)
     916                  self.assertIsInstance(attr, _CallList)
     917                  self.assertIsInstance(attr, list)
     918                  self.assertEqual(attr, [])
     919  
     920          assert_attrs(self.mock)
     921          with assertNeverAwaited(self):
     922              self.mock()
     923          with assertNeverAwaited(self):
     924              self.mock(1, 2)
     925          with assertNeverAwaited(self):
     926              self.mock(a=3)
     927  
     928          self.mock.reset_mock()
     929          assert_attrs(self.mock)
     930  
     931          a_mock = AsyncMock(AsyncClass)
     932          with assertNeverAwaited(self):
     933              a_mock.async_method()
     934          with assertNeverAwaited(self):
     935              a_mock.async_method(1, a=3)
     936  
     937          a_mock.reset_mock()
     938          assert_attrs(a_mock)
     939  
     940      def test_assert_awaited(self):
     941          with self.assertRaises(AssertionError):
     942              self.mock.assert_awaited()
     943  
     944          run(self._runnable_test())
     945          self.mock.assert_awaited()
     946  
     947      def test_assert_awaited_once(self):
     948          with self.assertRaises(AssertionError):
     949              self.mock.assert_awaited_once()
     950  
     951          run(self._runnable_test())
     952          self.mock.assert_awaited_once()
     953  
     954          run(self._runnable_test())
     955          with self.assertRaises(AssertionError):
     956              self.mock.assert_awaited_once()
     957  
     958      def test_assert_awaited_with(self):
     959          msg = 'Not awaited'
     960          with self.assertRaisesRegex(AssertionError, msg):
     961              self.mock.assert_awaited_with('foo')
     962  
     963          run(self._runnable_test())
     964          msg = 'expected await not found'
     965          with self.assertRaisesRegex(AssertionError, msg):
     966              self.mock.assert_awaited_with('foo')
     967  
     968          run(self._runnable_test('foo'))
     969          self.mock.assert_awaited_with('foo')
     970  
     971          run(self._runnable_test('SomethingElse'))
     972          with self.assertRaises(AssertionError):
     973              self.mock.assert_awaited_with('foo')
     974  
     975      def test_assert_awaited_once_with(self):
     976          with self.assertRaises(AssertionError):
     977              self.mock.assert_awaited_once_with('foo')
     978  
     979          run(self._runnable_test('foo'))
     980          self.mock.assert_awaited_once_with('foo')
     981  
     982          run(self._runnable_test('foo'))
     983          with self.assertRaises(AssertionError):
     984              self.mock.assert_awaited_once_with('foo')
     985  
     986      def test_assert_any_wait(self):
     987          with self.assertRaises(AssertionError):
     988              self.mock.assert_any_await('foo')
     989  
     990          run(self._runnable_test('baz'))
     991          with self.assertRaises(AssertionError):
     992              self.mock.assert_any_await('foo')
     993  
     994          run(self._runnable_test('foo'))
     995          self.mock.assert_any_await('foo')
     996  
     997          run(self._runnable_test('SomethingElse'))
     998          self.mock.assert_any_await('foo')
     999  
    1000      def test_assert_has_awaits_no_order(self):
    1001          calls = [call('foo'), call('baz')]
    1002  
    1003          with self.assertRaises(AssertionError) as cm:
    1004              self.mock.assert_has_awaits(calls)
    1005          self.assertEqual(len(cm.exception.args), 1)
    1006  
    1007          run(self._runnable_test('foo'))
    1008          with self.assertRaises(AssertionError):
    1009              self.mock.assert_has_awaits(calls)
    1010  
    1011          run(self._runnable_test('foo'))
    1012          with self.assertRaises(AssertionError):
    1013              self.mock.assert_has_awaits(calls)
    1014  
    1015          run(self._runnable_test('baz'))
    1016          self.mock.assert_has_awaits(calls)
    1017  
    1018          run(self._runnable_test('SomethingElse'))
    1019          self.mock.assert_has_awaits(calls)
    1020  
    1021      def test_awaits_asserts_with_any(self):
    1022          class ESC[4;38;5;81mFoo:
    1023              def __eq__(self, other): pass
    1024  
    1025          run(self._runnable_test(Foo(), 1))
    1026  
    1027          self.mock.assert_has_awaits([call(ANY, 1)])
    1028          self.mock.assert_awaited_with(ANY, 1)
    1029          self.mock.assert_any_await(ANY, 1)
    1030  
    1031      def test_awaits_asserts_with_spec_and_any(self):
    1032          class ESC[4;38;5;81mFoo:
    1033              def __eq__(self, other): pass
    1034  
    1035          mock_with_spec = AsyncMock(spec=Foo)
    1036  
    1037          async def _custom_mock_runnable_test(*args):
    1038              await mock_with_spec(*args)
    1039  
    1040          run(_custom_mock_runnable_test(Foo(), 1))
    1041          mock_with_spec.assert_has_awaits([call(ANY, 1)])
    1042          mock_with_spec.assert_awaited_with(ANY, 1)
    1043          mock_with_spec.assert_any_await(ANY, 1)
    1044  
    1045      def test_assert_has_awaits_ordered(self):
    1046          calls = [call('foo'), call('baz')]
    1047          with self.assertRaises(AssertionError):
    1048              self.mock.assert_has_awaits(calls, any_order=True)
    1049  
    1050          run(self._runnable_test('baz'))
    1051          with self.assertRaises(AssertionError):
    1052              self.mock.assert_has_awaits(calls, any_order=True)
    1053  
    1054          run(self._runnable_test('bamf'))
    1055          with self.assertRaises(AssertionError):
    1056              self.mock.assert_has_awaits(calls, any_order=True)
    1057  
    1058          run(self._runnable_test('foo'))
    1059          self.mock.assert_has_awaits(calls, any_order=True)
    1060  
    1061          run(self._runnable_test('qux'))
    1062          self.mock.assert_has_awaits(calls, any_order=True)
    1063  
    1064      def test_assert_not_awaited(self):
    1065          self.mock.assert_not_awaited()
    1066  
    1067          run(self._runnable_test())
    1068          with self.assertRaises(AssertionError):
    1069              self.mock.assert_not_awaited()
    1070  
    1071      def test_assert_has_awaits_not_matching_spec_error(self):
    1072          async def f(x=None): pass
    1073  
    1074          self.mock = AsyncMock(spec=f)
    1075          run(self._runnable_test(1))
    1076  
    1077          with self.assertRaisesRegex(
    1078                  AssertionError,
    1079                  '^{}$'.format(
    1080                      re.escape('Awaits not found.\n'
    1081                                'Expected: [call()]\n'
    1082                                'Actual: [call(1)]'))) as cm:
    1083              self.mock.assert_has_awaits([call()])
    1084          self.assertIsNone(cm.exception.__cause__)
    1085  
    1086          with self.assertRaisesRegex(
    1087                  AssertionError,
    1088                  '^{}$'.format(
    1089                      re.escape(
    1090                          'Error processing expected awaits.\n'
    1091                          "Errors: [None, TypeError('too many positional "
    1092                          "arguments')]\n"
    1093                          'Expected: [call(), call(1, 2)]\n'
    1094                          'Actual: [call(1)]'))) as cm:
    1095              self.mock.assert_has_awaits([call(), call(1, 2)])
    1096          self.assertIsInstance(cm.exception.__cause__, TypeError)
    1097  
    1098  
    1099  if __name__ == '__main__':
    1100      unittest.main()