(root)/
Python-3.11.7/
Lib/
test/
test_asyncio/
test_sock_lowlevel.py
       1  import socket
       2  import asyncio
       3  import sys
       4  import unittest
       5  
       6  from asyncio import proactor_events
       7  from itertools import cycle, islice
       8  from unittest.mock import patch, Mock
       9  from test.test_asyncio import utils as test_utils
      10  from test import support
      11  from test.support import socket_helper
      12  
      13  if socket_helper.tcp_blackhole():
      14      raise unittest.SkipTest('Not relevant to ProactorEventLoop')
      15  
      16  
      17  def tearDownModule():
      18      asyncio.set_event_loop_policy(None)
      19  
      20  
      21  class ESC[4;38;5;81mMyProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
      22      connected = None
      23      done = None
      24  
      25      def __init__(self, loop=None):
      26          self.transport = None
      27          self.state = 'INITIAL'
      28          self.nbytes = 0
      29          if loop is not None:
      30              self.connected = loop.create_future()
      31              self.done = loop.create_future()
      32  
      33      def _assert_state(self, *expected):
      34          if self.state not in expected:
      35              raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
      36  
      37      def connection_made(self, transport):
      38          self.transport = transport
      39          self._assert_state('INITIAL')
      40          self.state = 'CONNECTED'
      41          if self.connected:
      42              self.connected.set_result(None)
      43          transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
      44  
      45      def data_received(self, data):
      46          self._assert_state('CONNECTED')
      47          self.nbytes += len(data)
      48  
      49      def eof_received(self):
      50          self._assert_state('CONNECTED')
      51          self.state = 'EOF'
      52  
      53      def connection_lost(self, exc):
      54          self._assert_state('CONNECTED', 'EOF')
      55          self.state = 'CLOSED'
      56          if self.done:
      57              self.done.set_result(None)
      58  
      59  
      60  class ESC[4;38;5;81mBaseSockTestsMixin:
      61  
      62      def create_event_loop(self):
      63          raise NotImplementedError
      64  
      65      def setUp(self):
      66          self.loop = self.create_event_loop()
      67          self.set_event_loop(self.loop)
      68          super().setUp()
      69  
      70      def tearDown(self):
      71          # just in case if we have transport close callbacks
      72          if not self.loop.is_closed():
      73              test_utils.run_briefly(self.loop)
      74  
      75          self.doCleanups()
      76          support.gc_collect()
      77          super().tearDown()
      78  
      79      def _basetest_sock_client_ops(self, httpd, sock):
      80          if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
      81              # in debug mode, socket operations must fail
      82              # if the socket is not in blocking mode
      83              self.loop.set_debug(True)
      84              sock.setblocking(True)
      85              with self.assertRaises(ValueError):
      86                  self.loop.run_until_complete(
      87                      self.loop.sock_connect(sock, httpd.address))
      88              with self.assertRaises(ValueError):
      89                  self.loop.run_until_complete(
      90                      self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
      91              with self.assertRaises(ValueError):
      92                  self.loop.run_until_complete(
      93                      self.loop.sock_recv(sock, 1024))
      94              with self.assertRaises(ValueError):
      95                  self.loop.run_until_complete(
      96                      self.loop.sock_recv_into(sock, bytearray()))
      97              with self.assertRaises(ValueError):
      98                  self.loop.run_until_complete(
      99                      self.loop.sock_accept(sock))
     100  
     101          # test in non-blocking mode
     102          sock.setblocking(False)
     103          self.loop.run_until_complete(
     104              self.loop.sock_connect(sock, httpd.address))
     105          self.loop.run_until_complete(
     106              self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
     107          data = self.loop.run_until_complete(
     108              self.loop.sock_recv(sock, 1024))
     109          # consume data
     110          self.loop.run_until_complete(
     111              self.loop.sock_recv(sock, 1024))
     112          sock.close()
     113          self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
     114  
     115      def _basetest_sock_recv_into(self, httpd, sock):
     116          # same as _basetest_sock_client_ops, but using sock_recv_into
     117          sock.setblocking(False)
     118          self.loop.run_until_complete(
     119              self.loop.sock_connect(sock, httpd.address))
     120          self.loop.run_until_complete(
     121              self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
     122          data = bytearray(1024)
     123          with memoryview(data) as buf:
     124              nbytes = self.loop.run_until_complete(
     125                  self.loop.sock_recv_into(sock, buf[:1024]))
     126              # consume data
     127              self.loop.run_until_complete(
     128                  self.loop.sock_recv_into(sock, buf[nbytes:]))
     129          sock.close()
     130          self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
     131  
     132      def test_sock_client_ops(self):
     133          with test_utils.run_test_server() as httpd:
     134              sock = socket.socket()
     135              self._basetest_sock_client_ops(httpd, sock)
     136              sock = socket.socket()
     137              self._basetest_sock_recv_into(httpd, sock)
     138  
     139      async def _basetest_sock_recv_racing(self, httpd, sock):
     140          sock.setblocking(False)
     141          await self.loop.sock_connect(sock, httpd.address)
     142  
     143          task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
     144          await asyncio.sleep(0)
     145          task.cancel()
     146  
     147          asyncio.create_task(
     148              self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
     149          data = await self.loop.sock_recv(sock, 1024)
     150          # consume data
     151          await self.loop.sock_recv(sock, 1024)
     152  
     153          self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
     154  
     155      async def _basetest_sock_recv_into_racing(self, httpd, sock):
     156          sock.setblocking(False)
     157          await self.loop.sock_connect(sock, httpd.address)
     158  
     159          data = bytearray(1024)
     160          with memoryview(data) as buf:
     161              task = asyncio.create_task(
     162                  self.loop.sock_recv_into(sock, buf[:1024]))
     163              await asyncio.sleep(0)
     164              task.cancel()
     165  
     166              task = asyncio.create_task(
     167                  self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
     168              nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
     169              # consume data
     170              await self.loop.sock_recv_into(sock, buf[nbytes:])
     171              self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
     172  
     173          await task
     174  
     175      async def _basetest_sock_send_racing(self, listener, sock):
     176          listener.bind(('127.0.0.1', 0))
     177          listener.listen(1)
     178  
     179          # make connection
     180          sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
     181          sock.setblocking(False)
     182          task = asyncio.create_task(
     183              self.loop.sock_connect(sock, listener.getsockname()))
     184          await asyncio.sleep(0)
     185          server = listener.accept()[0]
     186          server.setblocking(False)
     187  
     188          with server:
     189              await task
     190  
     191              # fill the buffer until sending 5 chars would block
     192              size = 8192
     193              while size >= 4:
     194                  with self.assertRaises(BlockingIOError):
     195                      while True:
     196                          sock.send(b' ' * size)
     197                  size = int(size / 2)
     198  
     199              # cancel a blocked sock_sendall
     200              task = asyncio.create_task(
     201                  self.loop.sock_sendall(sock, b'hello'))
     202              await asyncio.sleep(0)
     203              task.cancel()
     204  
     205              # receive everything that is not a space
     206              async def recv_all():
     207                  rv = b''
     208                  while True:
     209                      buf = await self.loop.sock_recv(server, 8192)
     210                      if not buf:
     211                          return rv
     212                      rv += buf.strip()
     213              task = asyncio.create_task(recv_all())
     214  
     215              # immediately make another sock_sendall call
     216              await self.loop.sock_sendall(sock, b'world')
     217              sock.shutdown(socket.SHUT_WR)
     218              data = await task
     219              # ProactorEventLoop could deliver hello, so endswith is necessary
     220              self.assertTrue(data.endswith(b'world'))
     221  
     222      # After the first connect attempt before the listener is ready,
     223      # the socket needs time to "recover" to make the next connect call.
     224      # On Linux, a second retry will do. On Windows, the waiting time is
     225      # unpredictable; and on FreeBSD the socket may never come back
     226      # because it's a loopback address. Here we'll just retry for a few
     227      # times, and have to skip the test if it's not working. See also:
     228      # https://stackoverflow.com/a/54437602/3316267
     229      # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html
     230      async def _basetest_sock_connect_racing(self, listener, sock):
     231          listener.bind(('127.0.0.1', 0))
     232          addr = listener.getsockname()
     233          sock.setblocking(False)
     234  
     235          task = asyncio.create_task(self.loop.sock_connect(sock, addr))
     236          await asyncio.sleep(0)
     237          task.cancel()
     238  
     239          listener.listen(1)
     240  
     241          skip_reason = "Max retries reached"
     242          for i in range(128):
     243              try:
     244                  await self.loop.sock_connect(sock, addr)
     245              except ConnectionRefusedError as e:
     246                  skip_reason = e
     247              except OSError as e:
     248                  skip_reason = e
     249  
     250                  # Retry only for this error:
     251                  # [WinError 10022] An invalid argument was supplied
     252                  if getattr(e, 'winerror', 0) != 10022:
     253                      break
     254              else:
     255                  # success
     256                  return
     257  
     258          self.skipTest(skip_reason)
     259  
     260      def test_sock_client_racing(self):
     261          with test_utils.run_test_server() as httpd:
     262              sock = socket.socket()
     263              with sock:
     264                  self.loop.run_until_complete(asyncio.wait_for(
     265                      self._basetest_sock_recv_racing(httpd, sock), 10))
     266              sock = socket.socket()
     267              with sock:
     268                  self.loop.run_until_complete(asyncio.wait_for(
     269                      self._basetest_sock_recv_into_racing(httpd, sock), 10))
     270          listener = socket.socket()
     271          sock = socket.socket()
     272          with listener, sock:
     273              self.loop.run_until_complete(asyncio.wait_for(
     274                  self._basetest_sock_send_racing(listener, sock), 10))
     275  
     276      def test_sock_client_connect_racing(self):
     277          listener = socket.socket()
     278          sock = socket.socket()
     279          with listener, sock:
     280              self.loop.run_until_complete(asyncio.wait_for(
     281                  self._basetest_sock_connect_racing(listener, sock), 10))
     282  
     283      async def _basetest_huge_content(self, address):
     284          sock = socket.socket()
     285          sock.setblocking(False)
     286          DATA_SIZE = 10_000_00
     287  
     288          chunk = b'0123456789' * (DATA_SIZE // 10)
     289  
     290          await self.loop.sock_connect(sock, address)
     291          await self.loop.sock_sendall(sock,
     292                                       (b'POST /loop HTTP/1.0\r\n' +
     293                                        b'Content-Length: %d\r\n' % DATA_SIZE +
     294                                        b'\r\n'))
     295  
     296          task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
     297  
     298          data = await self.loop.sock_recv(sock, DATA_SIZE)
     299          # HTTP headers size is less than MTU,
     300          # they are sent by the first packet always
     301          self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
     302          while data.find(b'\r\n\r\n') == -1:
     303              data += await self.loop.sock_recv(sock, DATA_SIZE)
     304          # Strip headers
     305          headers = data[:data.index(b'\r\n\r\n') + 4]
     306          data = data[len(headers):]
     307  
     308          size = DATA_SIZE
     309          checker = cycle(b'0123456789')
     310  
     311          expected = bytes(islice(checker, len(data)))
     312          self.assertEqual(data, expected)
     313          size -= len(data)
     314  
     315          while True:
     316              data = await self.loop.sock_recv(sock, DATA_SIZE)
     317              if not data:
     318                  break
     319              expected = bytes(islice(checker, len(data)))
     320              self.assertEqual(data, expected)
     321              size -= len(data)
     322          self.assertEqual(size, 0)
     323  
     324          await task
     325          sock.close()
     326  
     327      def test_huge_content(self):
     328          with test_utils.run_test_server() as httpd:
     329              self.loop.run_until_complete(
     330                  self._basetest_huge_content(httpd.address))
     331  
     332      async def _basetest_huge_content_recvinto(self, address):
     333          sock = socket.socket()
     334          sock.setblocking(False)
     335          DATA_SIZE = 10_000_00
     336  
     337          chunk = b'0123456789' * (DATA_SIZE // 10)
     338  
     339          await self.loop.sock_connect(sock, address)
     340          await self.loop.sock_sendall(sock,
     341                                       (b'POST /loop HTTP/1.0\r\n' +
     342                                        b'Content-Length: %d\r\n' % DATA_SIZE +
     343                                        b'\r\n'))
     344  
     345          task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
     346  
     347          array = bytearray(DATA_SIZE)
     348          buf = memoryview(array)
     349  
     350          nbytes = await self.loop.sock_recv_into(sock, buf)
     351          data = bytes(buf[:nbytes])
     352          # HTTP headers size is less than MTU,
     353          # they are sent by the first packet always
     354          self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
     355          while data.find(b'\r\n\r\n') == -1:
     356              nbytes = await self.loop.sock_recv_into(sock, buf)
     357              data = bytes(buf[:nbytes])
     358          # Strip headers
     359          headers = data[:data.index(b'\r\n\r\n') + 4]
     360          data = data[len(headers):]
     361  
     362          size = DATA_SIZE
     363          checker = cycle(b'0123456789')
     364  
     365          expected = bytes(islice(checker, len(data)))
     366          self.assertEqual(data, expected)
     367          size -= len(data)
     368  
     369          while True:
     370              nbytes = await self.loop.sock_recv_into(sock, buf)
     371              data = buf[:nbytes]
     372              if not data:
     373                  break
     374              expected = bytes(islice(checker, len(data)))
     375              self.assertEqual(data, expected)
     376              size -= len(data)
     377          self.assertEqual(size, 0)
     378  
     379          await task
     380          sock.close()
     381  
     382      def test_huge_content_recvinto(self):
     383          with test_utils.run_test_server() as httpd:
     384              self.loop.run_until_complete(
     385                  self._basetest_huge_content_recvinto(httpd.address))
     386  
     387      async def _basetest_datagram_recvfrom(self, server_address):
     388          # Happy path, sock.sendto() returns immediately
     389          data = b'\x01' * 4096
     390          with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
     391              sock.setblocking(False)
     392              await self.loop.sock_sendto(sock, data, server_address)
     393              received_data, from_addr = await self.loop.sock_recvfrom(
     394                  sock, 4096)
     395              self.assertEqual(received_data, data)
     396              self.assertEqual(from_addr, server_address)
     397  
     398      def test_recvfrom(self):
     399          with test_utils.run_udp_echo_server() as server_address:
     400              self.loop.run_until_complete(
     401                  self._basetest_datagram_recvfrom(server_address))
     402  
     403      async def _basetest_datagram_recvfrom_into(self, server_address):
     404          # Happy path, sock.sendto() returns immediately
     405          with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
     406              sock.setblocking(False)
     407  
     408              buf = bytearray(4096)
     409              data = b'\x01' * 4096
     410              await self.loop.sock_sendto(sock, data, server_address)
     411              num_bytes, from_addr = await self.loop.sock_recvfrom_into(
     412                  sock, buf)
     413              self.assertEqual(num_bytes, 4096)
     414              self.assertEqual(buf, data)
     415              self.assertEqual(from_addr, server_address)
     416  
     417              buf = bytearray(8192)
     418              await self.loop.sock_sendto(sock, data, server_address)
     419              num_bytes, from_addr = await self.loop.sock_recvfrom_into(
     420                  sock, buf, 4096)
     421              self.assertEqual(num_bytes, 4096)
     422              self.assertEqual(buf[:4096], data[:4096])
     423              self.assertEqual(from_addr, server_address)
     424  
     425      def test_recvfrom_into(self):
     426          with test_utils.run_udp_echo_server() as server_address:
     427              self.loop.run_until_complete(
     428                  self._basetest_datagram_recvfrom_into(server_address))
     429  
     430      async def _basetest_datagram_sendto_blocking(self, server_address):
     431          # Sad path, sock.sendto() raises BlockingIOError
     432          # This involves patching sock.sendto() to raise BlockingIOError but
     433          # sendto() is not used by the proactor event loop
     434          data = b'\x01' * 4096
     435          with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
     436              sock.setblocking(False)
     437              mock_sock = Mock(sock)
     438              mock_sock.gettimeout = sock.gettimeout
     439              mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
     440              mock_sock.fileno = sock.fileno
     441              self.loop.call_soon(
     442                  lambda: setattr(mock_sock, 'sendto', sock.sendto)
     443              )
     444              await self.loop.sock_sendto(mock_sock, data, server_address)
     445  
     446              received_data, from_addr = await self.loop.sock_recvfrom(
     447                  sock, 4096)
     448              self.assertEqual(received_data, data)
     449              self.assertEqual(from_addr, server_address)
     450  
     451      def test_sendto_blocking(self):
     452          if sys.platform == 'win32':
     453              if isinstance(self.loop, asyncio.ProactorEventLoop):
     454                  raise unittest.SkipTest('Not relevant to ProactorEventLoop')
     455  
     456          with test_utils.run_udp_echo_server() as server_address:
     457              self.loop.run_until_complete(
     458                  self._basetest_datagram_sendto_blocking(server_address))
     459  
     460      @socket_helper.skip_unless_bind_unix_socket
     461      def test_unix_sock_client_ops(self):
     462          with test_utils.run_test_unix_server() as httpd:
     463              sock = socket.socket(socket.AF_UNIX)
     464              self._basetest_sock_client_ops(httpd, sock)
     465              sock = socket.socket(socket.AF_UNIX)
     466              self._basetest_sock_recv_into(httpd, sock)
     467  
     468      def test_sock_client_fail(self):
     469          # Make sure that we will get an unused port
     470          address = None
     471          try:
     472              s = socket.socket()
     473              s.bind(('127.0.0.1', 0))
     474              address = s.getsockname()
     475          finally:
     476              s.close()
     477  
     478          sock = socket.socket()
     479          sock.setblocking(False)
     480          with self.assertRaises(ConnectionRefusedError):
     481              self.loop.run_until_complete(
     482                  self.loop.sock_connect(sock, address))
     483          sock.close()
     484  
     485      def test_sock_accept(self):
     486          listener = socket.socket()
     487          listener.setblocking(False)
     488          listener.bind(('127.0.0.1', 0))
     489          listener.listen(1)
     490          client = socket.socket()
     491          client.connect(listener.getsockname())
     492  
     493          f = self.loop.sock_accept(listener)
     494          conn, addr = self.loop.run_until_complete(f)
     495          self.assertEqual(conn.gettimeout(), 0)
     496          self.assertEqual(addr, client.getsockname())
     497          self.assertEqual(client.getpeername(), listener.getsockname())
     498          client.close()
     499          conn.close()
     500          listener.close()
     501  
     502      def test_cancel_sock_accept(self):
     503          listener = socket.socket()
     504          listener.setblocking(False)
     505          listener.bind(('127.0.0.1', 0))
     506          listener.listen(1)
     507          sockaddr = listener.getsockname()
     508          f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
     509          with self.assertRaises(asyncio.TimeoutError):
     510              self.loop.run_until_complete(f)
     511  
     512          listener.close()
     513          client = socket.socket()
     514          client.setblocking(False)
     515          f = self.loop.sock_connect(client, sockaddr)
     516          with self.assertRaises(ConnectionRefusedError):
     517              self.loop.run_until_complete(f)
     518  
     519          client.close()
     520  
     521      def test_create_connection_sock(self):
     522          with test_utils.run_test_server() as httpd:
     523              sock = None
     524              infos = self.loop.run_until_complete(
     525                  self.loop.getaddrinfo(
     526                      *httpd.address, type=socket.SOCK_STREAM))
     527              for family, type, proto, cname, address in infos:
     528                  try:
     529                      sock = socket.socket(family=family, type=type, proto=proto)
     530                      sock.setblocking(False)
     531                      self.loop.run_until_complete(
     532                          self.loop.sock_connect(sock, address))
     533                  except BaseException:
     534                      pass
     535                  else:
     536                      break
     537              else:
     538                  self.fail('Can not create socket.')
     539  
     540              f = self.loop.create_connection(
     541                  lambda: MyProto(loop=self.loop), sock=sock)
     542              tr, pr = self.loop.run_until_complete(f)
     543              self.assertIsInstance(tr, asyncio.Transport)
     544              self.assertIsInstance(pr, asyncio.Protocol)
     545              self.loop.run_until_complete(pr.done)
     546              self.assertGreater(pr.nbytes, 0)
     547              tr.close()
     548  
     549  
     550  if sys.platform == 'win32':
     551  
     552      class ESC[4;38;5;81mSelectEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
     553                                 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     554  
     555          def create_event_loop(self):
     556              return asyncio.SelectorEventLoop()
     557  
     558      class ESC[4;38;5;81mProactorEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
     559                                   ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     560  
     561          def create_event_loop(self):
     562              return asyncio.ProactorEventLoop()
     563  
     564  else:
     565      import selectors
     566  
     567      if hasattr(selectors, 'KqueueSelector'):
     568          class ESC[4;38;5;81mKqueueEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
     569                                     ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     570  
     571              def create_event_loop(self):
     572                  return asyncio.SelectorEventLoop(
     573                      selectors.KqueueSelector())
     574  
     575      if hasattr(selectors, 'EpollSelector'):
     576          class ESC[4;38;5;81mEPollEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
     577                                    ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     578  
     579              def create_event_loop(self):
     580                  return asyncio.SelectorEventLoop(selectors.EpollSelector())
     581  
     582      if hasattr(selectors, 'PollSelector'):
     583          class ESC[4;38;5;81mPollEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
     584                                   ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     585  
     586              def create_event_loop(self):
     587                  return asyncio.SelectorEventLoop(selectors.PollSelector())
     588  
     589      # Should always exist.
     590      class ESC[4;38;5;81mSelectEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
     591                                 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     592  
     593          def create_event_loop(self):
     594              return asyncio.SelectorEventLoop(selectors.SelectSelector())
     595  
     596  
     597  if __name__ == '__main__':
     598      unittest.main()