1 """Utilities shared by tests."""
2
3 import asyncio
4 import collections
5 import contextlib
6 import io
7 import logging
8 import os
9 import re
10 import selectors
11 import socket
12 import socketserver
13 import sys
14 import tempfile
15 import threading
16 import time
17 import unittest
18 import weakref
19
20 from unittest import mock
21
22 from http.server import HTTPServer
23 from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
24
25 try:
26 import ssl
27 except ImportError: # pragma: no cover
28 ssl = None
29
30 from asyncio import base_events
31 from asyncio import events
32 from asyncio import format_helpers
33 from asyncio import futures
34 from asyncio import tasks
35 from asyncio.log import logger
36 from test import support
37 from test.support import threading_helper
38
39
40 # Use the maximum known clock resolution (gh-75191, gh-110088): Windows
41 # GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding
42 # issues.
43 CLOCK_RES = 0.050
44
45
46 def data_file(*filename):
47 if hasattr(support, 'TEST_HOME_DIR'):
48 fullname = os.path.join(support.TEST_HOME_DIR, *filename)
49 if os.path.isfile(fullname):
50 return fullname
51 fullname = os.path.join(os.path.dirname(__file__), '..', *filename)
52 if os.path.isfile(fullname):
53 return fullname
54 raise FileNotFoundError(os.path.join(filename))
55
56
57 ONLYCERT = data_file('certdata', 'ssl_cert.pem')
58 ONLYKEY = data_file('certdata', 'ssl_key.pem')
59 SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem')
60 SIGNING_CA = data_file('certdata', 'pycacert.pem')
61 PEERCERT = {
62 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
63 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
64 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
65 'issuer': ((('countryName', 'XY'),),
66 (('organizationName', 'Python Software Foundation CA'),),
67 (('commonName', 'our-ca-server'),)),
68 'notAfter': 'Oct 28 14:23:16 2037 GMT',
69 'notBefore': 'Aug 29 14:23:16 2018 GMT',
70 'serialNumber': 'CB2D80995A69525C',
71 'subject': ((('countryName', 'XY'),),
72 (('localityName', 'Castle Anthrax'),),
73 (('organizationName', 'Python Software Foundation'),),
74 (('commonName', 'localhost'),)),
75 'subjectAltName': (('DNS', 'localhost'),),
76 'version': 3
77 }
78
79
80 def simple_server_sslcontext():
81 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
82 server_context.load_cert_chain(ONLYCERT, ONLYKEY)
83 server_context.check_hostname = False
84 server_context.verify_mode = ssl.CERT_NONE
85 return server_context
86
87
88 def simple_client_sslcontext(*, disable_verify=True):
89 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
90 client_context.check_hostname = False
91 if disable_verify:
92 client_context.verify_mode = ssl.CERT_NONE
93 return client_context
94
95
96 def dummy_ssl_context():
97 if ssl is None:
98 return None
99 else:
100 return simple_client_sslcontext(disable_verify=True)
101
102
103 def run_briefly(loop):
104 async def once():
105 pass
106 gen = once()
107 t = loop.create_task(gen)
108 # Don't log a warning if the task is not done after run_until_complete().
109 # It occurs if the loop is stopped or if a task raises a BaseException.
110 t._log_destroy_pending = False
111 try:
112 loop.run_until_complete(t)
113 finally:
114 gen.close()
115
116
117 def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
118 deadline = time.monotonic() + timeout
119 while not pred():
120 if timeout is not None:
121 timeout = deadline - time.monotonic()
122 if timeout <= 0:
123 raise futures.TimeoutError()
124 loop.run_until_complete(tasks.sleep(0.001))
125
126
127 def run_once(loop):
128 """Legacy API to run once through the event loop.
129
130 This is the recommended pattern for test code. It will poll the
131 selector once and run all callbacks scheduled in response to I/O
132 events.
133 """
134 loop.call_soon(loop.stop)
135 loop.run_forever()
136
137
138 class ESC[4;38;5;81mSilentWSGIRequestHandler(ESC[4;38;5;149mWSGIRequestHandler):
139
140 def get_stderr(self):
141 return io.StringIO()
142
143 def log_message(self, format, *args):
144 pass
145
146
147 class ESC[4;38;5;81mSilentWSGIServer(ESC[4;38;5;149mWSGIServer):
148
149 request_timeout = support.LOOPBACK_TIMEOUT
150
151 def get_request(self):
152 request, client_addr = super().get_request()
153 request.settimeout(self.request_timeout)
154 return request, client_addr
155
156 def handle_error(self, request, client_address):
157 pass
158
159
160 class ESC[4;38;5;81mSSLWSGIServerMixin:
161
162 def finish_request(self, request, client_address):
163 # The relative location of our test directory (which
164 # contains the ssl key and certificate files) differs
165 # between the stdlib and stand-alone asyncio.
166 # Prefer our own if we can find it.
167 context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
168 context.load_cert_chain(ONLYCERT, ONLYKEY)
169
170 ssock = context.wrap_socket(request, server_side=True)
171 try:
172 self.RequestHandlerClass(ssock, client_address, self)
173 ssock.close()
174 except OSError:
175 # maybe socket has been closed by peer
176 pass
177
178
179 class ESC[4;38;5;81mSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentWSGIServer):
180 pass
181
182
183 def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
184
185 def loop(environ):
186 size = int(environ['CONTENT_LENGTH'])
187 while size:
188 data = environ['wsgi.input'].read(min(size, 0x10000))
189 yield data
190 size -= len(data)
191
192 def app(environ, start_response):
193 status = '200 OK'
194 headers = [('Content-type', 'text/plain')]
195 start_response(status, headers)
196 if environ['PATH_INFO'] == '/loop':
197 return loop(environ)
198 else:
199 return [b'Test message']
200
201 # Run the test WSGI server in a separate thread in order not to
202 # interfere with event handling in the main thread
203 server_class = server_ssl_cls if use_ssl else server_cls
204 httpd = server_class(address, SilentWSGIRequestHandler)
205 httpd.set_app(app)
206 httpd.address = httpd.server_address
207 server_thread = threading.Thread(
208 target=lambda: httpd.serve_forever(poll_interval=0.05))
209 server_thread.start()
210 try:
211 yield httpd
212 finally:
213 httpd.shutdown()
214 httpd.server_close()
215 server_thread.join()
216
217
218 if hasattr(socket, 'AF_UNIX'):
219
220 class ESC[4;38;5;81mUnixHTTPServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixStreamServer, ESC[4;38;5;149mHTTPServer):
221
222 def server_bind(self):
223 socketserver.UnixStreamServer.server_bind(self)
224 self.server_name = '127.0.0.1'
225 self.server_port = 80
226
227
228 class ESC[4;38;5;81mUnixWSGIServer(ESC[4;38;5;149mUnixHTTPServer, ESC[4;38;5;149mWSGIServer):
229
230 request_timeout = support.LOOPBACK_TIMEOUT
231
232 def server_bind(self):
233 UnixHTTPServer.server_bind(self)
234 self.setup_environ()
235
236 def get_request(self):
237 request, client_addr = super().get_request()
238 request.settimeout(self.request_timeout)
239 # Code in the stdlib expects that get_request
240 # will return a socket and a tuple (host, port).
241 # However, this isn't true for UNIX sockets,
242 # as the second return value will be a path;
243 # hence we return some fake data sufficient
244 # to get the tests going
245 return request, ('127.0.0.1', '')
246
247
248 class ESC[4;38;5;81mSilentUnixWSGIServer(ESC[4;38;5;149mUnixWSGIServer):
249
250 def handle_error(self, request, client_address):
251 pass
252
253
254 class ESC[4;38;5;81mUnixSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentUnixWSGIServer):
255 pass
256
257
258 def gen_unix_socket_path():
259 with tempfile.NamedTemporaryFile() as file:
260 return file.name
261
262
263 @contextlib.contextmanager
264 def unix_socket_path():
265 path = gen_unix_socket_path()
266 try:
267 yield path
268 finally:
269 try:
270 os.unlink(path)
271 except OSError:
272 pass
273
274
275 @contextlib.contextmanager
276 def run_test_unix_server(*, use_ssl=False):
277 with unix_socket_path() as path:
278 yield from _run_test_server(address=path, use_ssl=use_ssl,
279 server_cls=SilentUnixWSGIServer,
280 server_ssl_cls=UnixSSLWSGIServer)
281
282
283 @contextlib.contextmanager
284 def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
285 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
286 server_cls=SilentWSGIServer,
287 server_ssl_cls=SSLWSGIServer)
288
289
290 def echo_datagrams(sock):
291 while True:
292 data, addr = sock.recvfrom(4096)
293 if data == b'STOP':
294 sock.close()
295 break
296 else:
297 sock.sendto(data, addr)
298
299
300 @contextlib.contextmanager
301 def run_udp_echo_server(*, host='127.0.0.1', port=0):
302 addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
303 family, type, proto, _, sockaddr = addr_info[0]
304 sock = socket.socket(family, type, proto)
305 sock.bind((host, port))
306 thread = threading.Thread(target=lambda: echo_datagrams(sock))
307 thread.start()
308 try:
309 yield sock.getsockname()
310 finally:
311 sock.sendto(b'STOP', sock.getsockname())
312 thread.join()
313
314
315 def make_test_protocol(base):
316 dct = {}
317 for name in dir(base):
318 if name.startswith('__') and name.endswith('__'):
319 # skip magic names
320 continue
321 dct[name] = MockCallback(return_value=None)
322 return type('TestProtocol', (base,) + base.__bases__, dct)()
323
324
325 class ESC[4;38;5;81mTestSelector(ESC[4;38;5;149mselectorsESC[4;38;5;149m.ESC[4;38;5;149mBaseSelector):
326
327 def __init__(self):
328 self.keys = {}
329
330 def register(self, fileobj, events, data=None):
331 key = selectors.SelectorKey(fileobj, 0, events, data)
332 self.keys[fileobj] = key
333 return key
334
335 def unregister(self, fileobj):
336 return self.keys.pop(fileobj)
337
338 def select(self, timeout):
339 return []
340
341 def get_map(self):
342 return self.keys
343
344
345 class ESC[4;38;5;81mTestLoop(ESC[4;38;5;149mbase_eventsESC[4;38;5;149m.ESC[4;38;5;149mBaseEventLoop):
346 """Loop for unittests.
347
348 It manages self time directly.
349 If something scheduled to be executed later then
350 on next loop iteration after all ready handlers done
351 generator passed to __init__ is calling.
352
353 Generator should be like this:
354
355 def gen():
356 ...
357 when = yield ...
358 ... = yield time_advance
359
360 Value returned by yield is absolute time of next scheduled handler.
361 Value passed to yield is time advance to move loop's time forward.
362 """
363
364 def __init__(self, gen=None):
365 super().__init__()
366
367 if gen is None:
368 def gen():
369 yield
370 self._check_on_close = False
371 else:
372 self._check_on_close = True
373
374 self._gen = gen()
375 next(self._gen)
376 self._time = 0
377 self._clock_resolution = 1e-9
378 self._timers = []
379 self._selector = TestSelector()
380
381 self.readers = {}
382 self.writers = {}
383 self.reset_counters()
384
385 self._transports = weakref.WeakValueDictionary()
386
387 def time(self):
388 return self._time
389
390 def advance_time(self, advance):
391 """Move test time forward."""
392 if advance:
393 self._time += advance
394
395 def close(self):
396 super().close()
397 if self._check_on_close:
398 try:
399 self._gen.send(0)
400 except StopIteration:
401 pass
402 else: # pragma: no cover
403 raise AssertionError("Time generator is not finished")
404
405 def _add_reader(self, fd, callback, *args):
406 self.readers[fd] = events.Handle(callback, args, self, None)
407
408 def _remove_reader(self, fd):
409 self.remove_reader_count[fd] += 1
410 if fd in self.readers:
411 del self.readers[fd]
412 return True
413 else:
414 return False
415
416 def assert_reader(self, fd, callback, *args):
417 if fd not in self.readers:
418 raise AssertionError(f'fd {fd} is not registered')
419 handle = self.readers[fd]
420 if handle._callback != callback:
421 raise AssertionError(
422 f'unexpected callback: {handle._callback} != {callback}')
423 if handle._args != args:
424 raise AssertionError(
425 f'unexpected callback args: {handle._args} != {args}')
426
427 def assert_no_reader(self, fd):
428 if fd in self.readers:
429 raise AssertionError(f'fd {fd} is registered')
430
431 def _add_writer(self, fd, callback, *args):
432 self.writers[fd] = events.Handle(callback, args, self, None)
433
434 def _remove_writer(self, fd):
435 self.remove_writer_count[fd] += 1
436 if fd in self.writers:
437 del self.writers[fd]
438 return True
439 else:
440 return False
441
442 def assert_writer(self, fd, callback, *args):
443 if fd not in self.writers:
444 raise AssertionError(f'fd {fd} is not registered')
445 handle = self.writers[fd]
446 if handle._callback != callback:
447 raise AssertionError(f'{handle._callback!r} != {callback!r}')
448 if handle._args != args:
449 raise AssertionError(f'{handle._args!r} != {args!r}')
450
451 def _ensure_fd_no_transport(self, fd):
452 if not isinstance(fd, int):
453 try:
454 fd = int(fd.fileno())
455 except (AttributeError, TypeError, ValueError):
456 # This code matches selectors._fileobj_to_fd function.
457 raise ValueError("Invalid file object: "
458 "{!r}".format(fd)) from None
459 try:
460 transport = self._transports[fd]
461 except KeyError:
462 pass
463 else:
464 raise RuntimeError(
465 'File descriptor {!r} is used by transport {!r}'.format(
466 fd, transport))
467
468 def add_reader(self, fd, callback, *args):
469 """Add a reader callback."""
470 self._ensure_fd_no_transport(fd)
471 return self._add_reader(fd, callback, *args)
472
473 def remove_reader(self, fd):
474 """Remove a reader callback."""
475 self._ensure_fd_no_transport(fd)
476 return self._remove_reader(fd)
477
478 def add_writer(self, fd, callback, *args):
479 """Add a writer callback.."""
480 self._ensure_fd_no_transport(fd)
481 return self._add_writer(fd, callback, *args)
482
483 def remove_writer(self, fd):
484 """Remove a writer callback."""
485 self._ensure_fd_no_transport(fd)
486 return self._remove_writer(fd)
487
488 def reset_counters(self):
489 self.remove_reader_count = collections.defaultdict(int)
490 self.remove_writer_count = collections.defaultdict(int)
491
492 def _run_once(self):
493 super()._run_once()
494 for when in self._timers:
495 advance = self._gen.send(when)
496 self.advance_time(advance)
497 self._timers = []
498
499 def call_at(self, when, callback, *args, context=None):
500 self._timers.append(when)
501 return super().call_at(when, callback, *args, context=context)
502
503 def _process_events(self, event_list):
504 return
505
506 def _write_to_self(self):
507 pass
508
509
510 def MockCallback(**kwargs):
511 return mock.Mock(spec=['__call__'], **kwargs)
512
513
514 class ESC[4;38;5;81mMockPattern(ESC[4;38;5;149mstr):
515 """A regex based str with a fuzzy __eq__.
516
517 Use this helper with 'mock.assert_called_with', or anywhere
518 where a regex comparison between strings is needed.
519
520 For instance:
521 mock_call.assert_called_with(MockPattern('spam.*ham'))
522 """
523 def __eq__(self, other):
524 return bool(re.search(str(self), other, re.S))
525
526
527 class ESC[4;38;5;81mMockInstanceOf:
528 def __init__(self, type):
529 self._type = type
530
531 def __eq__(self, other):
532 return isinstance(other, self._type)
533
534
535 def get_function_source(func):
536 source = format_helpers._get_function_source(func)
537 if source is None:
538 raise ValueError("unable to get the source of %r" % (func,))
539 return source
540
541
542 class ESC[4;38;5;81mTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
543 @staticmethod
544 def close_loop(loop):
545 if loop._default_executor is not None:
546 if not loop.is_closed():
547 loop.run_until_complete(loop.shutdown_default_executor())
548 else:
549 loop._default_executor.shutdown(wait=True)
550 loop.close()
551
552 policy = support.maybe_get_event_loop_policy()
553 if policy is not None:
554 try:
555 watcher = policy.get_child_watcher()
556 except NotImplementedError:
557 # watcher is not implemented by EventLoopPolicy, e.g. Windows
558 pass
559 else:
560 if isinstance(watcher, asyncio.ThreadedChildWatcher):
561 # Wait for subprocess to finish, but not forever
562 for thread in list(watcher._threads.values()):
563 thread.join(timeout=support.SHORT_TIMEOUT)
564 if thread.is_alive():
565 raise RuntimeError(f"thread {thread} still alive: "
566 "subprocess still running")
567
568
569 def set_event_loop(self, loop, *, cleanup=True):
570 if loop is None:
571 raise AssertionError('loop is None')
572 # ensure that the event loop is passed explicitly in asyncio
573 events.set_event_loop(None)
574 if cleanup:
575 self.addCleanup(self.close_loop, loop)
576
577 def new_test_loop(self, gen=None):
578 loop = TestLoop(gen)
579 self.set_event_loop(loop)
580 return loop
581
582 def setUp(self):
583 self._thread_cleanup = threading_helper.threading_setup()
584
585 def tearDown(self):
586 events.set_event_loop(None)
587
588 # Detect CPython bug #23353: ensure that yield/yield-from is not used
589 # in an except block of a generator
590 self.assertEqual(sys.exc_info(), (None, None, None))
591
592 self.doCleanups()
593 threading_helper.threading_cleanup(*self._thread_cleanup)
594 support.reap_children()
595
596
597 @contextlib.contextmanager
598 def disable_logger():
599 """Context manager to disable asyncio logger.
600
601 For example, it can be used to ignore warnings in debug mode.
602 """
603 old_level = logger.level
604 try:
605 logger.setLevel(logging.CRITICAL+1)
606 yield
607 finally:
608 logger.setLevel(old_level)
609
610
611 def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
612 family=socket.AF_INET):
613 """Create a mock of a non-blocking socket."""
614 sock = mock.MagicMock(socket.socket)
615 sock.proto = proto
616 sock.type = type
617 sock.family = family
618 sock.gettimeout.return_value = 0.0
619 return sock
620
621
622 async def await_without_task(coro):
623 exc = None
624 def func():
625 try:
626 for _ in coro.__await__():
627 pass
628 except BaseException as err:
629 nonlocal exc
630 exc = err
631 asyncio.get_running_loop().call_soon(func)
632 await asyncio.sleep(0)
633 if exc is not None:
634 raise exc