1 """Tests for asyncio/threads.py"""
2
3 import asyncio
4 import unittest
5
6 from contextvars import ContextVar
7 from unittest import mock
8
9
10 def tearDownModule():
11 asyncio.set_event_loop_policy(None)
12
13
14 class ESC[4;38;5;81mToThreadTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mIsolatedAsyncioTestCase):
15 async def test_to_thread(self):
16 result = await asyncio.to_thread(sum, [40, 2])
17 self.assertEqual(result, 42)
18
19 async def test_to_thread_exception(self):
20 def raise_runtime():
21 raise RuntimeError("test")
22
23 with self.assertRaisesRegex(RuntimeError, "test"):
24 await asyncio.to_thread(raise_runtime)
25
26 async def test_to_thread_once(self):
27 func = mock.Mock()
28
29 await asyncio.to_thread(func)
30 func.assert_called_once()
31
32 async def test_to_thread_concurrent(self):
33 func = mock.Mock()
34
35 futs = []
36 for _ in range(10):
37 fut = asyncio.to_thread(func)
38 futs.append(fut)
39 await asyncio.gather(*futs)
40
41 self.assertEqual(func.call_count, 10)
42
43 async def test_to_thread_args_kwargs(self):
44 # Unlike run_in_executor(), to_thread() should directly accept kwargs.
45 func = mock.Mock()
46
47 await asyncio.to_thread(func, 'test', something=True)
48
49 func.assert_called_once_with('test', something=True)
50
51 async def test_to_thread_contextvars(self):
52 test_ctx = ContextVar('test_ctx')
53
54 def get_ctx():
55 return test_ctx.get()
56
57 test_ctx.set('parrot')
58 result = await asyncio.to_thread(get_ctx)
59
60 self.assertEqual(result, 'parrot')
61
62
63 if __name__ == "__main__":
64 unittest.main()