python (3.12.0)
1 """Utilities shared by tests."""
2
3 import asyncio
4 import collections
5 import contextlib
6 import io
7 import logging
8 import os
9 import re
10 import selectors
11 import socket
12 import socketserver
13 import sys
14 import threading
15 import unittest
16 import weakref
17 import warnings
18 from unittest import mock
19
20 from http.server import HTTPServer
21 from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
22
23 try:
24 import ssl
25 except ImportError: # pragma: no cover
26 ssl = None
27
28 from asyncio import base_events
29 from asyncio import events
30 from asyncio import format_helpers
31 from asyncio import futures
32 from asyncio import tasks
33 from asyncio.log import logger
34 from test import support
35 from test.support import socket_helper
36 from test.support import threading_helper
37
38
39 def data_file(filename):
40 if hasattr(support, 'TEST_HOME_DIR'):
41 fullname = os.path.join(support.TEST_HOME_DIR, filename)
42 if os.path.isfile(fullname):
43 return fullname
44 fullname = os.path.join(os.path.dirname(__file__), '..', filename)
45 if os.path.isfile(fullname):
46 return fullname
47 raise FileNotFoundError(filename)
48
49
50 ONLYCERT = data_file('ssl_cert.pem')
51 ONLYKEY = data_file('ssl_key.pem')
52 SIGNED_CERTFILE = data_file('keycert3.pem')
53 SIGNING_CA = data_file('pycacert.pem')
54 PEERCERT = {
55 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
56 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
57 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
58 'issuer': ((('countryName', 'XY'),),
59 (('organizationName', 'Python Software Foundation CA'),),
60 (('commonName', 'our-ca-server'),)),
61 'notAfter': 'Oct 28 14:23:16 2037 GMT',
62 'notBefore': 'Aug 29 14:23:16 2018 GMT',
63 'serialNumber': 'CB2D80995A69525C',
64 'subject': ((('countryName', 'XY'),),
65 (('localityName', 'Castle Anthrax'),),
66 (('organizationName', 'Python Software Foundation'),),
67 (('commonName', 'localhost'),)),
68 'subjectAltName': (('DNS', 'localhost'),),
69 'version': 3
70 }
71
72
73 def simple_server_sslcontext():
74 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
75 server_context.load_cert_chain(ONLYCERT, ONLYKEY)
76 server_context.check_hostname = False
77 server_context.verify_mode = ssl.CERT_NONE
78 return server_context
79
80
81 def simple_client_sslcontext(*, disable_verify=True):
82 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
83 client_context.check_hostname = False
84 if disable_verify:
85 client_context.verify_mode = ssl.CERT_NONE
86 return client_context
87
88
89 def dummy_ssl_context():
90 if ssl is None:
91 return None
92 else:
93 return simple_client_sslcontext(disable_verify=True)
94
95
96 def run_briefly(loop):
97 async def once():
98 pass
99 gen = once()
100 t = loop.create_task(gen)
101 # Don't log a warning if the task is not done after run_until_complete().
102 # It occurs if the loop is stopped or if a task raises a BaseException.
103 t._log_destroy_pending = False
104 try:
105 loop.run_until_complete(t)
106 finally:
107 gen.close()
108
109
110 def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
111 delay = 0.001
112 for _ in support.busy_retry(timeout, error=False):
113 if pred():
114 break
115 loop.run_until_complete(tasks.sleep(delay))
116 delay = max(delay * 2, 1.0)
117 else:
118 raise futures.TimeoutError()
119
120
121 def run_once(loop):
122 """Legacy API to run once through the event loop.
123
124 This is the recommended pattern for test code. It will poll the
125 selector once and run all callbacks scheduled in response to I/O
126 events.
127 """
128 loop.call_soon(loop.stop)
129 loop.run_forever()
130
131
132 class ESC[4;38;5;81mSilentWSGIRequestHandler(ESC[4;38;5;149mWSGIRequestHandler):
133
134 def get_stderr(self):
135 return io.StringIO()
136
137 def log_message(self, format, *args):
138 pass
139
140
141 class ESC[4;38;5;81mSilentWSGIServer(ESC[4;38;5;149mWSGIServer):
142
143 request_timeout = support.LOOPBACK_TIMEOUT
144
145 def get_request(self):
146 request, client_addr = super().get_request()
147 request.settimeout(self.request_timeout)
148 return request, client_addr
149
150 def handle_error(self, request, client_address):
151 pass
152
153
154 class ESC[4;38;5;81mSSLWSGIServerMixin:
155
156 def finish_request(self, request, client_address):
157 # The relative location of our test directory (which
158 # contains the ssl key and certificate files) differs
159 # between the stdlib and stand-alone asyncio.
160 # Prefer our own if we can find it.
161 context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
162 context.load_cert_chain(ONLYCERT, ONLYKEY)
163
164 ssock = context.wrap_socket(request, server_side=True)
165 try:
166 self.RequestHandlerClass(ssock, client_address, self)
167 ssock.close()
168 except OSError:
169 # maybe socket has been closed by peer
170 pass
171
172
173 class ESC[4;38;5;81mSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentWSGIServer):
174 pass
175
176
177 def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
178
179 def loop(environ):
180 size = int(environ['CONTENT_LENGTH'])
181 while size:
182 data = environ['wsgi.input'].read(min(size, 0x10000))
183 yield data
184 size -= len(data)
185
186 def app(environ, start_response):
187 status = '200 OK'
188 headers = [('Content-type', 'text/plain')]
189 start_response(status, headers)
190 if environ['PATH_INFO'] == '/loop':
191 return loop(environ)
192 else:
193 return [b'Test message']
194
195 # Run the test WSGI server in a separate thread in order not to
196 # interfere with event handling in the main thread
197 server_class = server_ssl_cls if use_ssl else server_cls
198 httpd = server_class(address, SilentWSGIRequestHandler)
199 httpd.set_app(app)
200 httpd.address = httpd.server_address
201 server_thread = threading.Thread(
202 target=lambda: httpd.serve_forever(poll_interval=0.05))
203 server_thread.start()
204 try:
205 yield httpd
206 finally:
207 httpd.shutdown()
208 httpd.server_close()
209 server_thread.join()
210
211
212 if hasattr(socket, 'AF_UNIX'):
213
214 class ESC[4;38;5;81mUnixHTTPServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixStreamServer, ESC[4;38;5;149mHTTPServer):
215
216 def server_bind(self):
217 socketserver.UnixStreamServer.server_bind(self)
218 self.server_name = '127.0.0.1'
219 self.server_port = 80
220
221
222 class ESC[4;38;5;81mUnixWSGIServer(ESC[4;38;5;149mUnixHTTPServer, ESC[4;38;5;149mWSGIServer):
223
224 request_timeout = support.LOOPBACK_TIMEOUT
225
226 def server_bind(self):
227 UnixHTTPServer.server_bind(self)
228 self.setup_environ()
229
230 def get_request(self):
231 request, client_addr = super().get_request()
232 request.settimeout(self.request_timeout)
233 # Code in the stdlib expects that get_request
234 # will return a socket and a tuple (host, port).
235 # However, this isn't true for UNIX sockets,
236 # as the second return value will be a path;
237 # hence we return some fake data sufficient
238 # to get the tests going
239 return request, ('127.0.0.1', '')
240
241
242 class ESC[4;38;5;81mSilentUnixWSGIServer(ESC[4;38;5;149mUnixWSGIServer):
243
244 def handle_error(self, request, client_address):
245 pass
246
247
248 class ESC[4;38;5;81mUnixSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentUnixWSGIServer):
249 pass
250
251
252 def gen_unix_socket_path():
253 return socket_helper.create_unix_domain_name()
254
255
256 @contextlib.contextmanager
257 def unix_socket_path():
258 path = gen_unix_socket_path()
259 try:
260 yield path
261 finally:
262 try:
263 os.unlink(path)
264 except OSError:
265 pass
266
267
268 @contextlib.contextmanager
269 def run_test_unix_server(*, use_ssl=False):
270 with unix_socket_path() as path:
271 yield from _run_test_server(address=path, use_ssl=use_ssl,
272 server_cls=SilentUnixWSGIServer,
273 server_ssl_cls=UnixSSLWSGIServer)
274
275
276 @contextlib.contextmanager
277 def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
278 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
279 server_cls=SilentWSGIServer,
280 server_ssl_cls=SSLWSGIServer)
281
282
283 def echo_datagrams(sock):
284 while True:
285 data, addr = sock.recvfrom(4096)
286 if data == b'STOP':
287 sock.close()
288 break
289 else:
290 sock.sendto(data, addr)
291
292
293 @contextlib.contextmanager
294 def run_udp_echo_server(*, host='127.0.0.1', port=0):
295 addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
296 family, type, proto, _, sockaddr = addr_info[0]
297 sock = socket.socket(family, type, proto)
298 sock.bind((host, port))
299 thread = threading.Thread(target=lambda: echo_datagrams(sock))
300 thread.start()
301 try:
302 yield sock.getsockname()
303 finally:
304 sock.sendto(b'STOP', sock.getsockname())
305 thread.join()
306
307
308 def make_test_protocol(base):
309 dct = {}
310 for name in dir(base):
311 if name.startswith('__') and name.endswith('__'):
312 # skip magic names
313 continue
314 dct[name] = MockCallback(return_value=None)
315 return type('TestProtocol', (base,) + base.__bases__, dct)()
316
317
318 class ESC[4;38;5;81mTestSelector(ESC[4;38;5;149mselectorsESC[4;38;5;149m.ESC[4;38;5;149mBaseSelector):
319
320 def __init__(self):
321 self.keys = {}
322
323 def register(self, fileobj, events, data=None):
324 key = selectors.SelectorKey(fileobj, 0, events, data)
325 self.keys[fileobj] = key
326 return key
327
328 def unregister(self, fileobj):
329 return self.keys.pop(fileobj)
330
331 def select(self, timeout):
332 return []
333
334 def get_map(self):
335 return self.keys
336
337
338 class ESC[4;38;5;81mTestLoop(ESC[4;38;5;149mbase_eventsESC[4;38;5;149m.ESC[4;38;5;149mBaseEventLoop):
339 """Loop for unittests.
340
341 It manages self time directly.
342 If something scheduled to be executed later then
343 on next loop iteration after all ready handlers done
344 generator passed to __init__ is calling.
345
346 Generator should be like this:
347
348 def gen():
349 ...
350 when = yield ...
351 ... = yield time_advance
352
353 Value returned by yield is absolute time of next scheduled handler.
354 Value passed to yield is time advance to move loop's time forward.
355 """
356
357 def __init__(self, gen=None):
358 super().__init__()
359
360 if gen is None:
361 def gen():
362 yield
363 self._check_on_close = False
364 else:
365 self._check_on_close = True
366
367 self._gen = gen()
368 next(self._gen)
369 self._time = 0
370 self._clock_resolution = 1e-9
371 self._timers = []
372 self._selector = TestSelector()
373
374 self.readers = {}
375 self.writers = {}
376 self.reset_counters()
377
378 self._transports = weakref.WeakValueDictionary()
379
380 def time(self):
381 return self._time
382
383 def advance_time(self, advance):
384 """Move test time forward."""
385 if advance:
386 self._time += advance
387
388 def close(self):
389 super().close()
390 if self._check_on_close:
391 try:
392 self._gen.send(0)
393 except StopIteration:
394 pass
395 else: # pragma: no cover
396 raise AssertionError("Time generator is not finished")
397
398 def _add_reader(self, fd, callback, *args):
399 self.readers[fd] = events.Handle(callback, args, self, None)
400
401 def _remove_reader(self, fd):
402 self.remove_reader_count[fd] += 1
403 if fd in self.readers:
404 del self.readers[fd]
405 return True
406 else:
407 return False
408
409 def assert_reader(self, fd, callback, *args):
410 if fd not in self.readers:
411 raise AssertionError(f'fd {fd} is not registered')
412 handle = self.readers[fd]
413 if handle._callback != callback:
414 raise AssertionError(
415 f'unexpected callback: {handle._callback} != {callback}')
416 if handle._args != args:
417 raise AssertionError(
418 f'unexpected callback args: {handle._args} != {args}')
419
420 def assert_no_reader(self, fd):
421 if fd in self.readers:
422 raise AssertionError(f'fd {fd} is registered')
423
424 def _add_writer(self, fd, callback, *args):
425 self.writers[fd] = events.Handle(callback, args, self, None)
426
427 def _remove_writer(self, fd):
428 self.remove_writer_count[fd] += 1
429 if fd in self.writers:
430 del self.writers[fd]
431 return True
432 else:
433 return False
434
435 def assert_writer(self, fd, callback, *args):
436 if fd not in self.writers:
437 raise AssertionError(f'fd {fd} is not registered')
438 handle = self.writers[fd]
439 if handle._callback != callback:
440 raise AssertionError(f'{handle._callback!r} != {callback!r}')
441 if handle._args != args:
442 raise AssertionError(f'{handle._args!r} != {args!r}')
443
444 def _ensure_fd_no_transport(self, fd):
445 if not isinstance(fd, int):
446 try:
447 fd = int(fd.fileno())
448 except (AttributeError, TypeError, ValueError):
449 # This code matches selectors._fileobj_to_fd function.
450 raise ValueError("Invalid file object: "
451 "{!r}".format(fd)) from None
452 try:
453 transport = self._transports[fd]
454 except KeyError:
455 pass
456 else:
457 raise RuntimeError(
458 'File descriptor {!r} is used by transport {!r}'.format(
459 fd, transport))
460
461 def add_reader(self, fd, callback, *args):
462 """Add a reader callback."""
463 self._ensure_fd_no_transport(fd)
464 return self._add_reader(fd, callback, *args)
465
466 def remove_reader(self, fd):
467 """Remove a reader callback."""
468 self._ensure_fd_no_transport(fd)
469 return self._remove_reader(fd)
470
471 def add_writer(self, fd, callback, *args):
472 """Add a writer callback.."""
473 self._ensure_fd_no_transport(fd)
474 return self._add_writer(fd, callback, *args)
475
476 def remove_writer(self, fd):
477 """Remove a writer callback."""
478 self._ensure_fd_no_transport(fd)
479 return self._remove_writer(fd)
480
481 def reset_counters(self):
482 self.remove_reader_count = collections.defaultdict(int)
483 self.remove_writer_count = collections.defaultdict(int)
484
485 def _run_once(self):
486 super()._run_once()
487 for when in self._timers:
488 advance = self._gen.send(when)
489 self.advance_time(advance)
490 self._timers = []
491
492 def call_at(self, when, callback, *args, context=None):
493 self._timers.append(when)
494 return super().call_at(when, callback, *args, context=context)
495
496 def _process_events(self, event_list):
497 return
498
499 def _write_to_self(self):
500 pass
501
502
503 def MockCallback(**kwargs):
504 return mock.Mock(spec=['__call__'], **kwargs)
505
506
507 class ESC[4;38;5;81mMockPattern(ESC[4;38;5;149mstr):
508 """A regex based str with a fuzzy __eq__.
509
510 Use this helper with 'mock.assert_called_with', or anywhere
511 where a regex comparison between strings is needed.
512
513 For instance:
514 mock_call.assert_called_with(MockPattern('spam.*ham'))
515 """
516 def __eq__(self, other):
517 return bool(re.search(str(self), other, re.S))
518
519
520 class ESC[4;38;5;81mMockInstanceOf:
521 def __init__(self, type):
522 self._type = type
523
524 def __eq__(self, other):
525 return isinstance(other, self._type)
526
527
528 def get_function_source(func):
529 source = format_helpers._get_function_source(func)
530 if source is None:
531 raise ValueError("unable to get the source of %r" % (func,))
532 return source
533
534
535 class ESC[4;38;5;81mTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
536 @staticmethod
537 def close_loop(loop):
538 if loop._default_executor is not None:
539 if not loop.is_closed():
540 loop.run_until_complete(loop.shutdown_default_executor())
541 else:
542 loop._default_executor.shutdown(wait=True)
543 loop.close()
544 policy = support.maybe_get_event_loop_policy()
545 if policy is not None:
546 try:
547 with warnings.catch_warnings():
548 warnings.simplefilter('ignore', DeprecationWarning)
549 watcher = policy.get_child_watcher()
550 except NotImplementedError:
551 # watcher is not implemented by EventLoopPolicy, e.g. Windows
552 pass
553 else:
554 if isinstance(watcher, asyncio.ThreadedChildWatcher):
555 threads = list(watcher._threads.values())
556 for thread in threads:
557 thread.join()
558
559 def set_event_loop(self, loop, *, cleanup=True):
560 if loop is None:
561 raise AssertionError('loop is None')
562 # ensure that the event loop is passed explicitly in asyncio
563 events.set_event_loop(None)
564 if cleanup:
565 self.addCleanup(self.close_loop, loop)
566
567 def new_test_loop(self, gen=None):
568 loop = TestLoop(gen)
569 self.set_event_loop(loop)
570 return loop
571
572 def setUp(self):
573 self._thread_cleanup = threading_helper.threading_setup()
574
575 def tearDown(self):
576 events.set_event_loop(None)
577
578 # Detect CPython bug #23353: ensure that yield/yield-from is not used
579 # in an except block of a generator
580 self.assertIsNone(sys.exception())
581
582 self.doCleanups()
583 threading_helper.threading_cleanup(*self._thread_cleanup)
584 support.reap_children()
585
586
587 @contextlib.contextmanager
588 def disable_logger():
589 """Context manager to disable asyncio logger.
590
591 For example, it can be used to ignore warnings in debug mode.
592 """
593 old_level = logger.level
594 try:
595 logger.setLevel(logging.CRITICAL+1)
596 yield
597 finally:
598 logger.setLevel(old_level)
599
600
601 def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
602 family=socket.AF_INET):
603 """Create a mock of a non-blocking socket."""
604 sock = mock.MagicMock(socket.socket)
605 sock.proto = proto
606 sock.type = type
607 sock.family = family
608 sock.gettimeout.return_value = 0.0
609 return sock