(root)/
Python-3.11.7/
Lib/
test/
test_asyncio/
utils.py
       1  """Utilities shared by tests."""
       2  
       3  import asyncio
       4  import collections
       5  import contextlib
       6  import io
       7  import logging
       8  import os
       9  import re
      10  import selectors
      11  import socket
      12  import socketserver
      13  import sys
      14  import tempfile
      15  import threading
      16  import time
      17  import unittest
      18  import weakref
      19  
      20  from unittest import mock
      21  
      22  from http.server import HTTPServer
      23  from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
      24  
      25  try:
      26      import ssl
      27  except ImportError:  # pragma: no cover
      28      ssl = None
      29  
      30  from asyncio import base_events
      31  from asyncio import events
      32  from asyncio import format_helpers
      33  from asyncio import futures
      34  from asyncio import tasks
      35  from asyncio.log import logger
      36  from test import support
      37  from test.support import threading_helper
      38  
      39  
      40  # Use the maximum known clock resolution (gh-75191, gh-110088): Windows
      41  # GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding
      42  # issues.
      43  CLOCK_RES = 0.050
      44  
      45  
      46  def data_file(*filename):
      47      if hasattr(support, 'TEST_HOME_DIR'):
      48          fullname = os.path.join(support.TEST_HOME_DIR, *filename)
      49          if os.path.isfile(fullname):
      50              return fullname
      51      fullname = os.path.join(os.path.dirname(__file__), '..', *filename)
      52      if os.path.isfile(fullname):
      53          return fullname
      54      raise FileNotFoundError(os.path.join(filename))
      55  
      56  
      57  ONLYCERT = data_file('certdata', 'ssl_cert.pem')
      58  ONLYKEY = data_file('certdata', 'ssl_key.pem')
      59  SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem')
      60  SIGNING_CA = data_file('certdata', 'pycacert.pem')
      61  PEERCERT = {
      62      'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
      63      'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
      64      'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
      65      'issuer': ((('countryName', 'XY'),),
      66              (('organizationName', 'Python Software Foundation CA'),),
      67              (('commonName', 'our-ca-server'),)),
      68      'notAfter': 'Oct 28 14:23:16 2037 GMT',
      69      'notBefore': 'Aug 29 14:23:16 2018 GMT',
      70      'serialNumber': 'CB2D80995A69525C',
      71      'subject': ((('countryName', 'XY'),),
      72               (('localityName', 'Castle Anthrax'),),
      73               (('organizationName', 'Python Software Foundation'),),
      74               (('commonName', 'localhost'),)),
      75      'subjectAltName': (('DNS', 'localhost'),),
      76      'version': 3
      77  }
      78  
      79  
      80  def simple_server_sslcontext():
      81      server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
      82      server_context.load_cert_chain(ONLYCERT, ONLYKEY)
      83      server_context.check_hostname = False
      84      server_context.verify_mode = ssl.CERT_NONE
      85      return server_context
      86  
      87  
      88  def simple_client_sslcontext(*, disable_verify=True):
      89      client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
      90      client_context.check_hostname = False
      91      if disable_verify:
      92          client_context.verify_mode = ssl.CERT_NONE
      93      return client_context
      94  
      95  
      96  def dummy_ssl_context():
      97      if ssl is None:
      98          return None
      99      else:
     100          return simple_client_sslcontext(disable_verify=True)
     101  
     102  
     103  def run_briefly(loop):
     104      async def once():
     105          pass
     106      gen = once()
     107      t = loop.create_task(gen)
     108      # Don't log a warning if the task is not done after run_until_complete().
     109      # It occurs if the loop is stopped or if a task raises a BaseException.
     110      t._log_destroy_pending = False
     111      try:
     112          loop.run_until_complete(t)
     113      finally:
     114          gen.close()
     115  
     116  
     117  def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
     118      deadline = time.monotonic() + timeout
     119      while not pred():
     120          if timeout is not None:
     121              timeout = deadline - time.monotonic()
     122              if timeout <= 0:
     123                  raise futures.TimeoutError()
     124          loop.run_until_complete(tasks.sleep(0.001))
     125  
     126  
     127  def run_once(loop):
     128      """Legacy API to run once through the event loop.
     129  
     130      This is the recommended pattern for test code.  It will poll the
     131      selector once and run all callbacks scheduled in response to I/O
     132      events.
     133      """
     134      loop.call_soon(loop.stop)
     135      loop.run_forever()
     136  
     137  
     138  class ESC[4;38;5;81mSilentWSGIRequestHandler(ESC[4;38;5;149mWSGIRequestHandler):
     139  
     140      def get_stderr(self):
     141          return io.StringIO()
     142  
     143      def log_message(self, format, *args):
     144          pass
     145  
     146  
     147  class ESC[4;38;5;81mSilentWSGIServer(ESC[4;38;5;149mWSGIServer):
     148  
     149      request_timeout = support.LOOPBACK_TIMEOUT
     150  
     151      def get_request(self):
     152          request, client_addr = super().get_request()
     153          request.settimeout(self.request_timeout)
     154          return request, client_addr
     155  
     156      def handle_error(self, request, client_address):
     157          pass
     158  
     159  
     160  class ESC[4;38;5;81mSSLWSGIServerMixin:
     161  
     162      def finish_request(self, request, client_address):
     163          # The relative location of our test directory (which
     164          # contains the ssl key and certificate files) differs
     165          # between the stdlib and stand-alone asyncio.
     166          # Prefer our own if we can find it.
     167          context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     168          context.load_cert_chain(ONLYCERT, ONLYKEY)
     169  
     170          ssock = context.wrap_socket(request, server_side=True)
     171          try:
     172              self.RequestHandlerClass(ssock, client_address, self)
     173              ssock.close()
     174          except OSError:
     175              # maybe socket has been closed by peer
     176              pass
     177  
     178  
     179  class ESC[4;38;5;81mSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentWSGIServer):
     180      pass
     181  
     182  
     183  def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
     184  
     185      def loop(environ):
     186          size = int(environ['CONTENT_LENGTH'])
     187          while size:
     188              data = environ['wsgi.input'].read(min(size, 0x10000))
     189              yield data
     190              size -= len(data)
     191  
     192      def app(environ, start_response):
     193          status = '200 OK'
     194          headers = [('Content-type', 'text/plain')]
     195          start_response(status, headers)
     196          if environ['PATH_INFO'] == '/loop':
     197              return loop(environ)
     198          else:
     199              return [b'Test message']
     200  
     201      # Run the test WSGI server in a separate thread in order not to
     202      # interfere with event handling in the main thread
     203      server_class = server_ssl_cls if use_ssl else server_cls
     204      httpd = server_class(address, SilentWSGIRequestHandler)
     205      httpd.set_app(app)
     206      httpd.address = httpd.server_address
     207      server_thread = threading.Thread(
     208          target=lambda: httpd.serve_forever(poll_interval=0.05))
     209      server_thread.start()
     210      try:
     211          yield httpd
     212      finally:
     213          httpd.shutdown()
     214          httpd.server_close()
     215          server_thread.join()
     216  
     217  
     218  if hasattr(socket, 'AF_UNIX'):
     219  
     220      class ESC[4;38;5;81mUnixHTTPServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixStreamServer, ESC[4;38;5;149mHTTPServer):
     221  
     222          def server_bind(self):
     223              socketserver.UnixStreamServer.server_bind(self)
     224              self.server_name = '127.0.0.1'
     225              self.server_port = 80
     226  
     227  
     228      class ESC[4;38;5;81mUnixWSGIServer(ESC[4;38;5;149mUnixHTTPServer, ESC[4;38;5;149mWSGIServer):
     229  
     230          request_timeout = support.LOOPBACK_TIMEOUT
     231  
     232          def server_bind(self):
     233              UnixHTTPServer.server_bind(self)
     234              self.setup_environ()
     235  
     236          def get_request(self):
     237              request, client_addr = super().get_request()
     238              request.settimeout(self.request_timeout)
     239              # Code in the stdlib expects that get_request
     240              # will return a socket and a tuple (host, port).
     241              # However, this isn't true for UNIX sockets,
     242              # as the second return value will be a path;
     243              # hence we return some fake data sufficient
     244              # to get the tests going
     245              return request, ('127.0.0.1', '')
     246  
     247  
     248      class ESC[4;38;5;81mSilentUnixWSGIServer(ESC[4;38;5;149mUnixWSGIServer):
     249  
     250          def handle_error(self, request, client_address):
     251              pass
     252  
     253  
     254      class ESC[4;38;5;81mUnixSSLWSGIServer(ESC[4;38;5;149mSSLWSGIServerMixin, ESC[4;38;5;149mSilentUnixWSGIServer):
     255          pass
     256  
     257  
     258      def gen_unix_socket_path():
     259          with tempfile.NamedTemporaryFile() as file:
     260              return file.name
     261  
     262  
     263      @contextlib.contextmanager
     264      def unix_socket_path():
     265          path = gen_unix_socket_path()
     266          try:
     267              yield path
     268          finally:
     269              try:
     270                  os.unlink(path)
     271              except OSError:
     272                  pass
     273  
     274  
     275      @contextlib.contextmanager
     276      def run_test_unix_server(*, use_ssl=False):
     277          with unix_socket_path() as path:
     278              yield from _run_test_server(address=path, use_ssl=use_ssl,
     279                                          server_cls=SilentUnixWSGIServer,
     280                                          server_ssl_cls=UnixSSLWSGIServer)
     281  
     282  
     283  @contextlib.contextmanager
     284  def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
     285      yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
     286                                  server_cls=SilentWSGIServer,
     287                                  server_ssl_cls=SSLWSGIServer)
     288  
     289  
     290  def echo_datagrams(sock):
     291      while True:
     292          data, addr = sock.recvfrom(4096)
     293          if data == b'STOP':
     294              sock.close()
     295              break
     296          else:
     297              sock.sendto(data, addr)
     298  
     299  
     300  @contextlib.contextmanager
     301  def run_udp_echo_server(*, host='127.0.0.1', port=0):
     302      addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
     303      family, type, proto, _, sockaddr = addr_info[0]
     304      sock = socket.socket(family, type, proto)
     305      sock.bind((host, port))
     306      thread = threading.Thread(target=lambda: echo_datagrams(sock))
     307      thread.start()
     308      try:
     309          yield sock.getsockname()
     310      finally:
     311          sock.sendto(b'STOP', sock.getsockname())
     312          thread.join()
     313  
     314  
     315  def make_test_protocol(base):
     316      dct = {}
     317      for name in dir(base):
     318          if name.startswith('__') and name.endswith('__'):
     319              # skip magic names
     320              continue
     321          dct[name] = MockCallback(return_value=None)
     322      return type('TestProtocol', (base,) + base.__bases__, dct)()
     323  
     324  
     325  class ESC[4;38;5;81mTestSelector(ESC[4;38;5;149mselectorsESC[4;38;5;149m.ESC[4;38;5;149mBaseSelector):
     326  
     327      def __init__(self):
     328          self.keys = {}
     329  
     330      def register(self, fileobj, events, data=None):
     331          key = selectors.SelectorKey(fileobj, 0, events, data)
     332          self.keys[fileobj] = key
     333          return key
     334  
     335      def unregister(self, fileobj):
     336          return self.keys.pop(fileobj)
     337  
     338      def select(self, timeout):
     339          return []
     340  
     341      def get_map(self):
     342          return self.keys
     343  
     344  
     345  class ESC[4;38;5;81mTestLoop(ESC[4;38;5;149mbase_eventsESC[4;38;5;149m.ESC[4;38;5;149mBaseEventLoop):
     346      """Loop for unittests.
     347  
     348      It manages self time directly.
     349      If something scheduled to be executed later then
     350      on next loop iteration after all ready handlers done
     351      generator passed to __init__ is calling.
     352  
     353      Generator should be like this:
     354  
     355          def gen():
     356              ...
     357              when = yield ...
     358              ... = yield time_advance
     359  
     360      Value returned by yield is absolute time of next scheduled handler.
     361      Value passed to yield is time advance to move loop's time forward.
     362      """
     363  
     364      def __init__(self, gen=None):
     365          super().__init__()
     366  
     367          if gen is None:
     368              def gen():
     369                  yield
     370              self._check_on_close = False
     371          else:
     372              self._check_on_close = True
     373  
     374          self._gen = gen()
     375          next(self._gen)
     376          self._time = 0
     377          self._clock_resolution = 1e-9
     378          self._timers = []
     379          self._selector = TestSelector()
     380  
     381          self.readers = {}
     382          self.writers = {}
     383          self.reset_counters()
     384  
     385          self._transports = weakref.WeakValueDictionary()
     386  
     387      def time(self):
     388          return self._time
     389  
     390      def advance_time(self, advance):
     391          """Move test time forward."""
     392          if advance:
     393              self._time += advance
     394  
     395      def close(self):
     396          super().close()
     397          if self._check_on_close:
     398              try:
     399                  self._gen.send(0)
     400              except StopIteration:
     401                  pass
     402              else:  # pragma: no cover
     403                  raise AssertionError("Time generator is not finished")
     404  
     405      def _add_reader(self, fd, callback, *args):
     406          self.readers[fd] = events.Handle(callback, args, self, None)
     407  
     408      def _remove_reader(self, fd):
     409          self.remove_reader_count[fd] += 1
     410          if fd in self.readers:
     411              del self.readers[fd]
     412              return True
     413          else:
     414              return False
     415  
     416      def assert_reader(self, fd, callback, *args):
     417          if fd not in self.readers:
     418              raise AssertionError(f'fd {fd} is not registered')
     419          handle = self.readers[fd]
     420          if handle._callback != callback:
     421              raise AssertionError(
     422                  f'unexpected callback: {handle._callback} != {callback}')
     423          if handle._args != args:
     424              raise AssertionError(
     425                  f'unexpected callback args: {handle._args} != {args}')
     426  
     427      def assert_no_reader(self, fd):
     428          if fd in self.readers:
     429              raise AssertionError(f'fd {fd} is registered')
     430  
     431      def _add_writer(self, fd, callback, *args):
     432          self.writers[fd] = events.Handle(callback, args, self, None)
     433  
     434      def _remove_writer(self, fd):
     435          self.remove_writer_count[fd] += 1
     436          if fd in self.writers:
     437              del self.writers[fd]
     438              return True
     439          else:
     440              return False
     441  
     442      def assert_writer(self, fd, callback, *args):
     443          if fd not in self.writers:
     444              raise AssertionError(f'fd {fd} is not registered')
     445          handle = self.writers[fd]
     446          if handle._callback != callback:
     447              raise AssertionError(f'{handle._callback!r} != {callback!r}')
     448          if handle._args != args:
     449              raise AssertionError(f'{handle._args!r} != {args!r}')
     450  
     451      def _ensure_fd_no_transport(self, fd):
     452          if not isinstance(fd, int):
     453              try:
     454                  fd = int(fd.fileno())
     455              except (AttributeError, TypeError, ValueError):
     456                  # This code matches selectors._fileobj_to_fd function.
     457                  raise ValueError("Invalid file object: "
     458                                   "{!r}".format(fd)) from None
     459          try:
     460              transport = self._transports[fd]
     461          except KeyError:
     462              pass
     463          else:
     464              raise RuntimeError(
     465                  'File descriptor {!r} is used by transport {!r}'.format(
     466                      fd, transport))
     467  
     468      def add_reader(self, fd, callback, *args):
     469          """Add a reader callback."""
     470          self._ensure_fd_no_transport(fd)
     471          return self._add_reader(fd, callback, *args)
     472  
     473      def remove_reader(self, fd):
     474          """Remove a reader callback."""
     475          self._ensure_fd_no_transport(fd)
     476          return self._remove_reader(fd)
     477  
     478      def add_writer(self, fd, callback, *args):
     479          """Add a writer callback.."""
     480          self._ensure_fd_no_transport(fd)
     481          return self._add_writer(fd, callback, *args)
     482  
     483      def remove_writer(self, fd):
     484          """Remove a writer callback."""
     485          self._ensure_fd_no_transport(fd)
     486          return self._remove_writer(fd)
     487  
     488      def reset_counters(self):
     489          self.remove_reader_count = collections.defaultdict(int)
     490          self.remove_writer_count = collections.defaultdict(int)
     491  
     492      def _run_once(self):
     493          super()._run_once()
     494          for when in self._timers:
     495              advance = self._gen.send(when)
     496              self.advance_time(advance)
     497          self._timers = []
     498  
     499      def call_at(self, when, callback, *args, context=None):
     500          self._timers.append(when)
     501          return super().call_at(when, callback, *args, context=context)
     502  
     503      def _process_events(self, event_list):
     504          return
     505  
     506      def _write_to_self(self):
     507          pass
     508  
     509  
     510  def MockCallback(**kwargs):
     511      return mock.Mock(spec=['__call__'], **kwargs)
     512  
     513  
     514  class ESC[4;38;5;81mMockPattern(ESC[4;38;5;149mstr):
     515      """A regex based str with a fuzzy __eq__.
     516  
     517      Use this helper with 'mock.assert_called_with', or anywhere
     518      where a regex comparison between strings is needed.
     519  
     520      For instance:
     521         mock_call.assert_called_with(MockPattern('spam.*ham'))
     522      """
     523      def __eq__(self, other):
     524          return bool(re.search(str(self), other, re.S))
     525  
     526  
     527  class ESC[4;38;5;81mMockInstanceOf:
     528      def __init__(self, type):
     529          self._type = type
     530  
     531      def __eq__(self, other):
     532          return isinstance(other, self._type)
     533  
     534  
     535  def get_function_source(func):
     536      source = format_helpers._get_function_source(func)
     537      if source is None:
     538          raise ValueError("unable to get the source of %r" % (func,))
     539      return source
     540  
     541  
     542  class ESC[4;38;5;81mTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     543      @staticmethod
     544      def close_loop(loop):
     545          if loop._default_executor is not None:
     546              if not loop.is_closed():
     547                  loop.run_until_complete(loop.shutdown_default_executor())
     548              else:
     549                  loop._default_executor.shutdown(wait=True)
     550          loop.close()
     551  
     552          policy = support.maybe_get_event_loop_policy()
     553          if policy is not None:
     554              try:
     555                  watcher = policy.get_child_watcher()
     556              except NotImplementedError:
     557                  # watcher is not implemented by EventLoopPolicy, e.g. Windows
     558                  pass
     559              else:
     560                  if isinstance(watcher, asyncio.ThreadedChildWatcher):
     561                      # Wait for subprocess to finish, but not forever
     562                      for thread in list(watcher._threads.values()):
     563                          thread.join(timeout=support.SHORT_TIMEOUT)
     564                          if thread.is_alive():
     565                              raise RuntimeError(f"thread {thread} still alive: "
     566                                                 "subprocess still running")
     567  
     568  
     569      def set_event_loop(self, loop, *, cleanup=True):
     570          if loop is None:
     571              raise AssertionError('loop is None')
     572          # ensure that the event loop is passed explicitly in asyncio
     573          events.set_event_loop(None)
     574          if cleanup:
     575              self.addCleanup(self.close_loop, loop)
     576  
     577      def new_test_loop(self, gen=None):
     578          loop = TestLoop(gen)
     579          self.set_event_loop(loop)
     580          return loop
     581  
     582      def setUp(self):
     583          self._thread_cleanup = threading_helper.threading_setup()
     584  
     585      def tearDown(self):
     586          events.set_event_loop(None)
     587  
     588          # Detect CPython bug #23353: ensure that yield/yield-from is not used
     589          # in an except block of a generator
     590          self.assertEqual(sys.exc_info(), (None, None, None))
     591  
     592          self.doCleanups()
     593          threading_helper.threading_cleanup(*self._thread_cleanup)
     594          support.reap_children()
     595  
     596  
     597  @contextlib.contextmanager
     598  def disable_logger():
     599      """Context manager to disable asyncio logger.
     600  
     601      For example, it can be used to ignore warnings in debug mode.
     602      """
     603      old_level = logger.level
     604      try:
     605          logger.setLevel(logging.CRITICAL+1)
     606          yield
     607      finally:
     608          logger.setLevel(old_level)
     609  
     610  
     611  def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
     612                              family=socket.AF_INET):
     613      """Create a mock of a non-blocking socket."""
     614      sock = mock.MagicMock(socket.socket)
     615      sock.proto = proto
     616      sock.type = type
     617      sock.family = family
     618      sock.gettimeout.return_value = 0.0
     619      return sock
     620  
     621  
     622  async def await_without_task(coro):
     623      exc = None
     624      def func():
     625          try:
     626              for _ in coro.__await__():
     627                  pass
     628          except BaseException as err:
     629              nonlocal exc
     630              exc = err
     631      asyncio.get_running_loop().call_soon(func)
     632      await asyncio.sleep(0)
     633      if exc is not None:
     634          raise exc