(root)/
Python-3.11.7/
Lib/
test/
test_asyncio/
functional.py
       1  import asyncio
       2  import asyncio.events
       3  import contextlib
       4  import os
       5  import pprint
       6  import select
       7  import socket
       8  import tempfile
       9  import threading
      10  from test import support
      11  
      12  
      13  class ESC[4;38;5;81mFunctionalTestCaseMixin:
      14  
      15      def new_loop(self):
      16          return asyncio.new_event_loop()
      17  
      18      def run_loop_briefly(self, *, delay=0.01):
      19          self.loop.run_until_complete(asyncio.sleep(delay))
      20  
      21      def loop_exception_handler(self, loop, context):
      22          self.__unhandled_exceptions.append(context)
      23          self.loop.default_exception_handler(context)
      24  
      25      def setUp(self):
      26          self.loop = self.new_loop()
      27          asyncio.set_event_loop(None)
      28  
      29          self.loop.set_exception_handler(self.loop_exception_handler)
      30          self.__unhandled_exceptions = []
      31  
      32      def tearDown(self):
      33          try:
      34              self.loop.close()
      35  
      36              if self.__unhandled_exceptions:
      37                  print('Unexpected calls to loop.call_exception_handler():')
      38                  pprint.pprint(self.__unhandled_exceptions)
      39                  self.fail('unexpected calls to loop.call_exception_handler()')
      40  
      41          finally:
      42              asyncio.set_event_loop(None)
      43              self.loop = None
      44  
      45      def tcp_server(self, server_prog, *,
      46                     family=socket.AF_INET,
      47                     addr=None,
      48                     timeout=support.LOOPBACK_TIMEOUT,
      49                     backlog=1,
      50                     max_clients=10):
      51  
      52          if addr is None:
      53              if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
      54                  with tempfile.NamedTemporaryFile() as tmp:
      55                      addr = tmp.name
      56              else:
      57                  addr = ('127.0.0.1', 0)
      58  
      59          sock = socket.create_server(addr, family=family, backlog=backlog)
      60          if timeout is None:
      61              raise RuntimeError('timeout is required')
      62          if timeout <= 0:
      63              raise RuntimeError('only blocking sockets are supported')
      64          sock.settimeout(timeout)
      65  
      66          return TestThreadedServer(
      67              self, sock, server_prog, timeout, max_clients)
      68  
      69      def tcp_client(self, client_prog,
      70                     family=socket.AF_INET,
      71                     timeout=support.LOOPBACK_TIMEOUT):
      72  
      73          sock = socket.socket(family, socket.SOCK_STREAM)
      74  
      75          if timeout is None:
      76              raise RuntimeError('timeout is required')
      77          if timeout <= 0:
      78              raise RuntimeError('only blocking sockets are supported')
      79          sock.settimeout(timeout)
      80  
      81          return TestThreadedClient(
      82              self, sock, client_prog, timeout)
      83  
      84      def unix_server(self, *args, **kwargs):
      85          if not hasattr(socket, 'AF_UNIX'):
      86              raise NotImplementedError
      87          return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
      88  
      89      def unix_client(self, *args, **kwargs):
      90          if not hasattr(socket, 'AF_UNIX'):
      91              raise NotImplementedError
      92          return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
      93  
      94      @contextlib.contextmanager
      95      def unix_sock_name(self):
      96          with tempfile.TemporaryDirectory() as td:
      97              fn = os.path.join(td, 'sock')
      98              try:
      99                  yield fn
     100              finally:
     101                  try:
     102                      os.unlink(fn)
     103                  except OSError:
     104                      pass
     105  
     106      def _abort_socket_test(self, ex):
     107          try:
     108              self.loop.stop()
     109          finally:
     110              self.fail(ex)
     111  
     112  
     113  ##############################################################################
     114  # Socket Testing Utilities
     115  ##############################################################################
     116  
     117  
     118  class ESC[4;38;5;81mTestSocketWrapper:
     119  
     120      def __init__(self, sock):
     121          self.__sock = sock
     122  
     123      def recv_all(self, n):
     124          buf = b''
     125          while len(buf) < n:
     126              data = self.recv(n - len(buf))
     127              if data == b'':
     128                  raise ConnectionAbortedError
     129              buf += data
     130          return buf
     131  
     132      def start_tls(self, ssl_context, *,
     133                    server_side=False,
     134                    server_hostname=None):
     135  
     136          ssl_sock = ssl_context.wrap_socket(
     137              self.__sock, server_side=server_side,
     138              server_hostname=server_hostname,
     139              do_handshake_on_connect=False)
     140  
     141          try:
     142              ssl_sock.do_handshake()
     143          except:
     144              ssl_sock.close()
     145              raise
     146          finally:
     147              self.__sock.close()
     148  
     149          self.__sock = ssl_sock
     150  
     151      def __getattr__(self, name):
     152          return getattr(self.__sock, name)
     153  
     154      def __repr__(self):
     155          return '<{} {!r}>'.format(type(self).__name__, self.__sock)
     156  
     157  
     158  class ESC[4;38;5;81mSocketThread(ESC[4;38;5;149mthreadingESC[4;38;5;149m.ESC[4;38;5;149mThread):
     159  
     160      def stop(self):
     161          self._active = False
     162          self.join()
     163  
     164      def __enter__(self):
     165          self.start()
     166          return self
     167  
     168      def __exit__(self, *exc):
     169          self.stop()
     170  
     171  
     172  class ESC[4;38;5;81mTestThreadedClient(ESC[4;38;5;149mSocketThread):
     173  
     174      def __init__(self, test, sock, prog, timeout):
     175          threading.Thread.__init__(self, None, None, 'test-client')
     176          self.daemon = True
     177  
     178          self._timeout = timeout
     179          self._sock = sock
     180          self._active = True
     181          self._prog = prog
     182          self._test = test
     183  
     184      def run(self):
     185          try:
     186              self._prog(TestSocketWrapper(self._sock))
     187          except Exception as ex:
     188              self._test._abort_socket_test(ex)
     189  
     190  
     191  class ESC[4;38;5;81mTestThreadedServer(ESC[4;38;5;149mSocketThread):
     192  
     193      def __init__(self, test, sock, prog, timeout, max_clients):
     194          threading.Thread.__init__(self, None, None, 'test-server')
     195          self.daemon = True
     196  
     197          self._clients = 0
     198          self._finished_clients = 0
     199          self._max_clients = max_clients
     200          self._timeout = timeout
     201          self._sock = sock
     202          self._active = True
     203  
     204          self._prog = prog
     205  
     206          self._s1, self._s2 = socket.socketpair()
     207          self._s1.setblocking(False)
     208  
     209          self._test = test
     210  
     211      def stop(self):
     212          try:
     213              if self._s2 and self._s2.fileno() != -1:
     214                  try:
     215                      self._s2.send(b'stop')
     216                  except OSError:
     217                      pass
     218          finally:
     219              super().stop()
     220  
     221      def run(self):
     222          try:
     223              with self._sock:
     224                  self._sock.setblocking(False)
     225                  self._run()
     226          finally:
     227              self._s1.close()
     228              self._s2.close()
     229  
     230      def _run(self):
     231          while self._active:
     232              if self._clients >= self._max_clients:
     233                  return
     234  
     235              r, w, x = select.select(
     236                  [self._sock, self._s1], [], [], self._timeout)
     237  
     238              if self._s1 in r:
     239                  return
     240  
     241              if self._sock in r:
     242                  try:
     243                      conn, addr = self._sock.accept()
     244                  except BlockingIOError:
     245                      continue
     246                  except TimeoutError:
     247                      if not self._active:
     248                          return
     249                      else:
     250                          raise
     251                  else:
     252                      self._clients += 1
     253                      conn.settimeout(self._timeout)
     254                      try:
     255                          with conn:
     256                              self._handle_client(conn)
     257                      except Exception as ex:
     258                          self._active = False
     259                          try:
     260                              raise
     261                          finally:
     262                              self._test._abort_socket_test(ex)
     263  
     264      def _handle_client(self, sock):
     265          self._prog(TestSocketWrapper(sock))
     266  
     267      @property
     268      def addr(self):
     269          return self._sock.getsockname()