(root)/
Python-3.11.7/
Lib/
test/
test_asyncio/
test_ssl.py
       1  import asyncio
       2  import asyncio.sslproto
       3  import contextlib
       4  import gc
       5  import logging
       6  import select
       7  import socket
       8  import sys
       9  import tempfile
      10  import threading
      11  import time
      12  import weakref
      13  import unittest
      14  
      15  try:
      16      import ssl
      17  except ImportError:
      18      ssl = None
      19  
      20  from test import support
      21  from test.test_asyncio import utils as test_utils
      22  
      23  
      24  MACOS = (sys.platform == 'darwin')
      25  BUF_MULTIPLIER = 1024 if not MACOS else 64
      26  
      27  
      28  def tearDownModule():
      29      asyncio.set_event_loop_policy(None)
      30  
      31  
      32  class ESC[4;38;5;81mMyBaseProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
      33      connected = None
      34      done = None
      35  
      36      def __init__(self, loop=None):
      37          self.transport = None
      38          self.state = 'INITIAL'
      39          self.nbytes = 0
      40          if loop is not None:
      41              self.connected = asyncio.Future(loop=loop)
      42              self.done = asyncio.Future(loop=loop)
      43  
      44      def connection_made(self, transport):
      45          self.transport = transport
      46          assert self.state == 'INITIAL', self.state
      47          self.state = 'CONNECTED'
      48          if self.connected:
      49              self.connected.set_result(None)
      50  
      51      def data_received(self, data):
      52          assert self.state == 'CONNECTED', self.state
      53          self.nbytes += len(data)
      54  
      55      def eof_received(self):
      56          assert self.state == 'CONNECTED', self.state
      57          self.state = 'EOF'
      58  
      59      def connection_lost(self, exc):
      60          assert self.state in ('CONNECTED', 'EOF'), self.state
      61          self.state = 'CLOSED'
      62          if self.done:
      63              self.done.set_result(None)
      64  
      65  
      66  class ESC[4;38;5;81mMessageOutFilter(ESC[4;38;5;149mloggingESC[4;38;5;149m.ESC[4;38;5;149mFilter):
      67      def __init__(self, msg):
      68          self.msg = msg
      69  
      70      def filter(self, record):
      71          if self.msg in record.msg:
      72              return False
      73          return True
      74  
      75  
      76  @unittest.skipIf(ssl is None, 'No ssl module')
      77  class ESC[4;38;5;81mTestSSL(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      78  
      79      PAYLOAD_SIZE = 1024 * 100
      80      TIMEOUT = support.LONG_TIMEOUT
      81  
      82      def setUp(self):
      83          super().setUp()
      84          self.loop = asyncio.new_event_loop()
      85          self.set_event_loop(self.loop)
      86          self.addCleanup(self.loop.close)
      87  
      88      def tearDown(self):
      89          # just in case if we have transport close callbacks
      90          if not self.loop.is_closed():
      91              test_utils.run_briefly(self.loop)
      92  
      93          self.doCleanups()
      94          support.gc_collect()
      95          super().tearDown()
      96  
      97      def tcp_server(self, server_prog, *,
      98                     family=socket.AF_INET,
      99                     addr=None,
     100                     timeout=support.SHORT_TIMEOUT,
     101                     backlog=1,
     102                     max_clients=10):
     103  
     104          if addr is None:
     105              if family == getattr(socket, "AF_UNIX", None):
     106                  with tempfile.NamedTemporaryFile() as tmp:
     107                      addr = tmp.name
     108              else:
     109                  addr = ('127.0.0.1', 0)
     110  
     111          sock = socket.socket(family, socket.SOCK_STREAM)
     112  
     113          if timeout is None:
     114              raise RuntimeError('timeout is required')
     115          if timeout <= 0:
     116              raise RuntimeError('only blocking sockets are supported')
     117          sock.settimeout(timeout)
     118  
     119          try:
     120              sock.bind(addr)
     121              sock.listen(backlog)
     122          except OSError as ex:
     123              sock.close()
     124              raise ex
     125  
     126          return TestThreadedServer(
     127              self, sock, server_prog, timeout, max_clients)
     128  
     129      def tcp_client(self, client_prog,
     130                     family=socket.AF_INET,
     131                     timeout=support.SHORT_TIMEOUT):
     132  
     133          sock = socket.socket(family, socket.SOCK_STREAM)
     134  
     135          if timeout is None:
     136              raise RuntimeError('timeout is required')
     137          if timeout <= 0:
     138              raise RuntimeError('only blocking sockets are supported')
     139          sock.settimeout(timeout)
     140  
     141          return TestThreadedClient(
     142              self, sock, client_prog, timeout)
     143  
     144      def unix_server(self, *args, **kwargs):
     145          return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
     146  
     147      def unix_client(self, *args, **kwargs):
     148          return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
     149  
     150      def _create_server_ssl_context(self, certfile, keyfile=None):
     151          sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     152          sslcontext.options |= ssl.OP_NO_SSLv2
     153          sslcontext.load_cert_chain(certfile, keyfile)
     154          return sslcontext
     155  
     156      def _create_client_ssl_context(self, *, disable_verify=True):
     157          sslcontext = ssl.create_default_context()
     158          sslcontext.check_hostname = False
     159          if disable_verify:
     160              sslcontext.verify_mode = ssl.CERT_NONE
     161          return sslcontext
     162  
     163      @contextlib.contextmanager
     164      def _silence_eof_received_warning(self):
     165          # TODO This warning has to be fixed in asyncio.
     166          logger = logging.getLogger('asyncio')
     167          filter = MessageOutFilter('has no effect when using ssl')
     168          logger.addFilter(filter)
     169          try:
     170              yield
     171          finally:
     172              logger.removeFilter(filter)
     173  
     174      def _abort_socket_test(self, ex):
     175          try:
     176              self.loop.stop()
     177          finally:
     178              self.fail(ex)
     179  
     180      def new_loop(self):
     181          return asyncio.new_event_loop()
     182  
     183      def new_policy(self):
     184          return asyncio.DefaultEventLoopPolicy()
     185  
     186      async def wait_closed(self, obj):
     187          if not isinstance(obj, asyncio.StreamWriter):
     188              return
     189          try:
     190              await obj.wait_closed()
     191          except (BrokenPipeError, ConnectionError):
     192              pass
     193  
     194      def test_create_server_ssl_1(self):
     195          CNT = 0           # number of clients that were successful
     196          TOTAL_CNT = 25    # total number of clients that test will create
     197          TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
     198  
     199          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
     200          B_DATA = b'B' * 1024 * BUF_MULTIPLIER
     201  
     202          sslctx = self._create_server_ssl_context(
     203              test_utils.ONLYCERT, test_utils.ONLYKEY
     204          )
     205          client_sslctx = self._create_client_ssl_context()
     206  
     207          clients = []
     208  
     209          async def handle_client(reader, writer):
     210              nonlocal CNT
     211  
     212              data = await reader.readexactly(len(A_DATA))
     213              self.assertEqual(data, A_DATA)
     214              writer.write(b'OK')
     215  
     216              data = await reader.readexactly(len(B_DATA))
     217              self.assertEqual(data, B_DATA)
     218              writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
     219  
     220              await writer.drain()
     221              writer.close()
     222  
     223              CNT += 1
     224  
     225          async def test_client(addr):
     226              fut = asyncio.Future()
     227  
     228              def prog(sock):
     229                  try:
     230                      sock.starttls(client_sslctx)
     231                      sock.connect(addr)
     232                      sock.send(A_DATA)
     233  
     234                      data = sock.recv_all(2)
     235                      self.assertEqual(data, b'OK')
     236  
     237                      sock.send(B_DATA)
     238                      data = sock.recv_all(4)
     239                      self.assertEqual(data, b'SPAM')
     240  
     241                      sock.close()
     242  
     243                  except Exception as ex:
     244                      self.loop.call_soon_threadsafe(fut.set_exception, ex)
     245                  else:
     246                      self.loop.call_soon_threadsafe(fut.set_result, None)
     247  
     248              client = self.tcp_client(prog)
     249              client.start()
     250              clients.append(client)
     251  
     252              await fut
     253  
     254          async def start_server():
     255              extras = {}
     256              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
     257  
     258              srv = await asyncio.start_server(
     259                  handle_client,
     260                  '127.0.0.1', 0,
     261                  family=socket.AF_INET,
     262                  ssl=sslctx,
     263                  **extras)
     264  
     265              try:
     266                  srv_socks = srv.sockets
     267                  self.assertTrue(srv_socks)
     268  
     269                  addr = srv_socks[0].getsockname()
     270  
     271                  tasks = []
     272                  for _ in range(TOTAL_CNT):
     273                      tasks.append(test_client(addr))
     274  
     275                  await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
     276  
     277              finally:
     278                  self.loop.call_soon(srv.close)
     279                  await srv.wait_closed()
     280  
     281          with self._silence_eof_received_warning():
     282              self.loop.run_until_complete(start_server())
     283  
     284          self.assertEqual(CNT, TOTAL_CNT)
     285  
     286          for client in clients:
     287              client.stop()
     288  
     289      def test_create_connection_ssl_1(self):
     290          self.loop.set_exception_handler(None)
     291  
     292          CNT = 0
     293          TOTAL_CNT = 25
     294  
     295          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
     296          B_DATA = b'B' * 1024 * BUF_MULTIPLIER
     297  
     298          sslctx = self._create_server_ssl_context(
     299              test_utils.ONLYCERT,
     300              test_utils.ONLYKEY
     301          )
     302          client_sslctx = self._create_client_ssl_context()
     303  
     304          def server(sock):
     305              sock.starttls(
     306                  sslctx,
     307                  server_side=True)
     308  
     309              data = sock.recv_all(len(A_DATA))
     310              self.assertEqual(data, A_DATA)
     311              sock.send(b'OK')
     312  
     313              data = sock.recv_all(len(B_DATA))
     314              self.assertEqual(data, B_DATA)
     315              sock.send(b'SPAM')
     316  
     317              sock.close()
     318  
     319          async def client(addr):
     320              extras = {}
     321              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
     322  
     323              reader, writer = await asyncio.open_connection(
     324                  *addr,
     325                  ssl=client_sslctx,
     326                  server_hostname='',
     327                  **extras)
     328  
     329              writer.write(A_DATA)
     330              self.assertEqual(await reader.readexactly(2), b'OK')
     331  
     332              writer.write(B_DATA)
     333              self.assertEqual(await reader.readexactly(4), b'SPAM')
     334  
     335              nonlocal CNT
     336              CNT += 1
     337  
     338              writer.close()
     339              await self.wait_closed(writer)
     340  
     341          async def client_sock(addr):
     342              sock = socket.socket()
     343              sock.connect(addr)
     344              reader, writer = await asyncio.open_connection(
     345                  sock=sock,
     346                  ssl=client_sslctx,
     347                  server_hostname='')
     348  
     349              writer.write(A_DATA)
     350              self.assertEqual(await reader.readexactly(2), b'OK')
     351  
     352              writer.write(B_DATA)
     353              self.assertEqual(await reader.readexactly(4), b'SPAM')
     354  
     355              nonlocal CNT
     356              CNT += 1
     357  
     358              writer.close()
     359              await self.wait_closed(writer)
     360              sock.close()
     361  
     362          def run(coro):
     363              nonlocal CNT
     364              CNT = 0
     365  
     366              async def _gather(*tasks):
     367                  # trampoline
     368                  return await asyncio.gather(*tasks)
     369  
     370              with self.tcp_server(server,
     371                                   max_clients=TOTAL_CNT,
     372                                   backlog=TOTAL_CNT) as srv:
     373                  tasks = []
     374                  for _ in range(TOTAL_CNT):
     375                      tasks.append(coro(srv.addr))
     376  
     377                  self.loop.run_until_complete(_gather(*tasks))
     378  
     379              self.assertEqual(CNT, TOTAL_CNT)
     380  
     381          with self._silence_eof_received_warning():
     382              run(client)
     383  
     384          with self._silence_eof_received_warning():
     385              run(client_sock)
     386  
     387      def test_create_connection_ssl_slow_handshake(self):
     388          client_sslctx = self._create_client_ssl_context()
     389  
     390          # silence error logger
     391          self.loop.set_exception_handler(lambda *args: None)
     392  
     393          def server(sock):
     394              try:
     395                  sock.recv_all(1024 * 1024)
     396              except ConnectionAbortedError:
     397                  pass
     398              finally:
     399                  sock.close()
     400  
     401          async def client(addr):
     402              reader, writer = await asyncio.open_connection(
     403                  *addr,
     404                  ssl=client_sslctx,
     405                  server_hostname='',
     406                  ssl_handshake_timeout=1.0)
     407              writer.close()
     408              await self.wait_closed(writer)
     409  
     410          with self.tcp_server(server,
     411                               max_clients=1,
     412                               backlog=1) as srv:
     413  
     414              with self.assertRaisesRegex(
     415                      ConnectionAbortedError,
     416                      r'SSL handshake.*is taking longer'):
     417  
     418                  self.loop.run_until_complete(client(srv.addr))
     419  
     420      def test_create_connection_ssl_failed_certificate(self):
     421          # silence error logger
     422          self.loop.set_exception_handler(lambda *args: None)
     423  
     424          sslctx = self._create_server_ssl_context(
     425              test_utils.ONLYCERT,
     426              test_utils.ONLYKEY
     427          )
     428          client_sslctx = self._create_client_ssl_context(disable_verify=False)
     429  
     430          def server(sock):
     431              try:
     432                  sock.starttls(
     433                      sslctx,
     434                      server_side=True)
     435                  sock.connect()
     436              except (ssl.SSLError, OSError):
     437                  pass
     438              finally:
     439                  sock.close()
     440  
     441          async def client(addr):
     442              reader, writer = await asyncio.open_connection(
     443                  *addr,
     444                  ssl=client_sslctx,
     445                  server_hostname='',
     446                  ssl_handshake_timeout=support.SHORT_TIMEOUT)
     447              writer.close()
     448              await self.wait_closed(writer)
     449  
     450          with self.tcp_server(server,
     451                               max_clients=1,
     452                               backlog=1) as srv:
     453  
     454              with self.assertRaises(ssl.SSLCertVerificationError):
     455                  self.loop.run_until_complete(client(srv.addr))
     456  
     457      def test_ssl_handshake_timeout(self):
     458          # bpo-29970: Check that a connection is aborted if handshake is not
     459          # completed in timeout period, instead of remaining open indefinitely
     460          client_sslctx = test_utils.simple_client_sslcontext()
     461  
     462          # silence error logger
     463          messages = []
     464          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     465  
     466          server_side_aborted = False
     467  
     468          def server(sock):
     469              nonlocal server_side_aborted
     470              try:
     471                  sock.recv_all(1024 * 1024)
     472              except ConnectionAbortedError:
     473                  server_side_aborted = True
     474              finally:
     475                  sock.close()
     476  
     477          async def client(addr):
     478              await asyncio.wait_for(
     479                  self.loop.create_connection(
     480                      asyncio.Protocol,
     481                      *addr,
     482                      ssl=client_sslctx,
     483                      server_hostname='',
     484                      ssl_handshake_timeout=10.0),
     485                  0.5)
     486  
     487          with self.tcp_server(server,
     488                               max_clients=1,
     489                               backlog=1) as srv:
     490  
     491              with self.assertRaises(asyncio.TimeoutError):
     492                  self.loop.run_until_complete(client(srv.addr))
     493  
     494          self.assertTrue(server_side_aborted)
     495  
     496          # Python issue #23197: cancelling a handshake must not raise an
     497          # exception or log an error, even if the handshake failed
     498          self.assertEqual(messages, [])
     499  
     500      def test_ssl_handshake_connection_lost(self):
     501          # #246: make sure that no connection_lost() is called before
     502          # connection_made() is called first
     503  
     504          client_sslctx = test_utils.simple_client_sslcontext()
     505  
     506          # silence error logger
     507          self.loop.set_exception_handler(lambda loop, ctx: None)
     508  
     509          connection_made_called = False
     510          connection_lost_called = False
     511  
     512          def server(sock):
     513              sock.recv(1024)
     514              # break the connection during handshake
     515              sock.close()
     516  
     517          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     518              def connection_made(self, transport):
     519                  nonlocal connection_made_called
     520                  connection_made_called = True
     521  
     522              def connection_lost(self, exc):
     523                  nonlocal connection_lost_called
     524                  connection_lost_called = True
     525  
     526          async def client(addr):
     527              await self.loop.create_connection(
     528                  ClientProto,
     529                  *addr,
     530                  ssl=client_sslctx,
     531                  server_hostname=''),
     532  
     533          with self.tcp_server(server,
     534                               max_clients=1,
     535                               backlog=1) as srv:
     536  
     537              with self.assertRaises(ConnectionResetError):
     538                  self.loop.run_until_complete(client(srv.addr))
     539  
     540          if connection_lost_called:
     541              if connection_made_called:
     542                  self.fail("unexpected call to connection_lost()")
     543              else:
     544                  self.fail("unexpected call to connection_lost() without"
     545                            "calling connection_made()")
     546          elif connection_made_called:
     547              self.fail("unexpected call to connection_made()")
     548  
     549      def test_ssl_connect_accepted_socket(self):
     550          proto = ssl.PROTOCOL_TLS_SERVER
     551          server_context = ssl.SSLContext(proto)
     552          server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
     553          if hasattr(server_context, 'check_hostname'):
     554              server_context.check_hostname = False
     555          server_context.verify_mode = ssl.CERT_NONE
     556  
     557          client_context = ssl.SSLContext(proto)
     558          if hasattr(server_context, 'check_hostname'):
     559              client_context.check_hostname = False
     560          client_context.verify_mode = ssl.CERT_NONE
     561  
     562      def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
     563          loop = self.loop
     564  
     565          class ESC[4;38;5;81mMyProto(ESC[4;38;5;149mMyBaseProto):
     566  
     567              def connection_lost(self, exc):
     568                  super().connection_lost(exc)
     569                  loop.call_soon(loop.stop)
     570  
     571              def data_received(self, data):
     572                  super().data_received(data)
     573                  self.transport.write(expected_response)
     574  
     575          lsock = socket.socket(socket.AF_INET)
     576          lsock.bind(('127.0.0.1', 0))
     577          lsock.listen(1)
     578          addr = lsock.getsockname()
     579  
     580          message = b'test data'
     581          response = None
     582          expected_response = b'roger'
     583  
     584          def client():
     585              nonlocal response
     586              try:
     587                  csock = socket.socket(socket.AF_INET)
     588                  if client_ssl is not None:
     589                      csock = client_ssl.wrap_socket(csock)
     590                  csock.connect(addr)
     591                  csock.sendall(message)
     592                  response = csock.recv(99)
     593                  csock.close()
     594              except Exception as exc:
     595                  print(
     596                      "Failure in client thread in test_connect_accepted_socket",
     597                      exc)
     598  
     599          thread = threading.Thread(target=client, daemon=True)
     600          thread.start()
     601  
     602          conn, _ = lsock.accept()
     603          proto = MyProto(loop=loop)
     604          proto.loop = loop
     605  
     606          extras = {}
     607          if server_ssl:
     608              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
     609  
     610          f = loop.create_task(
     611              loop.connect_accepted_socket(
     612                  (lambda: proto), conn, ssl=server_ssl,
     613                  **extras))
     614          loop.run_forever()
     615          conn.close()
     616          lsock.close()
     617  
     618          thread.join(1)
     619          self.assertFalse(thread.is_alive())
     620          self.assertEqual(proto.state, 'CLOSED')
     621          self.assertEqual(proto.nbytes, len(message))
     622          self.assertEqual(response, expected_response)
     623          tr, _ = f.result()
     624  
     625          if server_ssl:
     626              self.assertIn('SSL', tr.__class__.__name__)
     627  
     628          tr.close()
     629          # let it close
     630          self.loop.run_until_complete(asyncio.sleep(0.1))
     631  
     632      def test_start_tls_client_corrupted_ssl(self):
     633          self.loop.set_exception_handler(lambda loop, ctx: None)
     634  
     635          sslctx = test_utils.simple_server_sslcontext()
     636          client_sslctx = test_utils.simple_client_sslcontext()
     637  
     638          def server(sock):
     639              orig_sock = sock.dup()
     640              try:
     641                  sock.starttls(
     642                      sslctx,
     643                      server_side=True)
     644                  sock.sendall(b'A\n')
     645                  sock.recv_all(1)
     646                  orig_sock.send(b'please corrupt the SSL connection')
     647              except ssl.SSLError:
     648                  pass
     649              finally:
     650                  sock.close()
     651                  orig_sock.close()
     652  
     653          async def client(addr):
     654              reader, writer = await asyncio.open_connection(
     655                  *addr,
     656                  ssl=client_sslctx,
     657                  server_hostname='')
     658  
     659              self.assertEqual(await reader.readline(), b'A\n')
     660              writer.write(b'B')
     661              with self.assertRaises(ssl.SSLError):
     662                  await reader.readline()
     663              writer.close()
     664              try:
     665                  await self.wait_closed(writer)
     666              except ssl.SSLError:
     667                  pass
     668              return 'OK'
     669  
     670          with self.tcp_server(server,
     671                               max_clients=1,
     672                               backlog=1) as srv:
     673  
     674              res = self.loop.run_until_complete(client(srv.addr))
     675  
     676          self.assertEqual(res, 'OK')
     677  
     678      def test_start_tls_client_reg_proto_1(self):
     679          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     680  
     681          server_context = test_utils.simple_server_sslcontext()
     682          client_context = test_utils.simple_client_sslcontext()
     683  
     684          def serve(sock):
     685              sock.settimeout(self.TIMEOUT)
     686  
     687              data = sock.recv_all(len(HELLO_MSG))
     688              self.assertEqual(len(data), len(HELLO_MSG))
     689  
     690              sock.starttls(server_context, server_side=True)
     691  
     692              sock.sendall(b'O')
     693              data = sock.recv_all(len(HELLO_MSG))
     694              self.assertEqual(len(data), len(HELLO_MSG))
     695  
     696              sock.unwrap()
     697              sock.close()
     698  
     699          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     700              def __init__(self, on_data, on_eof):
     701                  self.on_data = on_data
     702                  self.on_eof = on_eof
     703                  self.con_made_cnt = 0
     704  
     705              def connection_made(proto, tr):
     706                  proto.con_made_cnt += 1
     707                  # Ensure connection_made gets called only once.
     708                  self.assertEqual(proto.con_made_cnt, 1)
     709  
     710              def data_received(self, data):
     711                  self.on_data.set_result(data)
     712  
     713              def eof_received(self):
     714                  self.on_eof.set_result(True)
     715  
     716          async def client(addr):
     717              await asyncio.sleep(0.5)
     718  
     719              on_data = self.loop.create_future()
     720              on_eof = self.loop.create_future()
     721  
     722              tr, proto = await self.loop.create_connection(
     723                  lambda: ClientProto(on_data, on_eof), *addr)
     724  
     725              tr.write(HELLO_MSG)
     726              new_tr = await self.loop.start_tls(tr, proto, client_context)
     727  
     728              self.assertEqual(await on_data, b'O')
     729              new_tr.write(HELLO_MSG)
     730              await on_eof
     731  
     732              new_tr.close()
     733  
     734          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     735              self.loop.run_until_complete(
     736                  asyncio.wait_for(client(srv.addr),
     737                                   timeout=support.SHORT_TIMEOUT))
     738  
     739      def test_create_connection_memory_leak(self):
     740          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     741  
     742          server_context = self._create_server_ssl_context(
     743              test_utils.ONLYCERT, test_utils.ONLYKEY)
     744          client_context = self._create_client_ssl_context()
     745  
     746          def serve(sock):
     747              sock.settimeout(self.TIMEOUT)
     748  
     749              sock.starttls(server_context, server_side=True)
     750  
     751              sock.sendall(b'O')
     752              data = sock.recv_all(len(HELLO_MSG))
     753              self.assertEqual(len(data), len(HELLO_MSG))
     754  
     755              sock.unwrap()
     756              sock.close()
     757  
     758          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     759              def __init__(self, on_data, on_eof):
     760                  self.on_data = on_data
     761                  self.on_eof = on_eof
     762                  self.con_made_cnt = 0
     763  
     764              def connection_made(proto, tr):
     765                  # XXX: We assume user stores the transport in protocol
     766                  proto.tr = tr
     767                  proto.con_made_cnt += 1
     768                  # Ensure connection_made gets called only once.
     769                  self.assertEqual(proto.con_made_cnt, 1)
     770  
     771              def data_received(self, data):
     772                  self.on_data.set_result(data)
     773  
     774              def eof_received(self):
     775                  self.on_eof.set_result(True)
     776  
     777          async def client(addr):
     778              await asyncio.sleep(0.5)
     779  
     780              on_data = self.loop.create_future()
     781              on_eof = self.loop.create_future()
     782  
     783              tr, proto = await self.loop.create_connection(
     784                  lambda: ClientProto(on_data, on_eof), *addr,
     785                  ssl=client_context)
     786  
     787              self.assertEqual(await on_data, b'O')
     788              tr.write(HELLO_MSG)
     789              await on_eof
     790  
     791              tr.close()
     792  
     793          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     794              self.loop.run_until_complete(
     795                  asyncio.wait_for(client(srv.addr),
     796                                   timeout=support.SHORT_TIMEOUT))
     797  
     798          # No garbage is left for SSL client from loop.create_connection, even
     799          # if user stores the SSLTransport in corresponding protocol instance
     800          client_context = weakref.ref(client_context)
     801          self.assertIsNone(client_context())
     802  
     803      def test_start_tls_client_buf_proto_1(self):
     804          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     805  
     806          server_context = test_utils.simple_server_sslcontext()
     807          client_context = test_utils.simple_client_sslcontext()
     808  
     809          client_con_made_calls = 0
     810  
     811          def serve(sock):
     812              sock.settimeout(self.TIMEOUT)
     813  
     814              data = sock.recv_all(len(HELLO_MSG))
     815              self.assertEqual(len(data), len(HELLO_MSG))
     816  
     817              sock.starttls(server_context, server_side=True)
     818  
     819              sock.sendall(b'O')
     820              data = sock.recv_all(len(HELLO_MSG))
     821              self.assertEqual(len(data), len(HELLO_MSG))
     822  
     823              sock.sendall(b'2')
     824              data = sock.recv_all(len(HELLO_MSG))
     825              self.assertEqual(len(data), len(HELLO_MSG))
     826  
     827              sock.unwrap()
     828              sock.close()
     829  
     830          class ESC[4;38;5;81mClientProtoFirst(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
     831              def __init__(self, on_data):
     832                  self.on_data = on_data
     833                  self.buf = bytearray(1)
     834  
     835              def connection_made(self, tr):
     836                  nonlocal client_con_made_calls
     837                  client_con_made_calls += 1
     838  
     839              def get_buffer(self, sizehint):
     840                  return self.buf
     841  
     842              def buffer_updated(self, nsize):
     843                  assert nsize == 1
     844                  self.on_data.set_result(bytes(self.buf[:nsize]))
     845  
     846              def eof_received(self):
     847                  pass
     848  
     849          class ESC[4;38;5;81mClientProtoSecond(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     850              def __init__(self, on_data, on_eof):
     851                  self.on_data = on_data
     852                  self.on_eof = on_eof
     853                  self.con_made_cnt = 0
     854  
     855              def connection_made(self, tr):
     856                  nonlocal client_con_made_calls
     857                  client_con_made_calls += 1
     858  
     859              def data_received(self, data):
     860                  self.on_data.set_result(data)
     861  
     862              def eof_received(self):
     863                  self.on_eof.set_result(True)
     864  
     865          async def client(addr):
     866              await asyncio.sleep(0.5)
     867  
     868              on_data1 = self.loop.create_future()
     869              on_data2 = self.loop.create_future()
     870              on_eof = self.loop.create_future()
     871  
     872              tr, proto = await self.loop.create_connection(
     873                  lambda: ClientProtoFirst(on_data1), *addr)
     874  
     875              tr.write(HELLO_MSG)
     876              new_tr = await self.loop.start_tls(tr, proto, client_context)
     877  
     878              self.assertEqual(await on_data1, b'O')
     879              new_tr.write(HELLO_MSG)
     880  
     881              new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
     882              self.assertEqual(await on_data2, b'2')
     883              new_tr.write(HELLO_MSG)
     884              await on_eof
     885  
     886              new_tr.close()
     887  
     888              # connection_made() should be called only once -- when
     889              # we establish connection for the first time. Start TLS
     890              # doesn't call connection_made() on application protocols.
     891              self.assertEqual(client_con_made_calls, 1)
     892  
     893          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     894              self.loop.run_until_complete(
     895                  asyncio.wait_for(client(srv.addr),
     896                                   timeout=self.TIMEOUT))
     897  
     898      def test_start_tls_slow_client_cancel(self):
     899          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     900  
     901          client_context = test_utils.simple_client_sslcontext()
     902          server_waits_on_handshake = self.loop.create_future()
     903  
     904          def serve(sock):
     905              sock.settimeout(self.TIMEOUT)
     906  
     907              data = sock.recv_all(len(HELLO_MSG))
     908              self.assertEqual(len(data), len(HELLO_MSG))
     909  
     910              try:
     911                  self.loop.call_soon_threadsafe(
     912                      server_waits_on_handshake.set_result, None)
     913                  data = sock.recv_all(1024 * 1024)
     914              except ConnectionAbortedError:
     915                  pass
     916              finally:
     917                  sock.close()
     918  
     919          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     920              def __init__(self, on_data, on_eof):
     921                  self.on_data = on_data
     922                  self.on_eof = on_eof
     923                  self.con_made_cnt = 0
     924  
     925              def connection_made(proto, tr):
     926                  proto.con_made_cnt += 1
     927                  # Ensure connection_made gets called only once.
     928                  self.assertEqual(proto.con_made_cnt, 1)
     929  
     930              def data_received(self, data):
     931                  self.on_data.set_result(data)
     932  
     933              def eof_received(self):
     934                  self.on_eof.set_result(True)
     935  
     936          async def client(addr):
     937              await asyncio.sleep(0.5)
     938  
     939              on_data = self.loop.create_future()
     940              on_eof = self.loop.create_future()
     941  
     942              tr, proto = await self.loop.create_connection(
     943                  lambda: ClientProto(on_data, on_eof), *addr)
     944  
     945              tr.write(HELLO_MSG)
     946  
     947              await server_waits_on_handshake
     948  
     949              with self.assertRaises(asyncio.TimeoutError):
     950                  await asyncio.wait_for(
     951                      self.loop.start_tls(tr, proto, client_context),
     952                      0.5)
     953  
     954          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     955              self.loop.run_until_complete(
     956                  asyncio.wait_for(client(srv.addr),
     957                                   timeout=support.SHORT_TIMEOUT))
     958  
     959      def test_start_tls_server_1(self):
     960          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     961  
     962          server_context = test_utils.simple_server_sslcontext()
     963          client_context = test_utils.simple_client_sslcontext()
     964  
     965          def client(sock, addr):
     966              sock.settimeout(self.TIMEOUT)
     967  
     968              sock.connect(addr)
     969              data = sock.recv_all(len(HELLO_MSG))
     970              self.assertEqual(len(data), len(HELLO_MSG))
     971  
     972              sock.starttls(client_context)
     973              sock.sendall(HELLO_MSG)
     974  
     975              sock.unwrap()
     976              sock.close()
     977  
     978          class ESC[4;38;5;81mServerProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     979              def __init__(self, on_con, on_eof, on_con_lost):
     980                  self.on_con = on_con
     981                  self.on_eof = on_eof
     982                  self.on_con_lost = on_con_lost
     983                  self.data = b''
     984  
     985              def connection_made(self, tr):
     986                  self.on_con.set_result(tr)
     987  
     988              def data_received(self, data):
     989                  self.data += data
     990  
     991              def eof_received(self):
     992                  self.on_eof.set_result(1)
     993  
     994              def connection_lost(self, exc):
     995                  if exc is None:
     996                      self.on_con_lost.set_result(None)
     997                  else:
     998                      self.on_con_lost.set_exception(exc)
     999  
    1000          async def main(proto, on_con, on_eof, on_con_lost):
    1001              tr = await on_con
    1002              tr.write(HELLO_MSG)
    1003  
    1004              self.assertEqual(proto.data, b'')
    1005  
    1006              new_tr = await self.loop.start_tls(
    1007                  tr, proto, server_context,
    1008                  server_side=True,
    1009                  ssl_handshake_timeout=self.TIMEOUT)
    1010  
    1011              await on_eof
    1012              await on_con_lost
    1013              self.assertEqual(proto.data, HELLO_MSG)
    1014              new_tr.close()
    1015  
    1016          async def run_main():
    1017              on_con = self.loop.create_future()
    1018              on_eof = self.loop.create_future()
    1019              on_con_lost = self.loop.create_future()
    1020              proto = ServerProto(on_con, on_eof, on_con_lost)
    1021  
    1022              server = await self.loop.create_server(
    1023                  lambda: proto, '127.0.0.1', 0)
    1024              addr = server.sockets[0].getsockname()
    1025  
    1026              with self.tcp_client(lambda sock: client(sock, addr),
    1027                                   timeout=self.TIMEOUT):
    1028                  await asyncio.wait_for(
    1029                      main(proto, on_con, on_eof, on_con_lost),
    1030                      timeout=self.TIMEOUT)
    1031  
    1032              server.close()
    1033              await server.wait_closed()
    1034  
    1035          self.loop.run_until_complete(run_main())
    1036  
    1037      def test_create_server_ssl_over_ssl(self):
    1038          CNT = 0           # number of clients that were successful
    1039          TOTAL_CNT = 25    # total number of clients that test will create
    1040          TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
    1041  
    1042          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
    1043          B_DATA = b'B' * 1024 * BUF_MULTIPLIER
    1044  
    1045          sslctx_1 = self._create_server_ssl_context(
    1046              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1047          client_sslctx_1 = self._create_client_ssl_context()
    1048          sslctx_2 = self._create_server_ssl_context(
    1049              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1050          client_sslctx_2 = self._create_client_ssl_context()
    1051  
    1052          clients = []
    1053  
    1054          async def handle_client(reader, writer):
    1055              nonlocal CNT
    1056  
    1057              data = await reader.readexactly(len(A_DATA))
    1058              self.assertEqual(data, A_DATA)
    1059              writer.write(b'OK')
    1060  
    1061              data = await reader.readexactly(len(B_DATA))
    1062              self.assertEqual(data, B_DATA)
    1063              writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
    1064  
    1065              await writer.drain()
    1066              writer.close()
    1067  
    1068              CNT += 1
    1069  
    1070          class ESC[4;38;5;81mServerProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mStreamReaderProtocol):
    1071              def connection_made(self, transport):
    1072                  super_ = super()
    1073                  transport.pause_reading()
    1074                  fut = self._loop.create_task(self._loop.start_tls(
    1075                      transport, self, sslctx_2, server_side=True))
    1076  
    1077                  def cb(_):
    1078                      try:
    1079                          tr = fut.result()
    1080                      except Exception as ex:
    1081                          super_.connection_lost(ex)
    1082                      else:
    1083                          super_.connection_made(tr)
    1084                  fut.add_done_callback(cb)
    1085  
    1086          def server_protocol_factory():
    1087              reader = asyncio.StreamReader()
    1088              protocol = ServerProtocol(reader, handle_client)
    1089              return protocol
    1090  
    1091          async def test_client(addr):
    1092              fut = asyncio.Future()
    1093  
    1094              def prog(sock):
    1095                  try:
    1096                      sock.connect(addr)
    1097                      sock.starttls(client_sslctx_1)
    1098  
    1099                      # because wrap_socket() doesn't work correctly on
    1100                      # SSLSocket, we have to do the 2nd level SSL manually
    1101                      incoming = ssl.MemoryBIO()
    1102                      outgoing = ssl.MemoryBIO()
    1103                      sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
    1104  
    1105                      def do(func, *args):
    1106                          while True:
    1107                              try:
    1108                                  rv = func(*args)
    1109                                  break
    1110                              except ssl.SSLWantReadError:
    1111                                  if outgoing.pending:
    1112                                      sock.send(outgoing.read())
    1113                                  incoming.write(sock.recv(65536))
    1114                          if outgoing.pending:
    1115                              sock.send(outgoing.read())
    1116                          return rv
    1117  
    1118                      do(sslobj.do_handshake)
    1119  
    1120                      do(sslobj.write, A_DATA)
    1121                      data = do(sslobj.read, 2)
    1122                      self.assertEqual(data, b'OK')
    1123  
    1124                      do(sslobj.write, B_DATA)
    1125                      data = b''
    1126                      while True:
    1127                          chunk = do(sslobj.read, 4)
    1128                          if not chunk:
    1129                              break
    1130                          data += chunk
    1131                      self.assertEqual(data, b'SPAM')
    1132  
    1133                      do(sslobj.unwrap)
    1134                      sock.close()
    1135  
    1136                  except Exception as ex:
    1137                      self.loop.call_soon_threadsafe(fut.set_exception, ex)
    1138                      sock.close()
    1139                  else:
    1140                      self.loop.call_soon_threadsafe(fut.set_result, None)
    1141  
    1142              client = self.tcp_client(prog)
    1143              client.start()
    1144              clients.append(client)
    1145  
    1146              await fut
    1147  
    1148          async def start_server():
    1149              extras = {}
    1150  
    1151              srv = await self.loop.create_server(
    1152                  server_protocol_factory,
    1153                  '127.0.0.1', 0,
    1154                  family=socket.AF_INET,
    1155                  ssl=sslctx_1,
    1156                  **extras)
    1157  
    1158              try:
    1159                  srv_socks = srv.sockets
    1160                  self.assertTrue(srv_socks)
    1161  
    1162                  addr = srv_socks[0].getsockname()
    1163  
    1164                  tasks = []
    1165                  for _ in range(TOTAL_CNT):
    1166                      tasks.append(test_client(addr))
    1167  
    1168                  await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
    1169  
    1170              finally:
    1171                  self.loop.call_soon(srv.close)
    1172                  await srv.wait_closed()
    1173  
    1174          with self._silence_eof_received_warning():
    1175              self.loop.run_until_complete(start_server())
    1176  
    1177          self.assertEqual(CNT, TOTAL_CNT)
    1178  
    1179          for client in clients:
    1180              client.stop()
    1181  
    1182      def test_shutdown_cleanly(self):
    1183          CNT = 0
    1184          TOTAL_CNT = 25
    1185  
    1186          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
    1187  
    1188          sslctx = self._create_server_ssl_context(
    1189              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1190          client_sslctx = self._create_client_ssl_context()
    1191  
    1192          def server(sock):
    1193              sock.starttls(
    1194                  sslctx,
    1195                  server_side=True)
    1196  
    1197              data = sock.recv_all(len(A_DATA))
    1198              self.assertEqual(data, A_DATA)
    1199              sock.send(b'OK')
    1200  
    1201              sock.unwrap()
    1202  
    1203              sock.close()
    1204  
    1205          async def client(addr):
    1206              extras = {}
    1207              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
    1208  
    1209              reader, writer = await asyncio.open_connection(
    1210                  *addr,
    1211                  ssl=client_sslctx,
    1212                  server_hostname='',
    1213                  **extras)
    1214  
    1215              writer.write(A_DATA)
    1216              self.assertEqual(await reader.readexactly(2), b'OK')
    1217  
    1218              self.assertEqual(await reader.read(), b'')
    1219  
    1220              nonlocal CNT
    1221              CNT += 1
    1222  
    1223              writer.close()
    1224              await self.wait_closed(writer)
    1225  
    1226          def run(coro):
    1227              nonlocal CNT
    1228              CNT = 0
    1229  
    1230              async def _gather(*tasks):
    1231                  return await asyncio.gather(*tasks)
    1232  
    1233              with self.tcp_server(server,
    1234                                   max_clients=TOTAL_CNT,
    1235                                   backlog=TOTAL_CNT) as srv:
    1236                  tasks = []
    1237                  for _ in range(TOTAL_CNT):
    1238                      tasks.append(coro(srv.addr))
    1239  
    1240                  self.loop.run_until_complete(
    1241                      _gather(*tasks))
    1242  
    1243              self.assertEqual(CNT, TOTAL_CNT)
    1244  
    1245          with self._silence_eof_received_warning():
    1246              run(client)
    1247  
    1248      def test_flush_before_shutdown(self):
    1249          CHUNK = 1024 * 128
    1250          SIZE = 32
    1251  
    1252          sslctx = self._create_server_ssl_context(
    1253              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1254          client_sslctx = self._create_client_ssl_context()
    1255  
    1256          future = None
    1257  
    1258          def server(sock):
    1259              sock.starttls(sslctx, server_side=True)
    1260              self.assertEqual(sock.recv_all(4), b'ping')
    1261              sock.send(b'pong')
    1262              time.sleep(0.5)  # hopefully stuck the TCP buffer
    1263              data = sock.recv_all(CHUNK * SIZE)
    1264              self.assertEqual(len(data), CHUNK * SIZE)
    1265              sock.close()
    1266  
    1267          def run(meth):
    1268              def wrapper(sock):
    1269                  try:
    1270                      meth(sock)
    1271                  except Exception as ex:
    1272                      self.loop.call_soon_threadsafe(future.set_exception, ex)
    1273                  else:
    1274                      self.loop.call_soon_threadsafe(future.set_result, None)
    1275              return wrapper
    1276  
    1277          async def client(addr):
    1278              nonlocal future
    1279              future = self.loop.create_future()
    1280              reader, writer = await asyncio.open_connection(
    1281                  *addr,
    1282                  ssl=client_sslctx,
    1283                  server_hostname='')
    1284              sslprotocol = writer.transport._ssl_protocol
    1285              writer.write(b'ping')
    1286              data = await reader.readexactly(4)
    1287              self.assertEqual(data, b'pong')
    1288  
    1289              sslprotocol.pause_writing()
    1290              for _ in range(SIZE):
    1291                  writer.write(b'x' * CHUNK)
    1292  
    1293              writer.close()
    1294              sslprotocol.resume_writing()
    1295  
    1296              await self.wait_closed(writer)
    1297              try:
    1298                  data = await reader.read()
    1299                  self.assertEqual(data, b'')
    1300              except ConnectionResetError:
    1301                  pass
    1302              await future
    1303  
    1304          with self.tcp_server(run(server)) as srv:
    1305              self.loop.run_until_complete(client(srv.addr))
    1306  
    1307      def test_remote_shutdown_receives_trailing_data(self):
    1308          CHUNK = 1024 * 128
    1309          SIZE = 32
    1310  
    1311          sslctx = self._create_server_ssl_context(
    1312              test_utils.ONLYCERT,
    1313              test_utils.ONLYKEY
    1314          )
    1315          client_sslctx = self._create_client_ssl_context()
    1316          future = None
    1317  
    1318          def server(sock):
    1319              incoming = ssl.MemoryBIO()
    1320              outgoing = ssl.MemoryBIO()
    1321              sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
    1322  
    1323              while True:
    1324                  try:
    1325                      sslobj.do_handshake()
    1326                  except ssl.SSLWantReadError:
    1327                      if outgoing.pending:
    1328                          sock.send(outgoing.read())
    1329                      incoming.write(sock.recv(16384))
    1330                  else:
    1331                      if outgoing.pending:
    1332                          sock.send(outgoing.read())
    1333                      break
    1334  
    1335              while True:
    1336                  try:
    1337                      data = sslobj.read(4)
    1338                  except ssl.SSLWantReadError:
    1339                      incoming.write(sock.recv(16384))
    1340                  else:
    1341                      break
    1342  
    1343              self.assertEqual(data, b'ping')
    1344              sslobj.write(b'pong')
    1345              sock.send(outgoing.read())
    1346  
    1347              time.sleep(0.2)  # wait for the peer to fill its backlog
    1348  
    1349              # send close_notify but don't wait for response
    1350              with self.assertRaises(ssl.SSLWantReadError):
    1351                  sslobj.unwrap()
    1352              sock.send(outgoing.read())
    1353  
    1354              # should receive all data
    1355              data_len = 0
    1356              while True:
    1357                  try:
    1358                      chunk = len(sslobj.read(16384))
    1359                      data_len += chunk
    1360                  except ssl.SSLWantReadError:
    1361                      incoming.write(sock.recv(16384))
    1362                  except ssl.SSLZeroReturnError:
    1363                      break
    1364  
    1365              self.assertEqual(data_len, CHUNK * SIZE)
    1366  
    1367              # verify that close_notify is received
    1368              sslobj.unwrap()
    1369  
    1370              sock.close()
    1371  
    1372          def eof_server(sock):
    1373              sock.starttls(sslctx, server_side=True)
    1374              self.assertEqual(sock.recv_all(4), b'ping')
    1375              sock.send(b'pong')
    1376  
    1377              time.sleep(0.2)  # wait for the peer to fill its backlog
    1378  
    1379              # send EOF
    1380              sock.shutdown(socket.SHUT_WR)
    1381  
    1382              # should receive all data
    1383              data = sock.recv_all(CHUNK * SIZE)
    1384              self.assertEqual(len(data), CHUNK * SIZE)
    1385  
    1386              sock.close()
    1387  
    1388          async def client(addr):
    1389              nonlocal future
    1390              future = self.loop.create_future()
    1391  
    1392              reader, writer = await asyncio.open_connection(
    1393                  *addr,
    1394                  ssl=client_sslctx,
    1395                  server_hostname='')
    1396              writer.write(b'ping')
    1397              data = await reader.readexactly(4)
    1398              self.assertEqual(data, b'pong')
    1399  
    1400              # fill write backlog in a hacky way - renegotiation won't help
    1401              for _ in range(SIZE):
    1402                  writer.transport._test__append_write_backlog(b'x' * CHUNK)
    1403  
    1404              try:
    1405                  data = await reader.read()
    1406                  self.assertEqual(data, b'')
    1407              except (BrokenPipeError, ConnectionResetError):
    1408                  pass
    1409  
    1410              await future
    1411  
    1412              writer.close()
    1413              await self.wait_closed(writer)
    1414  
    1415          def run(meth):
    1416              def wrapper(sock):
    1417                  try:
    1418                      meth(sock)
    1419                  except Exception as ex:
    1420                      self.loop.call_soon_threadsafe(future.set_exception, ex)
    1421                  else:
    1422                      self.loop.call_soon_threadsafe(future.set_result, None)
    1423              return wrapper
    1424  
    1425          with self.tcp_server(run(server)) as srv:
    1426              self.loop.run_until_complete(client(srv.addr))
    1427  
    1428          with self.tcp_server(run(eof_server)) as srv:
    1429              self.loop.run_until_complete(client(srv.addr))
    1430  
    1431      def test_connect_timeout_warning(self):
    1432          s = socket.socket(socket.AF_INET)
    1433          s.bind(('127.0.0.1', 0))
    1434          addr = s.getsockname()
    1435  
    1436          async def test():
    1437              try:
    1438                  await asyncio.wait_for(
    1439                      self.loop.create_connection(asyncio.Protocol,
    1440                                                  *addr, ssl=True),
    1441                      0.1)
    1442              except (ConnectionRefusedError, asyncio.TimeoutError):
    1443                  pass
    1444              else:
    1445                  self.fail('TimeoutError is not raised')
    1446  
    1447          with s:
    1448              try:
    1449                  with self.assertWarns(ResourceWarning) as cm:
    1450                      self.loop.run_until_complete(test())
    1451                      gc.collect()
    1452                      gc.collect()
    1453                      gc.collect()
    1454              except AssertionError as e:
    1455                  self.assertEqual(str(e), 'ResourceWarning not triggered')
    1456              else:
    1457                  self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
    1458  
    1459      def test_handshake_timeout_handler_leak(self):
    1460          s = socket.socket(socket.AF_INET)
    1461          s.bind(('127.0.0.1', 0))
    1462          s.listen(1)
    1463          addr = s.getsockname()
    1464  
    1465          async def test(ctx):
    1466              try:
    1467                  await asyncio.wait_for(
    1468                      self.loop.create_connection(asyncio.Protocol, *addr,
    1469                                                  ssl=ctx),
    1470                      0.1)
    1471              except (ConnectionRefusedError, asyncio.TimeoutError):
    1472                  pass
    1473              else:
    1474                  self.fail('TimeoutError is not raised')
    1475  
    1476          with s:
    1477              ctx = ssl.create_default_context()
    1478              self.loop.run_until_complete(test(ctx))
    1479              ctx = weakref.ref(ctx)
    1480  
    1481          # SSLProtocol should be DECREF to 0
    1482          self.assertIsNone(ctx())
    1483  
    1484      def test_shutdown_timeout_handler_leak(self):
    1485          loop = self.loop
    1486  
    1487          def server(sock):
    1488              sslctx = self._create_server_ssl_context(
    1489                  test_utils.ONLYCERT,
    1490                  test_utils.ONLYKEY
    1491              )
    1492              sock = sslctx.wrap_socket(sock, server_side=True)
    1493              sock.recv(32)
    1494              sock.close()
    1495  
    1496          class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
    1497              def __init__(self):
    1498                  self.fut = asyncio.Future(loop=loop)
    1499  
    1500              def connection_lost(self, exc):
    1501                  self.fut.set_result(None)
    1502  
    1503          async def client(addr, ctx):
    1504              tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
    1505              tr.close()
    1506              await pr.fut
    1507  
    1508          with self.tcp_server(server) as srv:
    1509              ctx = self._create_client_ssl_context()
    1510              loop.run_until_complete(client(srv.addr, ctx))
    1511              ctx = weakref.ref(ctx)
    1512  
    1513          # asyncio has no shutdown timeout, but it ends up with a circular
    1514          # reference loop - not ideal (introduces gc glitches), but at least
    1515          # not leaking
    1516          gc.collect()
    1517          gc.collect()
    1518          gc.collect()
    1519  
    1520          # SSLProtocol should be DECREF to 0
    1521          self.assertIsNone(ctx())
    1522  
    1523      def test_shutdown_timeout_handler_not_set(self):
    1524          loop = self.loop
    1525          eof = asyncio.Event()
    1526          extra = None
    1527  
    1528          def server(sock):
    1529              sslctx = self._create_server_ssl_context(
    1530                  test_utils.ONLYCERT,
    1531                  test_utils.ONLYKEY
    1532              )
    1533              sock = sslctx.wrap_socket(sock, server_side=True)
    1534              sock.send(b'hello')
    1535              assert sock.recv(1024) == b'world'
    1536              sock.send(b'extra bytes')
    1537              # sending EOF here
    1538              sock.shutdown(socket.SHUT_WR)
    1539              loop.call_soon_threadsafe(eof.set)
    1540              # make sure we have enough time to reproduce the issue
    1541              assert sock.recv(1024) == b''
    1542              sock.close()
    1543  
    1544          class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
    1545              def __init__(self):
    1546                  self.fut = asyncio.Future(loop=loop)
    1547                  self.transport = None
    1548  
    1549              def connection_made(self, transport):
    1550                  self.transport = transport
    1551  
    1552              def data_received(self, data):
    1553                  if data == b'hello':
    1554                      self.transport.write(b'world')
    1555                      # pause reading would make incoming data stay in the sslobj
    1556                      self.transport.pause_reading()
    1557                  else:
    1558                      nonlocal extra
    1559                      extra = data
    1560  
    1561              def connection_lost(self, exc):
    1562                  if exc is None:
    1563                      self.fut.set_result(None)
    1564                  else:
    1565                      self.fut.set_exception(exc)
    1566  
    1567          async def client(addr):
    1568              ctx = self._create_client_ssl_context()
    1569              tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
    1570              await eof.wait()
    1571              tr.resume_reading()
    1572              await pr.fut
    1573              tr.close()
    1574              assert extra == b'extra bytes'
    1575  
    1576          with self.tcp_server(server) as srv:
    1577              loop.run_until_complete(client(srv.addr))
    1578  
    1579  
    1580  ###############################################################################
    1581  # Socket Testing Utilities
    1582  ###############################################################################
    1583  
    1584  
    1585  class ESC[4;38;5;81mTestSocketWrapper:
    1586  
    1587      def __init__(self, sock):
    1588          self.__sock = sock
    1589  
    1590      def recv_all(self, n):
    1591          buf = b''
    1592          while len(buf) < n:
    1593              data = self.recv(n - len(buf))
    1594              if data == b'':
    1595                  raise ConnectionAbortedError
    1596              buf += data
    1597          return buf
    1598  
    1599      def starttls(self, ssl_context, *,
    1600                   server_side=False,
    1601                   server_hostname=None,
    1602                   do_handshake_on_connect=True):
    1603  
    1604          assert isinstance(ssl_context, ssl.SSLContext)
    1605  
    1606          ssl_sock = ssl_context.wrap_socket(
    1607              self.__sock, server_side=server_side,
    1608              server_hostname=server_hostname,
    1609              do_handshake_on_connect=do_handshake_on_connect)
    1610  
    1611          if server_side:
    1612              ssl_sock.do_handshake()
    1613  
    1614          self.__sock.close()
    1615          self.__sock = ssl_sock
    1616  
    1617      def __getattr__(self, name):
    1618          return getattr(self.__sock, name)
    1619  
    1620      def __repr__(self):
    1621          return '<{} {!r}>'.format(type(self).__name__, self.__sock)
    1622  
    1623  
    1624  class ESC[4;38;5;81mSocketThread(ESC[4;38;5;149mthreadingESC[4;38;5;149m.ESC[4;38;5;149mThread):
    1625  
    1626      def stop(self):
    1627          self._active = False
    1628          self.join()
    1629  
    1630      def __enter__(self):
    1631          self.start()
    1632          return self
    1633  
    1634      def __exit__(self, *exc):
    1635          self.stop()
    1636  
    1637  
    1638  class ESC[4;38;5;81mTestThreadedClient(ESC[4;38;5;149mSocketThread):
    1639  
    1640      def __init__(self, test, sock, prog, timeout):
    1641          threading.Thread.__init__(self, None, None, 'test-client')
    1642          self.daemon = True
    1643  
    1644          self._timeout = timeout
    1645          self._sock = sock
    1646          self._active = True
    1647          self._prog = prog
    1648          self._test = test
    1649  
    1650      def run(self):
    1651          try:
    1652              self._prog(TestSocketWrapper(self._sock))
    1653          except (KeyboardInterrupt, SystemExit):
    1654              raise
    1655          except BaseException as ex:
    1656              self._test._abort_socket_test(ex)
    1657  
    1658  
    1659  class ESC[4;38;5;81mTestThreadedServer(ESC[4;38;5;149mSocketThread):
    1660  
    1661      def __init__(self, test, sock, prog, timeout, max_clients):
    1662          threading.Thread.__init__(self, None, None, 'test-server')
    1663          self.daemon = True
    1664  
    1665          self._clients = 0
    1666          self._finished_clients = 0
    1667          self._max_clients = max_clients
    1668          self._timeout = timeout
    1669          self._sock = sock
    1670          self._active = True
    1671  
    1672          self._prog = prog
    1673  
    1674          self._s1, self._s2 = socket.socketpair()
    1675          self._s1.setblocking(False)
    1676  
    1677          self._test = test
    1678  
    1679      def stop(self):
    1680          try:
    1681              if self._s2 and self._s2.fileno() != -1:
    1682                  try:
    1683                      self._s2.send(b'stop')
    1684                  except OSError:
    1685                      pass
    1686          finally:
    1687              super().stop()
    1688  
    1689      def run(self):
    1690          try:
    1691              with self._sock:
    1692                  self._sock.setblocking(False)
    1693                  self._run()
    1694          finally:
    1695              self._s1.close()
    1696              self._s2.close()
    1697  
    1698      def _run(self):
    1699          while self._active:
    1700              if self._clients >= self._max_clients:
    1701                  return
    1702  
    1703              r, w, x = select.select(
    1704                  [self._sock, self._s1], [], [], self._timeout)
    1705  
    1706              if self._s1 in r:
    1707                  return
    1708  
    1709              if self._sock in r:
    1710                  try:
    1711                      conn, addr = self._sock.accept()
    1712                  except BlockingIOError:
    1713                      continue
    1714                  except socket.timeout:
    1715                      if not self._active:
    1716                          return
    1717                      else:
    1718                          raise
    1719                  else:
    1720                      self._clients += 1
    1721                      conn.settimeout(self._timeout)
    1722                      try:
    1723                          with conn:
    1724                              self._handle_client(conn)
    1725                      except (KeyboardInterrupt, SystemExit):
    1726                          raise
    1727                      except BaseException as ex:
    1728                          self._active = False
    1729                          try:
    1730                              raise
    1731                          finally:
    1732                              self._test._abort_socket_test(ex)
    1733  
    1734      def _handle_client(self, sock):
    1735          self._prog(TestSocketWrapper(sock))
    1736  
    1737      @property
    1738      def addr(self):
    1739          return self._sock.getsockname()