(root)/
Python-3.12.0/
Lib/
test/
test_asyncio/
utils.py
       1  """Utilities shared by tests."""
       2  
       3  import asyncio
       4  import collections
       5  import contextlib
       6  import io
       7  import logging
       8  import os
       9  import re
      10  import selectors
      11  import socket
      12  import socketserver
      13  import sys
      14  import threading
      15  import unittest
      16  import weakref
      17  import warnings
      18  from unittest import mock
      19  
      20  from http.server import HTTPServer
      21  from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
      22  
      23  try:
      24      import ssl
      25  except ImportError:  # pragma: no cover
      26      ssl = None
      27  
      28  from asyncio import base_events
      29  from asyncio import events
      30  from asyncio import format_helpers
      31  from asyncio import futures
      32  from asyncio import tasks
      33  from asyncio.log import logger
      34  from test import support
      35  from test.support import socket_helper
      36  from test.support import threading_helper
      37  
      38  
      39  def data_file(filename):
      40      if hasattr(support, 'TEST_HOME_DIR'):
      41          fullname = os.path.join(support.TEST_HOME_DIR, filename)
      42          if os.path.isfile(fullname):
      43              return fullname
      44      fullname = os.path.join(os.path.dirname(__file__), '..', filename)
      45      if os.path.isfile(fullname):
      46          return fullname
      47      raise FileNotFoundError(filename)
      48  
      49  
      50  ONLYCERT = data_file('ssl_cert.pem')
      51  ONLYKEY = data_file('ssl_key.pem')
      52  SIGNED_CERTFILE = data_file('keycert3.pem')
      53  SIGNING_CA = data_file('pycacert.pem')
      54  PEERCERT = {
      55      'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
      56      'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
      57      'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
      58      'issuer': ((('countryName', 'XY'),),
      59              (('organizationName', 'Python Software Foundation CA'),),
      60              (('commonName', 'our-ca-server'),)),
      61      'notAfter': 'Oct 28 14:23:16 2037 GMT',
      62      'notBefore': 'Aug 29 14:23:16 2018 GMT',
      63      'serialNumber': 'CB2D80995A69525C',
      64      'subject': ((('countryName', 'XY'),),
      65               (('localityName', 'Castle Anthrax'),),
      66               (('organizationName', 'Python Software Foundation'),),
      67               (('commonName', 'localhost'),)),
      68      'subjectAltName': (('DNS', 'localhost'),),
      69      'version': 3
      70  }
      71  
      72  
      73  def simple_server_sslcontext():
      74      server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
      75      server_context.load_cert_chain(ONLYCERT, ONLYKEY)
      76      server_context.check_hostname = False
      77      server_context.verify_mode = ssl.CERT_NONE
      78      return server_context
      79  
      80  
      81  def simple_client_sslcontext(*, disable_verify=True):
      82      client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
      83      client_context.check_hostname = False
      84      if disable_verify:
      85          client_context.verify_mode = ssl.CERT_NONE
      86      return client_context
      87  
      88  
      89  def dummy_ssl_context():
      90      if ssl is None:
      91          return None
      92      else:
      93          return simple_client_sslcontext(disable_verify=True)
      94  
      95  
      96  def run_briefly(loop):
      97      async def once():
      98          pass
      99      gen = once()
     100      t = loop.create_task(gen)
     101      # Don't log a warning if the task is not done after run_until_complete().
     102      # It occurs if the loop is stopped or if a task raises a BaseException.
     103      t._log_destroy_pending = False
     104      try:
     105          loop.run_until_complete(t)
     106      finally:
     107          gen.close()
     108  
     109  
     110  def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
     111      delay = 0.001
     112      for _ in support.busy_retry(timeout, error=False):
     113          if pred():
     114              break
     115          loop.run_until_complete(tasks.sleep(delay))
     116          delay = max(delay * 2, 1.0)
     117      else:
     118          raise futures.TimeoutError()
     119  
     120  
     121  def run_once(loop):
     122      """Legacy API to run once through the event loop.
     123  
     124      This is the recommended pattern for test code.  It will poll the
     125      selector once and run all callbacks scheduled in response to I/O
     126      events.
     127      """
     128      loop.call_soon(loop.stop)
     129      loop.run_forever()
     130  
     131  
     132  class ESC[4;38;5;81mSilentWSGIRequestHandler(ESC[4;38;5;149mWSGIRequestHandler):
     133  
     134      def get_stderr(self):
     135          return io.StringIO()
     136  
     137      def log_message(self, format, *args):
     138          pass
     139  
     140  
     141  class ESC[4;38;5;81mSilentWSGIServer(ESC[4;38;5;149mWSGIServer):
     142  
     143      request_timeout = support.LOOPBACK_TIMEOUT
     144  
     145      def get_request(self):
     146          request, client_addr = super().get_request()
     147          request.settimeout(self.request_timeout)
     148          return request, client_addr
     149  
     150      def handle_error(self, request, client_address):
     151          pass
     152  
     153  
     154  class ESC[4;38;5;81mSSLWSGIServerMixin:
     155  
     156      def finish_request(self, request, client_address):
     157          # The relative location of our test directory (which
     158          # contains the ssl key and certificate files) differs
     159          # between the stdlib and stand-alone asyncio.
     160          # Prefer our own if we can find it.
     161          context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     162          context.load_cert_chain(ONLYCERT, ONLYKEY)
     163  
     164          ssock = context.wrap_socket(request, server_side=True)
     165          try:
     166              self.RequestHandlerClass(ssock, client_address, self)
     167              ssock.close()
     168          except OSError:
     169              # maybe socket has been closed by peer
     170              pass
     171  
     172  
     173  class ESC[4;38;5;81mSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentWSGIServer):
     174      pass
     175  
     176  
     177  def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
     178  
     179      def loop(environ):
     180          size = int(environ['CONTENT_LENGTH'])
     181          while size:
     182              data = environ['wsgi.input'].read(min(size, 0x10000))
     183              yield data
     184              size -= len(data)
     185  
     186      def app(environ, start_response):
     187          status = '200 OK'
     188          headers = [('Content-type', 'text/plain')]
     189          start_response(status, headers)
     190          if environ['PATH_INFO'] == '/loop':
     191              return loop(environ)
     192          else:
     193              return [b'Test message']
     194  
     195      # Run the test WSGI server in a separate thread in order not to
     196      # interfere with event handling in the main thread
     197      server_class = server_ssl_cls if use_ssl else server_cls
     198      httpd = server_class(address, SilentWSGIRequestHandler)
     199      httpd.set_app(app)
     200      httpd.address = httpd.server_address
     201      server_thread = threading.Thread(
     202          target=lambda: httpd.serve_forever(poll_interval=0.05))
     203      server_thread.start()
     204      try:
     205          yield httpd
     206      finally:
     207          httpd.shutdown()
     208          httpd.server_close()
     209          server_thread.join()
     210  
     211  
     212  if hasattr(socket, 'AF_UNIX'):
     213  
     214      class ESC[4;38;5;81mUnixHTTPServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixStreamServer, ESC[4;38;5;149mHTTPServer):
     215  
     216          def server_bind(self):
     217              socketserver.UnixStreamServer.server_bind(self)
     218              self.server_name = '127.0.0.1'
     219              self.server_port = 80
     220  
     221  
     222      class ESC[4;38;5;81mUnixWSGIServer(ESC[4;38;5;149mUnixHTTPServer, ESC[4;38;5;149mWSGIServer):
     223  
     224          request_timeout = support.LOOPBACK_TIMEOUT
     225  
     226          def server_bind(self):
     227              UnixHTTPServer.server_bind(self)
     228              self.setup_environ()
     229  
     230          def get_request(self):
     231              request, client_addr = super().get_request()
     232              request.settimeout(self.request_timeout)
     233              # Code in the stdlib expects that get_request
     234              # will return a socket and a tuple (host, port).
     235              # However, this isn't true for UNIX sockets,
     236              # as the second return value will be a path;
     237              # hence we return some fake data sufficient
     238              # to get the tests going
     239              return request, ('127.0.0.1', '')
     240  
     241  
     242      class ESC[4;38;5;81mSilentUnixWSGIServer(ESC[4;38;5;149mUnixWSGIServer):
     243  
     244          def handle_error(self, request, client_address):
     245              pass
     246  
     247  
     248      class ESC[4;38;5;81mUnixSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentUnixWSGIServer):
     249          pass
     250  
     251  
     252      def gen_unix_socket_path():
     253          return socket_helper.create_unix_domain_name()
     254  
     255  
     256      @contextlib.contextmanager
     257      def unix_socket_path():
     258          path = gen_unix_socket_path()
     259          try:
     260              yield path
     261          finally:
     262              try:
     263                  os.unlink(path)
     264              except OSError:
     265                  pass
     266  
     267  
     268      @contextlib.contextmanager
     269      def run_test_unix_server(*, use_ssl=False):
     270          with unix_socket_path() as path:
     271              yield from _run_test_server(address=path, use_ssl=use_ssl,
     272                                          server_cls=SilentUnixWSGIServer,
     273                                          server_ssl_cls=UnixSSLWSGIServer)
     274  
     275  
     276  @contextlib.contextmanager
     277  def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
     278      yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
     279                                  server_cls=SilentWSGIServer,
     280                                  server_ssl_cls=SSLWSGIServer)
     281  
     282  
     283  def echo_datagrams(sock):
     284      while True:
     285          data, addr = sock.recvfrom(4096)
     286          if data == b'STOP':
     287              sock.close()
     288              break
     289          else:
     290              sock.sendto(data, addr)
     291  
     292  
     293  @contextlib.contextmanager
     294  def run_udp_echo_server(*, host='127.0.0.1', port=0):
     295      addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
     296      family, type, proto, _, sockaddr = addr_info[0]
     297      sock = socket.socket(family, type, proto)
     298      sock.bind((host, port))
     299      thread = threading.Thread(target=lambda: echo_datagrams(sock))
     300      thread.start()
     301      try:
     302          yield sock.getsockname()
     303      finally:
     304          sock.sendto(b'STOP', sock.getsockname())
     305          thread.join()
     306  
     307  
     308  def make_test_protocol(base):
     309      dct = {}
     310      for name in dir(base):
     311          if name.startswith('__') and name.endswith('__'):
     312              # skip magic names
     313              continue
     314          dct[name] = MockCallback(return_value=None)
     315      return type('TestProtocol', (base,) + base.__bases__, dct)()
     316  
     317  
     318  class ESC[4;38;5;81mTestSelector(ESC[4;38;5;149mselectorsESC[4;38;5;149m.ESC[4;38;5;149mBaseSelector):
     319  
     320      def __init__(self):
     321          self.keys = {}
     322  
     323      def register(self, fileobj, events, data=None):
     324          key = selectors.SelectorKey(fileobj, 0, events, data)
     325          self.keys[fileobj] = key
     326          return key
     327  
     328      def unregister(self, fileobj):
     329          return self.keys.pop(fileobj)
     330  
     331      def select(self, timeout):
     332          return []
     333  
     334      def get_map(self):
     335          return self.keys
     336  
     337  
     338  class ESC[4;38;5;81mTestLoop(ESC[4;38;5;149mbase_eventsESC[4;38;5;149m.ESC[4;38;5;149mBaseEventLoop):
     339      """Loop for unittests.
     340  
     341      It manages self time directly.
     342      If something scheduled to be executed later then
     343      on next loop iteration after all ready handlers done
     344      generator passed to __init__ is calling.
     345  
     346      Generator should be like this:
     347  
     348          def gen():
     349              ...
     350              when = yield ...
     351              ... = yield time_advance
     352  
     353      Value returned by yield is absolute time of next scheduled handler.
     354      Value passed to yield is time advance to move loop's time forward.
     355      """
     356  
     357      def __init__(self, gen=None):
     358          super().__init__()
     359  
     360          if gen is None:
     361              def gen():
     362                  yield
     363              self._check_on_close = False
     364          else:
     365              self._check_on_close = True
     366  
     367          self._gen = gen()
     368          next(self._gen)
     369          self._time = 0
     370          self._clock_resolution = 1e-9
     371          self._timers = []
     372          self._selector = TestSelector()
     373  
     374          self.readers = {}
     375          self.writers = {}
     376          self.reset_counters()
     377  
     378          self._transports = weakref.WeakValueDictionary()
     379  
     380      def time(self):
     381          return self._time
     382  
     383      def advance_time(self, advance):
     384          """Move test time forward."""
     385          if advance:
     386              self._time += advance
     387  
     388      def close(self):
     389          super().close()
     390          if self._check_on_close:
     391              try:
     392                  self._gen.send(0)
     393              except StopIteration:
     394                  pass
     395              else:  # pragma: no cover
     396                  raise AssertionError("Time generator is not finished")
     397  
     398      def _add_reader(self, fd, callback, *args):
     399          self.readers[fd] = events.Handle(callback, args, self, None)
     400  
     401      def _remove_reader(self, fd):
     402          self.remove_reader_count[fd] += 1
     403          if fd in self.readers:
     404              del self.readers[fd]
     405              return True
     406          else:
     407              return False
     408  
     409      def assert_reader(self, fd, callback, *args):
     410          if fd not in self.readers:
     411              raise AssertionError(f'fd {fd} is not registered')
     412          handle = self.readers[fd]
     413          if handle._callback != callback:
     414              raise AssertionError(
     415                  f'unexpected callback: {handle._callback} != {callback}')
     416          if handle._args != args:
     417              raise AssertionError(
     418                  f'unexpected callback args: {handle._args} != {args}')
     419  
     420      def assert_no_reader(self, fd):
     421          if fd in self.readers:
     422              raise AssertionError(f'fd {fd} is registered')
     423  
     424      def _add_writer(self, fd, callback, *args):
     425          self.writers[fd] = events.Handle(callback, args, self, None)
     426  
     427      def _remove_writer(self, fd):
     428          self.remove_writer_count[fd] += 1
     429          if fd in self.writers:
     430              del self.writers[fd]
     431              return True
     432          else:
     433              return False
     434  
     435      def assert_writer(self, fd, callback, *args):
     436          if fd not in self.writers:
     437              raise AssertionError(f'fd {fd} is not registered')
     438          handle = self.writers[fd]
     439          if handle._callback != callback:
     440              raise AssertionError(f'{handle._callback!r} != {callback!r}')
     441          if handle._args != args:
     442              raise AssertionError(f'{handle._args!r} != {args!r}')
     443  
     444      def _ensure_fd_no_transport(self, fd):
     445          if not isinstance(fd, int):
     446              try:
     447                  fd = int(fd.fileno())
     448              except (AttributeError, TypeError, ValueError):
     449                  # This code matches selectors._fileobj_to_fd function.
     450                  raise ValueError("Invalid file object: "
     451                                   "{!r}".format(fd)) from None
     452          try:
     453              transport = self._transports[fd]
     454          except KeyError:
     455              pass
     456          else:
     457              raise RuntimeError(
     458                  'File descriptor {!r} is used by transport {!r}'.format(
     459                      fd, transport))
     460  
     461      def add_reader(self, fd, callback, *args):
     462          """Add a reader callback."""
     463          self._ensure_fd_no_transport(fd)
     464          return self._add_reader(fd, callback, *args)
     465  
     466      def remove_reader(self, fd):
     467          """Remove a reader callback."""
     468          self._ensure_fd_no_transport(fd)
     469          return self._remove_reader(fd)
     470  
     471      def add_writer(self, fd, callback, *args):
     472          """Add a writer callback.."""
     473          self._ensure_fd_no_transport(fd)
     474          return self._add_writer(fd, callback, *args)
     475  
     476      def remove_writer(self, fd):
     477          """Remove a writer callback."""
     478          self._ensure_fd_no_transport(fd)
     479          return self._remove_writer(fd)
     480  
     481      def reset_counters(self):
     482          self.remove_reader_count = collections.defaultdict(int)
     483          self.remove_writer_count = collections.defaultdict(int)
     484  
     485      def _run_once(self):
     486          super()._run_once()
     487          for when in self._timers:
     488              advance = self._gen.send(when)
     489              self.advance_time(advance)
     490          self._timers = []
     491  
     492      def call_at(self, when, callback, *args, context=None):
     493          self._timers.append(when)
     494          return super().call_at(when, callback, *args, context=context)
     495  
     496      def _process_events(self, event_list):
     497          return
     498  
     499      def _write_to_self(self):
     500          pass
     501  
     502  
     503  def MockCallback(**kwargs):
     504      return mock.Mock(spec=['__call__'], **kwargs)
     505  
     506  
     507  class ESC[4;38;5;81mMockPattern(ESC[4;38;5;149mstr):
     508      """A regex based str with a fuzzy __eq__.
     509  
     510      Use this helper with 'mock.assert_called_with', or anywhere
     511      where a regex comparison between strings is needed.
     512  
     513      For instance:
     514         mock_call.assert_called_with(MockPattern('spam.*ham'))
     515      """
     516      def __eq__(self, other):
     517          return bool(re.search(str(self), other, re.S))
     518  
     519  
     520  class ESC[4;38;5;81mMockInstanceOf:
     521      def __init__(self, type):
     522          self._type = type
     523  
     524      def __eq__(self, other):
     525          return isinstance(other, self._type)
     526  
     527  
     528  def get_function_source(func):
     529      source = format_helpers._get_function_source(func)
     530      if source is None:
     531          raise ValueError("unable to get the source of %r" % (func,))
     532      return source
     533  
     534  
     535  class ESC[4;38;5;81mTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     536      @staticmethod
     537      def close_loop(loop):
     538          if loop._default_executor is not None:
     539              if not loop.is_closed():
     540                  loop.run_until_complete(loop.shutdown_default_executor())
     541              else:
     542                  loop._default_executor.shutdown(wait=True)
     543          loop.close()
     544          policy = support.maybe_get_event_loop_policy()
     545          if policy is not None:
     546              try:
     547                  with warnings.catch_warnings():
     548                      warnings.simplefilter('ignore', DeprecationWarning)
     549                      watcher = policy.get_child_watcher()
     550              except NotImplementedError:
     551                  # watcher is not implemented by EventLoopPolicy, e.g. Windows
     552                  pass
     553              else:
     554                  if isinstance(watcher, asyncio.ThreadedChildWatcher):
     555                      threads = list(watcher._threads.values())
     556                      for thread in threads:
     557                          thread.join()
     558  
     559      def set_event_loop(self, loop, *, cleanup=True):
     560          if loop is None:
     561              raise AssertionError('loop is None')
     562          # ensure that the event loop is passed explicitly in asyncio
     563          events.set_event_loop(None)
     564          if cleanup:
     565              self.addCleanup(self.close_loop, loop)
     566  
     567      def new_test_loop(self, gen=None):
     568          loop = TestLoop(gen)
     569          self.set_event_loop(loop)
     570          return loop
     571  
     572      def setUp(self):
     573          self._thread_cleanup = threading_helper.threading_setup()
     574  
     575      def tearDown(self):
     576          events.set_event_loop(None)
     577  
     578          # Detect CPython bug #23353: ensure that yield/yield-from is not used
     579          # in an except block of a generator
     580          self.assertIsNone(sys.exception())
     581  
     582          self.doCleanups()
     583          threading_helper.threading_cleanup(*self._thread_cleanup)
     584          support.reap_children()
     585  
     586  
     587  @contextlib.contextmanager
     588  def disable_logger():
     589      """Context manager to disable asyncio logger.
     590  
     591      For example, it can be used to ignore warnings in debug mode.
     592      """
     593      old_level = logger.level
     594      try:
     595          logger.setLevel(logging.CRITICAL+1)
     596          yield
     597      finally:
     598          logger.setLevel(old_level)
     599  
     600  
     601  def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
     602                              family=socket.AF_INET):
     603      """Create a mock of a non-blocking socket."""
     604      sock = mock.MagicMock(socket.socket)
     605      sock.proto = proto
     606      sock.type = type
     607      sock.family = family
     608      sock.gettimeout.return_value = 0.0
     609      return sock