python (3.12.0)
1 import asyncio
2 import gc
3 import inspect
4 import re
5 import unittest
6 from contextlib import contextmanager
7 from test import support
8
9 support.requires_working_socket(module=True)
10
11 from asyncio import run, iscoroutinefunction
12 from unittest import IsolatedAsyncioTestCase
13 from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
14 create_autospec, sentinel, _CallList, seal)
15
16
17 def tearDownModule():
18 asyncio.set_event_loop_policy(None)
19
20
21 class ESC[4;38;5;81mAsyncClass:
22 def __init__(self): pass
23 async def async_method(self): pass
24 def normal_method(self): pass
25
26 @classmethod
27 async def async_class_method(cls): pass
28
29 @staticmethod
30 async def async_static_method(): pass
31
32
33 class ESC[4;38;5;81mAwaitableClass:
34 def __await__(self): yield
35
36 async def async_func(): pass
37
38 async def async_func_args(a, b, *, c): pass
39
40 def normal_func(): pass
41
42 class ESC[4;38;5;81mNormalClass(ESC[4;38;5;149mobject):
43 def a(self): pass
44
45
46 async_foo_name = f'{__name__}.AsyncClass'
47 normal_foo_name = f'{__name__}.NormalClass'
48
49
50 @contextmanager
51 def assertNeverAwaited(test):
52 with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"):
53 yield
54 # In non-CPython implementations of Python, this is needed because timely
55 # deallocation is not guaranteed by the garbage collector.
56 gc.collect()
57
58
59 class ESC[4;38;5;81mAsyncPatchDecoratorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
60 def test_is_coroutine_function_patch(self):
61 @patch.object(AsyncClass, 'async_method')
62 def test_async(mock_method):
63 self.assertTrue(iscoroutinefunction(mock_method))
64 test_async()
65
66 def test_is_async_patch(self):
67 @patch.object(AsyncClass, 'async_method')
68 def test_async(mock_method):
69 m = mock_method()
70 self.assertTrue(inspect.isawaitable(m))
71 run(m)
72
73 @patch(f'{async_foo_name}.async_method')
74 def test_no_parent_attribute(mock_method):
75 m = mock_method()
76 self.assertTrue(inspect.isawaitable(m))
77 run(m)
78
79 test_async()
80 test_no_parent_attribute()
81
82 def test_is_AsyncMock_patch(self):
83 @patch.object(AsyncClass, 'async_method')
84 def test_async(mock_method):
85 self.assertIsInstance(mock_method, AsyncMock)
86
87 test_async()
88
89 def test_is_AsyncMock_patch_staticmethod(self):
90 @patch.object(AsyncClass, 'async_static_method')
91 def test_async(mock_method):
92 self.assertIsInstance(mock_method, AsyncMock)
93
94 test_async()
95
96 def test_is_AsyncMock_patch_classmethod(self):
97 @patch.object(AsyncClass, 'async_class_method')
98 def test_async(mock_method):
99 self.assertIsInstance(mock_method, AsyncMock)
100
101 test_async()
102
103 def test_async_def_patch(self):
104 @patch(f"{__name__}.async_func", return_value=1)
105 @patch(f"{__name__}.async_func_args", return_value=2)
106 async def test_async(func_args_mock, func_mock):
107 self.assertEqual(func_args_mock._mock_name, "async_func_args")
108 self.assertEqual(func_mock._mock_name, "async_func")
109
110 self.assertIsInstance(async_func, AsyncMock)
111 self.assertIsInstance(async_func_args, AsyncMock)
112
113 self.assertEqual(await async_func(), 1)
114 self.assertEqual(await async_func_args(1, 2, c=3), 2)
115
116 run(test_async())
117 self.assertTrue(inspect.iscoroutinefunction(async_func))
118
119
120 class ESC[4;38;5;81mAsyncPatchCMTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
121 def test_is_async_function_cm(self):
122 def test_async():
123 with patch.object(AsyncClass, 'async_method') as mock_method:
124 self.assertTrue(iscoroutinefunction(mock_method))
125
126 test_async()
127
128 def test_is_async_cm(self):
129 def test_async():
130 with patch.object(AsyncClass, 'async_method') as mock_method:
131 m = mock_method()
132 self.assertTrue(inspect.isawaitable(m))
133 run(m)
134
135 test_async()
136
137 def test_is_AsyncMock_cm(self):
138 def test_async():
139 with patch.object(AsyncClass, 'async_method') as mock_method:
140 self.assertIsInstance(mock_method, AsyncMock)
141
142 test_async()
143
144 def test_async_def_cm(self):
145 async def test_async():
146 with patch(f"{__name__}.async_func", AsyncMock()):
147 self.assertIsInstance(async_func, AsyncMock)
148 self.assertTrue(inspect.iscoroutinefunction(async_func))
149
150 run(test_async())
151
152 def test_patch_dict_async_def(self):
153 foo = {'a': 'a'}
154 @patch.dict(foo, {'a': 'b'})
155 async def test_async():
156 self.assertEqual(foo['a'], 'b')
157
158 self.assertTrue(iscoroutinefunction(test_async))
159 run(test_async())
160
161 def test_patch_dict_async_def_context(self):
162 foo = {'a': 'a'}
163 async def test_async():
164 with patch.dict(foo, {'a': 'b'}):
165 self.assertEqual(foo['a'], 'b')
166
167 run(test_async())
168
169
170 class ESC[4;38;5;81mAsyncMockTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
171 def test_iscoroutinefunction_default(self):
172 mock = AsyncMock()
173 self.assertTrue(iscoroutinefunction(mock))
174
175 def test_iscoroutinefunction_function(self):
176 async def foo(): pass
177 mock = AsyncMock(foo)
178 self.assertTrue(iscoroutinefunction(mock))
179 self.assertTrue(inspect.iscoroutinefunction(mock))
180
181 def test_isawaitable(self):
182 mock = AsyncMock()
183 m = mock()
184 self.assertTrue(inspect.isawaitable(m))
185 run(m)
186 self.assertIn('assert_awaited', dir(mock))
187
188 def test_iscoroutinefunction_normal_function(self):
189 def foo(): pass
190 mock = AsyncMock(foo)
191 self.assertTrue(iscoroutinefunction(mock))
192 self.assertTrue(inspect.iscoroutinefunction(mock))
193
194 def test_future_isfuture(self):
195 loop = asyncio.new_event_loop()
196 fut = loop.create_future()
197 loop.stop()
198 loop.close()
199 mock = AsyncMock(fut)
200 self.assertIsInstance(mock, asyncio.Future)
201
202
203 class ESC[4;38;5;81mAsyncAutospecTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
204 def test_is_AsyncMock_patch(self):
205 @patch(async_foo_name, autospec=True)
206 def test_async(mock_method):
207 self.assertIsInstance(mock_method.async_method, AsyncMock)
208 self.assertIsInstance(mock_method, MagicMock)
209
210 @patch(async_foo_name, autospec=True)
211 def test_normal_method(mock_method):
212 self.assertIsInstance(mock_method.normal_method, MagicMock)
213
214 test_async()
215 test_normal_method()
216
217 def test_create_autospec_instance(self):
218 with self.assertRaises(RuntimeError):
219 create_autospec(async_func, instance=True)
220
221 def test_create_autospec(self):
222 spec = create_autospec(async_func_args)
223 awaitable = spec(1, 2, c=3)
224 async def main():
225 await awaitable
226
227 self.assertEqual(spec.await_count, 0)
228 self.assertIsNone(spec.await_args)
229 self.assertEqual(spec.await_args_list, [])
230 spec.assert_not_awaited()
231
232 run(main())
233
234 self.assertTrue(iscoroutinefunction(spec))
235 self.assertTrue(asyncio.iscoroutine(awaitable))
236 self.assertEqual(spec.await_count, 1)
237 self.assertEqual(spec.await_args, call(1, 2, c=3))
238 self.assertEqual(spec.await_args_list, [call(1, 2, c=3)])
239 spec.assert_awaited_once()
240 spec.assert_awaited_once_with(1, 2, c=3)
241 spec.assert_awaited_with(1, 2, c=3)
242 spec.assert_awaited()
243
244 with self.assertRaises(AssertionError):
245 spec.assert_any_await(e=1)
246
247
248 def test_patch_with_autospec(self):
249
250 async def test_async():
251 with patch(f"{__name__}.async_func_args", autospec=True) as mock_method:
252 awaitable = mock_method(1, 2, c=3)
253 self.assertIsInstance(mock_method.mock, AsyncMock)
254
255 self.assertTrue(iscoroutinefunction(mock_method))
256 self.assertTrue(asyncio.iscoroutine(awaitable))
257 self.assertTrue(inspect.isawaitable(awaitable))
258
259 # Verify the default values during mock setup
260 self.assertEqual(mock_method.await_count, 0)
261 self.assertEqual(mock_method.await_args_list, [])
262 self.assertIsNone(mock_method.await_args)
263 mock_method.assert_not_awaited()
264
265 await awaitable
266
267 self.assertEqual(mock_method.await_count, 1)
268 self.assertEqual(mock_method.await_args, call(1, 2, c=3))
269 self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)])
270 mock_method.assert_awaited_once()
271 mock_method.assert_awaited_once_with(1, 2, c=3)
272 mock_method.assert_awaited_with(1, 2, c=3)
273 mock_method.assert_awaited()
274
275 mock_method.reset_mock()
276 self.assertEqual(mock_method.await_count, 0)
277 self.assertIsNone(mock_method.await_args)
278 self.assertEqual(mock_method.await_args_list, [])
279
280 run(test_async())
281
282
283 class ESC[4;38;5;81mAsyncSpecTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
284 def test_spec_normal_methods_on_class(self):
285 def inner_test(mock_type):
286 mock = mock_type(AsyncClass)
287 self.assertIsInstance(mock.async_method, AsyncMock)
288 self.assertIsInstance(mock.normal_method, MagicMock)
289
290 for mock_type in [AsyncMock, MagicMock]:
291 with self.subTest(f"test method types with {mock_type}"):
292 inner_test(mock_type)
293
294 def test_spec_normal_methods_on_class_with_mock(self):
295 mock = Mock(AsyncClass)
296 self.assertIsInstance(mock.async_method, AsyncMock)
297 self.assertIsInstance(mock.normal_method, Mock)
298
299 def test_spec_normal_methods_on_class_with_mock_seal(self):
300 mock = Mock(AsyncClass)
301 seal(mock)
302 with self.assertRaises(AttributeError):
303 mock.normal_method
304 with self.assertRaises(AttributeError):
305 mock.async_method
306
307 def test_spec_async_attributes_instance(self):
308 async_instance = AsyncClass()
309 async_instance.async_func_attr = async_func
310 async_instance.later_async_func_attr = normal_func
311
312 mock_async_instance = Mock(spec_set=async_instance)
313
314 async_instance.later_async_func_attr = async_func
315
316 self.assertIsInstance(mock_async_instance.async_func_attr, AsyncMock)
317 # only the shape of the spec at the time of mock construction matters
318 self.assertNotIsInstance(mock_async_instance.later_async_func_attr, AsyncMock)
319
320 def test_spec_mock_type_kw(self):
321 def inner_test(mock_type):
322 async_mock = mock_type(spec=async_func)
323 self.assertIsInstance(async_mock, mock_type)
324 with assertNeverAwaited(self):
325 self.assertTrue(inspect.isawaitable(async_mock()))
326
327 sync_mock = mock_type(spec=normal_func)
328 self.assertIsInstance(sync_mock, mock_type)
329
330 for mock_type in [AsyncMock, MagicMock, Mock]:
331 with self.subTest(f"test spec kwarg with {mock_type}"):
332 inner_test(mock_type)
333
334 def test_spec_mock_type_positional(self):
335 def inner_test(mock_type):
336 async_mock = mock_type(async_func)
337 self.assertIsInstance(async_mock, mock_type)
338 with assertNeverAwaited(self):
339 self.assertTrue(inspect.isawaitable(async_mock()))
340
341 sync_mock = mock_type(normal_func)
342 self.assertIsInstance(sync_mock, mock_type)
343
344 for mock_type in [AsyncMock, MagicMock, Mock]:
345 with self.subTest(f"test spec positional with {mock_type}"):
346 inner_test(mock_type)
347
348 def test_spec_as_normal_kw_AsyncMock(self):
349 mock = AsyncMock(spec=normal_func)
350 self.assertIsInstance(mock, AsyncMock)
351 m = mock()
352 self.assertTrue(inspect.isawaitable(m))
353 run(m)
354
355 def test_spec_as_normal_positional_AsyncMock(self):
356 mock = AsyncMock(normal_func)
357 self.assertIsInstance(mock, AsyncMock)
358 m = mock()
359 self.assertTrue(inspect.isawaitable(m))
360 run(m)
361
362 def test_spec_async_mock(self):
363 @patch.object(AsyncClass, 'async_method', spec=True)
364 def test_async(mock_method):
365 self.assertIsInstance(mock_method, AsyncMock)
366
367 test_async()
368
369 def test_spec_parent_not_async_attribute_is(self):
370 @patch(async_foo_name, spec=True)
371 def test_async(mock_method):
372 self.assertIsInstance(mock_method, MagicMock)
373 self.assertIsInstance(mock_method.async_method, AsyncMock)
374
375 test_async()
376
377 def test_target_async_spec_not(self):
378 @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
379 def test_async_attribute(mock_method):
380 self.assertIsInstance(mock_method, MagicMock)
381 self.assertFalse(inspect.iscoroutine(mock_method))
382 self.assertFalse(inspect.isawaitable(mock_method))
383
384 test_async_attribute()
385
386 def test_target_not_async_spec_is(self):
387 @patch.object(NormalClass, 'a', spec=async_func)
388 def test_attribute_not_async_spec_is(mock_async_func):
389 self.assertIsInstance(mock_async_func, AsyncMock)
390 test_attribute_not_async_spec_is()
391
392 def test_spec_async_attributes(self):
393 @patch(normal_foo_name, spec=AsyncClass)
394 def test_async_attributes_coroutines(MockNormalClass):
395 self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
396 self.assertIsInstance(MockNormalClass, MagicMock)
397
398 test_async_attributes_coroutines()
399
400
401 class ESC[4;38;5;81mAsyncSpecSetTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
402 def test_is_AsyncMock_patch(self):
403 @patch.object(AsyncClass, 'async_method', spec_set=True)
404 def test_async(async_method):
405 self.assertIsInstance(async_method, AsyncMock)
406 test_async()
407
408 def test_is_async_AsyncMock(self):
409 mock = AsyncMock(spec_set=AsyncClass.async_method)
410 self.assertTrue(iscoroutinefunction(mock))
411 self.assertIsInstance(mock, AsyncMock)
412
413 def test_is_child_AsyncMock(self):
414 mock = MagicMock(spec_set=AsyncClass)
415 self.assertTrue(iscoroutinefunction(mock.async_method))
416 self.assertFalse(iscoroutinefunction(mock.normal_method))
417 self.assertIsInstance(mock.async_method, AsyncMock)
418 self.assertIsInstance(mock.normal_method, MagicMock)
419 self.assertIsInstance(mock, MagicMock)
420
421 def test_magicmock_lambda_spec(self):
422 mock_obj = MagicMock()
423 mock_obj.mock_func = MagicMock(spec=lambda x: x)
424
425 with patch.object(mock_obj, "mock_func") as cm:
426 self.assertIsInstance(cm, MagicMock)
427
428
429 class ESC[4;38;5;81mAsyncArguments(ESC[4;38;5;149mIsolatedAsyncioTestCase):
430 async def test_add_return_value(self):
431 async def addition(self, var): pass
432
433 mock = AsyncMock(addition, return_value=10)
434 output = await mock(5)
435
436 self.assertEqual(output, 10)
437
438 async def test_add_side_effect_exception(self):
439 class ESC[4;38;5;81mCustomError(ESC[4;38;5;149mException): pass
440 async def addition(var): pass
441 mock = AsyncMock(addition, side_effect=CustomError('side-effect'))
442 with self.assertRaisesRegex(CustomError, 'side-effect'):
443 await mock(5)
444
445 async def test_add_side_effect_coroutine(self):
446 async def addition(var):
447 return var + 1
448 mock = AsyncMock(side_effect=addition)
449 result = await mock(5)
450 self.assertEqual(result, 6)
451
452 async def test_add_side_effect_normal_function(self):
453 def addition(var):
454 return var + 1
455 mock = AsyncMock(side_effect=addition)
456 result = await mock(5)
457 self.assertEqual(result, 6)
458
459 async def test_add_side_effect_iterable(self):
460 vals = [1, 2, 3]
461 mock = AsyncMock(side_effect=vals)
462 for item in vals:
463 self.assertEqual(await mock(), item)
464
465 with self.assertRaises(StopAsyncIteration) as e:
466 await mock()
467
468 async def test_add_side_effect_exception_iterable(self):
469 class ESC[4;38;5;81mSampleException(ESC[4;38;5;149mException):
470 pass
471
472 vals = [1, SampleException("foo")]
473 mock = AsyncMock(side_effect=vals)
474 self.assertEqual(await mock(), 1)
475
476 with self.assertRaises(SampleException) as e:
477 await mock()
478
479 async def test_return_value_AsyncMock(self):
480 value = AsyncMock(return_value=10)
481 mock = AsyncMock(return_value=value)
482 result = await mock()
483 self.assertIs(result, value)
484
485 async def test_return_value_awaitable(self):
486 fut = asyncio.Future()
487 fut.set_result(None)
488 mock = AsyncMock(return_value=fut)
489 result = await mock()
490 self.assertIsInstance(result, asyncio.Future)
491
492 async def test_side_effect_awaitable_values(self):
493 fut = asyncio.Future()
494 fut.set_result(None)
495
496 mock = AsyncMock(side_effect=[fut])
497 result = await mock()
498 self.assertIsInstance(result, asyncio.Future)
499
500 with self.assertRaises(StopAsyncIteration):
501 await mock()
502
503 async def test_side_effect_is_AsyncMock(self):
504 effect = AsyncMock(return_value=10)
505 mock = AsyncMock(side_effect=effect)
506
507 result = await mock()
508 self.assertEqual(result, 10)
509
510 async def test_wraps_coroutine(self):
511 value = asyncio.Future()
512
513 ran = False
514 async def inner():
515 nonlocal ran
516 ran = True
517 return value
518
519 mock = AsyncMock(wraps=inner)
520 result = await mock()
521 self.assertEqual(result, value)
522 mock.assert_awaited()
523 self.assertTrue(ran)
524
525 async def test_wraps_normal_function(self):
526 value = 1
527
528 ran = False
529 def inner():
530 nonlocal ran
531 ran = True
532 return value
533
534 mock = AsyncMock(wraps=inner)
535 result = await mock()
536 self.assertEqual(result, value)
537 mock.assert_awaited()
538 self.assertTrue(ran)
539
540 async def test_await_args_list_order(self):
541 async_mock = AsyncMock()
542 mock2 = async_mock(2)
543 mock1 = async_mock(1)
544 await mock1
545 await mock2
546 async_mock.assert_has_awaits([call(1), call(2)])
547 self.assertEqual(async_mock.await_args_list, [call(1), call(2)])
548 self.assertEqual(async_mock.call_args_list, [call(2), call(1)])
549
550
551 class ESC[4;38;5;81mAsyncMagicMethods(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
552 def test_async_magic_methods_return_async_mocks(self):
553 m_mock = MagicMock()
554 self.assertIsInstance(m_mock.__aenter__, AsyncMock)
555 self.assertIsInstance(m_mock.__aexit__, AsyncMock)
556 self.assertIsInstance(m_mock.__anext__, AsyncMock)
557 # __aiter__ is actually a synchronous object
558 # so should return a MagicMock
559 self.assertIsInstance(m_mock.__aiter__, MagicMock)
560
561 def test_sync_magic_methods_return_magic_mocks(self):
562 a_mock = AsyncMock()
563 self.assertIsInstance(a_mock.__enter__, MagicMock)
564 self.assertIsInstance(a_mock.__exit__, MagicMock)
565 self.assertIsInstance(a_mock.__next__, MagicMock)
566 self.assertIsInstance(a_mock.__len__, MagicMock)
567
568 def test_magicmock_has_async_magic_methods(self):
569 m_mock = MagicMock()
570 self.assertTrue(hasattr(m_mock, "__aenter__"))
571 self.assertTrue(hasattr(m_mock, "__aexit__"))
572 self.assertTrue(hasattr(m_mock, "__anext__"))
573
574 def test_asyncmock_has_sync_magic_methods(self):
575 a_mock = AsyncMock()
576 self.assertTrue(hasattr(a_mock, "__enter__"))
577 self.assertTrue(hasattr(a_mock, "__exit__"))
578 self.assertTrue(hasattr(a_mock, "__next__"))
579 self.assertTrue(hasattr(a_mock, "__len__"))
580
581 def test_magic_methods_are_async_functions(self):
582 m_mock = MagicMock()
583 self.assertIsInstance(m_mock.__aenter__, AsyncMock)
584 self.assertIsInstance(m_mock.__aexit__, AsyncMock)
585 # AsyncMocks are also coroutine functions
586 self.assertTrue(iscoroutinefunction(m_mock.__aenter__))
587 self.assertTrue(iscoroutinefunction(m_mock.__aexit__))
588
589 class ESC[4;38;5;81mAsyncContextManagerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
590
591 class ESC[4;38;5;81mWithAsyncContextManager:
592 async def __aenter__(self, *args, **kwargs): pass
593
594 async def __aexit__(self, *args, **kwargs): pass
595
596 class ESC[4;38;5;81mWithSyncContextManager:
597 def __enter__(self, *args, **kwargs): pass
598
599 def __exit__(self, *args, **kwargs): pass
600
601 class ESC[4;38;5;81mProductionCode:
602 # Example real-world(ish) code
603 def __init__(self):
604 self.session = None
605
606 async def main(self):
607 async with self.session.post('https://python.org') as response:
608 val = await response.json()
609 return val
610
611 def test_set_return_value_of_aenter(self):
612 def inner_test(mock_type):
613 pc = self.ProductionCode()
614 pc.session = MagicMock(name='sessionmock')
615 cm = mock_type(name='magic_cm')
616 response = AsyncMock(name='response')
617 response.json = AsyncMock(return_value={'json': 123})
618 cm.__aenter__.return_value = response
619 pc.session.post.return_value = cm
620 result = run(pc.main())
621 self.assertEqual(result, {'json': 123})
622
623 for mock_type in [AsyncMock, MagicMock]:
624 with self.subTest(f"test set return value of aenter with {mock_type}"):
625 inner_test(mock_type)
626
627 def test_mock_supports_async_context_manager(self):
628 def inner_test(mock_type):
629 called = False
630 cm = self.WithAsyncContextManager()
631 cm_mock = mock_type(cm)
632
633 async def use_context_manager():
634 nonlocal called
635 async with cm_mock as result:
636 called = True
637 return result
638
639 cm_result = run(use_context_manager())
640 self.assertTrue(called)
641 self.assertTrue(cm_mock.__aenter__.called)
642 self.assertTrue(cm_mock.__aexit__.called)
643 cm_mock.__aenter__.assert_awaited()
644 cm_mock.__aexit__.assert_awaited()
645 # We mock __aenter__ so it does not return self
646 self.assertIsNot(cm_mock, cm_result)
647
648 for mock_type in [AsyncMock, MagicMock]:
649 with self.subTest(f"test context manager magics with {mock_type}"):
650 inner_test(mock_type)
651
652
653 def test_mock_customize_async_context_manager(self):
654 instance = self.WithAsyncContextManager()
655 mock_instance = MagicMock(instance)
656
657 expected_result = object()
658 mock_instance.__aenter__.return_value = expected_result
659
660 async def use_context_manager():
661 async with mock_instance as result:
662 return result
663
664 self.assertIs(run(use_context_manager()), expected_result)
665
666 def test_mock_customize_async_context_manager_with_coroutine(self):
667 enter_called = False
668 exit_called = False
669
670 async def enter_coroutine(*args):
671 nonlocal enter_called
672 enter_called = True
673
674 async def exit_coroutine(*args):
675 nonlocal exit_called
676 exit_called = True
677
678 instance = self.WithAsyncContextManager()
679 mock_instance = MagicMock(instance)
680
681 mock_instance.__aenter__ = enter_coroutine
682 mock_instance.__aexit__ = exit_coroutine
683
684 async def use_context_manager():
685 async with mock_instance:
686 pass
687
688 run(use_context_manager())
689 self.assertTrue(enter_called)
690 self.assertTrue(exit_called)
691
692 def test_context_manager_raise_exception_by_default(self):
693 async def raise_in(context_manager):
694 async with context_manager:
695 raise TypeError()
696
697 instance = self.WithAsyncContextManager()
698 mock_instance = MagicMock(instance)
699 with self.assertRaises(TypeError):
700 run(raise_in(mock_instance))
701
702
703 class ESC[4;38;5;81mAsyncIteratorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
704 class ESC[4;38;5;81mWithAsyncIterator(ESC[4;38;5;149mobject):
705 def __init__(self):
706 self.items = ["foo", "NormalFoo", "baz"]
707
708 def __aiter__(self): pass
709
710 async def __anext__(self): pass
711
712 def test_aiter_set_return_value(self):
713 mock_iter = AsyncMock(name="tester")
714 mock_iter.__aiter__.return_value = [1, 2, 3]
715 async def main():
716 return [i async for i in mock_iter]
717 result = run(main())
718 self.assertEqual(result, [1, 2, 3])
719
720 def test_mock_aiter_and_anext_asyncmock(self):
721 def inner_test(mock_type):
722 instance = self.WithAsyncIterator()
723 mock_instance = mock_type(instance)
724 # Check that the mock and the real thing bahave the same
725 # __aiter__ is not actually async, so not a coroutinefunction
726 self.assertFalse(iscoroutinefunction(instance.__aiter__))
727 self.assertFalse(iscoroutinefunction(mock_instance.__aiter__))
728 # __anext__ is async
729 self.assertTrue(iscoroutinefunction(instance.__anext__))
730 self.assertTrue(iscoroutinefunction(mock_instance.__anext__))
731
732 for mock_type in [AsyncMock, MagicMock]:
733 with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
734 inner_test(mock_type)
735
736
737 def test_mock_async_for(self):
738 async def iterate(iterator):
739 accumulator = []
740 async for item in iterator:
741 accumulator.append(item)
742
743 return accumulator
744
745 expected = ["FOO", "BAR", "BAZ"]
746 def test_default(mock_type):
747 mock_instance = mock_type(self.WithAsyncIterator())
748 self.assertEqual(run(iterate(mock_instance)), [])
749
750
751 def test_set_return_value(mock_type):
752 mock_instance = mock_type(self.WithAsyncIterator())
753 mock_instance.__aiter__.return_value = expected[:]
754 self.assertEqual(run(iterate(mock_instance)), expected)
755
756 def test_set_return_value_iter(mock_type):
757 mock_instance = mock_type(self.WithAsyncIterator())
758 mock_instance.__aiter__.return_value = iter(expected[:])
759 self.assertEqual(run(iterate(mock_instance)), expected)
760
761 for mock_type in [AsyncMock, MagicMock]:
762 with self.subTest(f"default value with {mock_type}"):
763 test_default(mock_type)
764
765 with self.subTest(f"set return_value with {mock_type}"):
766 test_set_return_value(mock_type)
767
768 with self.subTest(f"set return_value iterator with {mock_type}"):
769 test_set_return_value_iter(mock_type)
770
771
772 class ESC[4;38;5;81mAsyncMockAssert(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
773 def setUp(self):
774 self.mock = AsyncMock()
775
776 async def _runnable_test(self, *args, **kwargs):
777 await self.mock(*args, **kwargs)
778
779 async def _await_coroutine(self, coroutine):
780 return await coroutine
781
782 def test_assert_called_but_not_awaited(self):
783 mock = AsyncMock(AsyncClass)
784 with assertNeverAwaited(self):
785 mock.async_method()
786 self.assertTrue(iscoroutinefunction(mock.async_method))
787 mock.async_method.assert_called()
788 mock.async_method.assert_called_once()
789 mock.async_method.assert_called_once_with()
790 with self.assertRaises(AssertionError):
791 mock.assert_awaited()
792 with self.assertRaises(AssertionError):
793 mock.async_method.assert_awaited()
794
795 def test_assert_called_then_awaited(self):
796 mock = AsyncMock(AsyncClass)
797 mock_coroutine = mock.async_method()
798 mock.async_method.assert_called()
799 mock.async_method.assert_called_once()
800 mock.async_method.assert_called_once_with()
801 with self.assertRaises(AssertionError):
802 mock.async_method.assert_awaited()
803
804 run(self._await_coroutine(mock_coroutine))
805 # Assert we haven't re-called the function
806 mock.async_method.assert_called_once()
807 mock.async_method.assert_awaited()
808 mock.async_method.assert_awaited_once()
809 mock.async_method.assert_awaited_once_with()
810
811 def test_assert_called_and_awaited_at_same_time(self):
812 with self.assertRaises(AssertionError):
813 self.mock.assert_awaited()
814
815 with self.assertRaises(AssertionError):
816 self.mock.assert_called()
817
818 run(self._runnable_test())
819 self.mock.assert_called_once()
820 self.mock.assert_awaited_once()
821
822 def test_assert_called_twice_and_awaited_once(self):
823 mock = AsyncMock(AsyncClass)
824 coroutine = mock.async_method()
825 # The first call will be awaited so no warning there
826 # But this call will never get awaited, so it will warn here
827 with assertNeverAwaited(self):
828 mock.async_method()
829 with self.assertRaises(AssertionError):
830 mock.async_method.assert_awaited()
831 mock.async_method.assert_called()
832 run(self._await_coroutine(coroutine))
833 mock.async_method.assert_awaited()
834 mock.async_method.assert_awaited_once()
835
836 def test_assert_called_once_and_awaited_twice(self):
837 mock = AsyncMock(AsyncClass)
838 coroutine = mock.async_method()
839 mock.async_method.assert_called_once()
840 run(self._await_coroutine(coroutine))
841 with self.assertRaises(RuntimeError):
842 # Cannot reuse already awaited coroutine
843 run(self._await_coroutine(coroutine))
844 mock.async_method.assert_awaited()
845
846 def test_assert_awaited_but_not_called(self):
847 with self.assertRaises(AssertionError):
848 self.mock.assert_awaited()
849 with self.assertRaises(AssertionError):
850 self.mock.assert_called()
851 with self.assertRaises(TypeError):
852 # You cannot await an AsyncMock, it must be a coroutine
853 run(self._await_coroutine(self.mock))
854
855 with self.assertRaises(AssertionError):
856 self.mock.assert_awaited()
857 with self.assertRaises(AssertionError):
858 self.mock.assert_called()
859
860 def test_assert_has_calls_not_awaits(self):
861 kalls = [call('foo')]
862 with assertNeverAwaited(self):
863 self.mock('foo')
864 self.mock.assert_has_calls(kalls)
865 with self.assertRaises(AssertionError):
866 self.mock.assert_has_awaits(kalls)
867
868 def test_assert_has_mock_calls_on_async_mock_no_spec(self):
869 with assertNeverAwaited(self):
870 self.mock()
871 kalls_empty = [('', (), {})]
872 self.assertEqual(self.mock.mock_calls, kalls_empty)
873
874 with assertNeverAwaited(self):
875 self.mock('foo')
876 with assertNeverAwaited(self):
877 self.mock('baz')
878 mock_kalls = ([call(), call('foo'), call('baz')])
879 self.assertEqual(self.mock.mock_calls, mock_kalls)
880
881 def test_assert_has_mock_calls_on_async_mock_with_spec(self):
882 a_class_mock = AsyncMock(AsyncClass)
883 with assertNeverAwaited(self):
884 a_class_mock.async_method()
885 kalls_empty = [('', (), {})]
886 self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
887 self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
888
889 with assertNeverAwaited(self):
890 a_class_mock.async_method(1, 2, 3, a=4, b=5)
891 method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
892 mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
893 self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
894 self.assertEqual(a_class_mock.mock_calls, mock_kalls)
895
896 def test_async_method_calls_recorded(self):
897 with assertNeverAwaited(self):
898 self.mock.something(3, fish=None)
899 with assertNeverAwaited(self):
900 self.mock.something_else.something(6, cake=sentinel.Cake)
901
902 self.assertEqual(self.mock.method_calls, [
903 ("something", (3,), {'fish': None}),
904 ("something_else.something", (6,), {'cake': sentinel.Cake})
905 ],
906 "method calls not recorded correctly")
907 self.assertEqual(self.mock.something_else.method_calls,
908 [("something", (6,), {'cake': sentinel.Cake})],
909 "method calls not recorded correctly")
910
911 def test_async_arg_lists(self):
912 def assert_attrs(mock):
913 names = ('call_args_list', 'method_calls', 'mock_calls')
914 for name in names:
915 attr = getattr(mock, name)
916 self.assertIsInstance(attr, _CallList)
917 self.assertIsInstance(attr, list)
918 self.assertEqual(attr, [])
919
920 assert_attrs(self.mock)
921 with assertNeverAwaited(self):
922 self.mock()
923 with assertNeverAwaited(self):
924 self.mock(1, 2)
925 with assertNeverAwaited(self):
926 self.mock(a=3)
927
928 self.mock.reset_mock()
929 assert_attrs(self.mock)
930
931 a_mock = AsyncMock(AsyncClass)
932 with assertNeverAwaited(self):
933 a_mock.async_method()
934 with assertNeverAwaited(self):
935 a_mock.async_method(1, a=3)
936
937 a_mock.reset_mock()
938 assert_attrs(a_mock)
939
940 def test_assert_awaited(self):
941 with self.assertRaises(AssertionError):
942 self.mock.assert_awaited()
943
944 run(self._runnable_test())
945 self.mock.assert_awaited()
946
947 def test_assert_awaited_once(self):
948 with self.assertRaises(AssertionError):
949 self.mock.assert_awaited_once()
950
951 run(self._runnable_test())
952 self.mock.assert_awaited_once()
953
954 run(self._runnable_test())
955 with self.assertRaises(AssertionError):
956 self.mock.assert_awaited_once()
957
958 def test_assert_awaited_with(self):
959 msg = 'Not awaited'
960 with self.assertRaisesRegex(AssertionError, msg):
961 self.mock.assert_awaited_with('foo')
962
963 run(self._runnable_test())
964 msg = 'expected await not found'
965 with self.assertRaisesRegex(AssertionError, msg):
966 self.mock.assert_awaited_with('foo')
967
968 run(self._runnable_test('foo'))
969 self.mock.assert_awaited_with('foo')
970
971 run(self._runnable_test('SomethingElse'))
972 with self.assertRaises(AssertionError):
973 self.mock.assert_awaited_with('foo')
974
975 def test_assert_awaited_once_with(self):
976 with self.assertRaises(AssertionError):
977 self.mock.assert_awaited_once_with('foo')
978
979 run(self._runnable_test('foo'))
980 self.mock.assert_awaited_once_with('foo')
981
982 run(self._runnable_test('foo'))
983 with self.assertRaises(AssertionError):
984 self.mock.assert_awaited_once_with('foo')
985
986 def test_assert_any_wait(self):
987 with self.assertRaises(AssertionError):
988 self.mock.assert_any_await('foo')
989
990 run(self._runnable_test('baz'))
991 with self.assertRaises(AssertionError):
992 self.mock.assert_any_await('foo')
993
994 run(self._runnable_test('foo'))
995 self.mock.assert_any_await('foo')
996
997 run(self._runnable_test('SomethingElse'))
998 self.mock.assert_any_await('foo')
999
1000 def test_assert_has_awaits_no_order(self):
1001 calls = [call('foo'), call('baz')]
1002
1003 with self.assertRaises(AssertionError) as cm:
1004 self.mock.assert_has_awaits(calls)
1005 self.assertEqual(len(cm.exception.args), 1)
1006
1007 run(self._runnable_test('foo'))
1008 with self.assertRaises(AssertionError):
1009 self.mock.assert_has_awaits(calls)
1010
1011 run(self._runnable_test('foo'))
1012 with self.assertRaises(AssertionError):
1013 self.mock.assert_has_awaits(calls)
1014
1015 run(self._runnable_test('baz'))
1016 self.mock.assert_has_awaits(calls)
1017
1018 run(self._runnable_test('SomethingElse'))
1019 self.mock.assert_has_awaits(calls)
1020
1021 def test_awaits_asserts_with_any(self):
1022 class ESC[4;38;5;81mFoo:
1023 def __eq__(self, other): pass
1024
1025 run(self._runnable_test(Foo(), 1))
1026
1027 self.mock.assert_has_awaits([call(ANY, 1)])
1028 self.mock.assert_awaited_with(ANY, 1)
1029 self.mock.assert_any_await(ANY, 1)
1030
1031 def test_awaits_asserts_with_spec_and_any(self):
1032 class ESC[4;38;5;81mFoo:
1033 def __eq__(self, other): pass
1034
1035 mock_with_spec = AsyncMock(spec=Foo)
1036
1037 async def _custom_mock_runnable_test(*args):
1038 await mock_with_spec(*args)
1039
1040 run(_custom_mock_runnable_test(Foo(), 1))
1041 mock_with_spec.assert_has_awaits([call(ANY, 1)])
1042 mock_with_spec.assert_awaited_with(ANY, 1)
1043 mock_with_spec.assert_any_await(ANY, 1)
1044
1045 def test_assert_has_awaits_ordered(self):
1046 calls = [call('foo'), call('baz')]
1047 with self.assertRaises(AssertionError):
1048 self.mock.assert_has_awaits(calls, any_order=True)
1049
1050 run(self._runnable_test('baz'))
1051 with self.assertRaises(AssertionError):
1052 self.mock.assert_has_awaits(calls, any_order=True)
1053
1054 run(self._runnable_test('bamf'))
1055 with self.assertRaises(AssertionError):
1056 self.mock.assert_has_awaits(calls, any_order=True)
1057
1058 run(self._runnable_test('foo'))
1059 self.mock.assert_has_awaits(calls, any_order=True)
1060
1061 run(self._runnable_test('qux'))
1062 self.mock.assert_has_awaits(calls, any_order=True)
1063
1064 def test_assert_not_awaited(self):
1065 self.mock.assert_not_awaited()
1066
1067 run(self._runnable_test())
1068 with self.assertRaises(AssertionError):
1069 self.mock.assert_not_awaited()
1070
1071 def test_assert_has_awaits_not_matching_spec_error(self):
1072 async def f(x=None): pass
1073
1074 self.mock = AsyncMock(spec=f)
1075 run(self._runnable_test(1))
1076
1077 with self.assertRaisesRegex(
1078 AssertionError,
1079 '^{}$'.format(
1080 re.escape('Awaits not found.\n'
1081 'Expected: [call()]\n'
1082 'Actual: [call(1)]'))) as cm:
1083 self.mock.assert_has_awaits([call()])
1084 self.assertIsNone(cm.exception.__cause__)
1085
1086 with self.assertRaisesRegex(
1087 AssertionError,
1088 '^{}$'.format(
1089 re.escape(
1090 'Error processing expected awaits.\n'
1091 "Errors: [None, TypeError('too many positional "
1092 "arguments')]\n"
1093 'Expected: [call(), call(1, 2)]\n'
1094 'Actual: [call(1)]'))) as cm:
1095 self.mock.assert_has_awaits([call(), call(1, 2)])
1096 self.assertIsInstance(cm.exception.__cause__, TypeError)
1097
1098
1099 if __name__ == '__main__':
1100 unittest.main()