python (3.12.0)

(root)/
lib/
python3.12/
test/
test_asyncio/
test_ssl.py
       1  import asyncio
       2  import contextlib
       3  import gc
       4  import logging
       5  import select
       6  import socket
       7  import sys
       8  import tempfile
       9  import threading
      10  import time
      11  import weakref
      12  import unittest
      13  
      14  try:
      15      import ssl
      16  except ImportError:
      17      ssl = None
      18  
      19  from test import support
      20  from test.test_asyncio import utils as test_utils
      21  
      22  
      23  MACOS = (sys.platform == 'darwin')
      24  BUF_MULTIPLIER = 1024 if not MACOS else 64
      25  
      26  
      27  def tearDownModule():
      28      asyncio.set_event_loop_policy(None)
      29  
      30  
      31  class ESC[4;38;5;81mMyBaseProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
      32      connected = None
      33      done = None
      34  
      35      def __init__(self, loop=None):
      36          self.transport = None
      37          self.state = 'INITIAL'
      38          self.nbytes = 0
      39          if loop is not None:
      40              self.connected = asyncio.Future(loop=loop)
      41              self.done = asyncio.Future(loop=loop)
      42  
      43      def connection_made(self, transport):
      44          self.transport = transport
      45          assert self.state == 'INITIAL', self.state
      46          self.state = 'CONNECTED'
      47          if self.connected:
      48              self.connected.set_result(None)
      49  
      50      def data_received(self, data):
      51          assert self.state == 'CONNECTED', self.state
      52          self.nbytes += len(data)
      53  
      54      def eof_received(self):
      55          assert self.state == 'CONNECTED', self.state
      56          self.state = 'EOF'
      57  
      58      def connection_lost(self, exc):
      59          assert self.state in ('CONNECTED', 'EOF'), self.state
      60          self.state = 'CLOSED'
      61          if self.done:
      62              self.done.set_result(None)
      63  
      64  
      65  class ESC[4;38;5;81mMessageOutFilter(ESC[4;38;5;149mloggingESC[4;38;5;149m.ESC[4;38;5;149mFilter):
      66      def __init__(self, msg):
      67          self.msg = msg
      68  
      69      def filter(self, record):
      70          if self.msg in record.msg:
      71              return False
      72          return True
      73  
      74  
      75  @unittest.skipIf(ssl is None, 'No ssl module')
      76  class ESC[4;38;5;81mTestSSL(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      77  
      78      PAYLOAD_SIZE = 1024 * 100
      79      TIMEOUT = support.LONG_TIMEOUT
      80  
      81      def setUp(self):
      82          super().setUp()
      83          self.loop = asyncio.new_event_loop()
      84          self.set_event_loop(self.loop)
      85          self.addCleanup(self.loop.close)
      86  
      87      def tearDown(self):
      88          # just in case if we have transport close callbacks
      89          if not self.loop.is_closed():
      90              test_utils.run_briefly(self.loop)
      91  
      92          self.doCleanups()
      93          support.gc_collect()
      94          super().tearDown()
      95  
      96      def tcp_server(self, server_prog, *,
      97                     family=socket.AF_INET,
      98                     addr=None,
      99                     timeout=support.SHORT_TIMEOUT,
     100                     backlog=1,
     101                     max_clients=10):
     102  
     103          if addr is None:
     104              if family == getattr(socket, "AF_UNIX", None):
     105                  with tempfile.NamedTemporaryFile() as tmp:
     106                      addr = tmp.name
     107              else:
     108                  addr = ('127.0.0.1', 0)
     109  
     110          sock = socket.socket(family, socket.SOCK_STREAM)
     111  
     112          if timeout is None:
     113              raise RuntimeError('timeout is required')
     114          if timeout <= 0:
     115              raise RuntimeError('only blocking sockets are supported')
     116          sock.settimeout(timeout)
     117  
     118          try:
     119              sock.bind(addr)
     120              sock.listen(backlog)
     121          except OSError as ex:
     122              sock.close()
     123              raise ex
     124  
     125          return TestThreadedServer(
     126              self, sock, server_prog, timeout, max_clients)
     127  
     128      def tcp_client(self, client_prog,
     129                     family=socket.AF_INET,
     130                     timeout=support.SHORT_TIMEOUT):
     131  
     132          sock = socket.socket(family, socket.SOCK_STREAM)
     133  
     134          if timeout is None:
     135              raise RuntimeError('timeout is required')
     136          if timeout <= 0:
     137              raise RuntimeError('only blocking sockets are supported')
     138          sock.settimeout(timeout)
     139  
     140          return TestThreadedClient(
     141              self, sock, client_prog, timeout)
     142  
     143      def unix_server(self, *args, **kwargs):
     144          return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
     145  
     146      def unix_client(self, *args, **kwargs):
     147          return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
     148  
     149      def _create_server_ssl_context(self, certfile, keyfile=None):
     150          sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     151          sslcontext.options |= ssl.OP_NO_SSLv2
     152          sslcontext.load_cert_chain(certfile, keyfile)
     153          return sslcontext
     154  
     155      def _create_client_ssl_context(self, *, disable_verify=True):
     156          sslcontext = ssl.create_default_context()
     157          sslcontext.check_hostname = False
     158          if disable_verify:
     159              sslcontext.verify_mode = ssl.CERT_NONE
     160          return sslcontext
     161  
     162      @contextlib.contextmanager
     163      def _silence_eof_received_warning(self):
     164          # TODO This warning has to be fixed in asyncio.
     165          logger = logging.getLogger('asyncio')
     166          filter = MessageOutFilter('has no effect when using ssl')
     167          logger.addFilter(filter)
     168          try:
     169              yield
     170          finally:
     171              logger.removeFilter(filter)
     172  
     173      def _abort_socket_test(self, ex):
     174          try:
     175              self.loop.stop()
     176          finally:
     177              self.fail(ex)
     178  
     179      def new_loop(self):
     180          return asyncio.new_event_loop()
     181  
     182      def new_policy(self):
     183          return asyncio.DefaultEventLoopPolicy()
     184  
     185      async def wait_closed(self, obj):
     186          if not isinstance(obj, asyncio.StreamWriter):
     187              return
     188          try:
     189              await obj.wait_closed()
     190          except (BrokenPipeError, ConnectionError):
     191              pass
     192  
     193      def test_create_server_ssl_1(self):
     194          CNT = 0           # number of clients that were successful
     195          TOTAL_CNT = 25    # total number of clients that test will create
     196          TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
     197  
     198          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
     199          B_DATA = b'B' * 1024 * BUF_MULTIPLIER
     200  
     201          sslctx = self._create_server_ssl_context(
     202              test_utils.ONLYCERT, test_utils.ONLYKEY
     203          )
     204          client_sslctx = self._create_client_ssl_context()
     205  
     206          clients = []
     207  
     208          async def handle_client(reader, writer):
     209              nonlocal CNT
     210  
     211              data = await reader.readexactly(len(A_DATA))
     212              self.assertEqual(data, A_DATA)
     213              writer.write(b'OK')
     214  
     215              data = await reader.readexactly(len(B_DATA))
     216              self.assertEqual(data, B_DATA)
     217              writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
     218  
     219              await writer.drain()
     220              writer.close()
     221  
     222              CNT += 1
     223  
     224          async def test_client(addr):
     225              fut = asyncio.Future()
     226  
     227              def prog(sock):
     228                  try:
     229                      sock.starttls(client_sslctx)
     230                      sock.connect(addr)
     231                      sock.send(A_DATA)
     232  
     233                      data = sock.recv_all(2)
     234                      self.assertEqual(data, b'OK')
     235  
     236                      sock.send(B_DATA)
     237                      data = sock.recv_all(4)
     238                      self.assertEqual(data, b'SPAM')
     239  
     240                      sock.close()
     241  
     242                  except Exception as ex:
     243                      self.loop.call_soon_threadsafe(fut.set_exception, ex)
     244                  else:
     245                      self.loop.call_soon_threadsafe(fut.set_result, None)
     246  
     247              client = self.tcp_client(prog)
     248              client.start()
     249              clients.append(client)
     250  
     251              await fut
     252  
     253          async def start_server():
     254              extras = {}
     255              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
     256  
     257              srv = await asyncio.start_server(
     258                  handle_client,
     259                  '127.0.0.1', 0,
     260                  family=socket.AF_INET,
     261                  ssl=sslctx,
     262                  **extras)
     263  
     264              try:
     265                  srv_socks = srv.sockets
     266                  self.assertTrue(srv_socks)
     267  
     268                  addr = srv_socks[0].getsockname()
     269  
     270                  tasks = []
     271                  for _ in range(TOTAL_CNT):
     272                      tasks.append(test_client(addr))
     273  
     274                  await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
     275  
     276              finally:
     277                  self.loop.call_soon(srv.close)
     278                  await srv.wait_closed()
     279  
     280          with self._silence_eof_received_warning():
     281              self.loop.run_until_complete(start_server())
     282  
     283          self.assertEqual(CNT, TOTAL_CNT)
     284  
     285          for client in clients:
     286              client.stop()
     287  
     288      def test_create_connection_ssl_1(self):
     289          self.loop.set_exception_handler(None)
     290  
     291          CNT = 0
     292          TOTAL_CNT = 25
     293  
     294          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
     295          B_DATA = b'B' * 1024 * BUF_MULTIPLIER
     296  
     297          sslctx = self._create_server_ssl_context(
     298              test_utils.ONLYCERT,
     299              test_utils.ONLYKEY
     300          )
     301          client_sslctx = self._create_client_ssl_context()
     302  
     303          def server(sock):
     304              sock.starttls(
     305                  sslctx,
     306                  server_side=True)
     307  
     308              data = sock.recv_all(len(A_DATA))
     309              self.assertEqual(data, A_DATA)
     310              sock.send(b'OK')
     311  
     312              data = sock.recv_all(len(B_DATA))
     313              self.assertEqual(data, B_DATA)
     314              sock.send(b'SPAM')
     315  
     316              sock.close()
     317  
     318          async def client(addr):
     319              extras = {}
     320              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
     321  
     322              reader, writer = await asyncio.open_connection(
     323                  *addr,
     324                  ssl=client_sslctx,
     325                  server_hostname='',
     326                  **extras)
     327  
     328              writer.write(A_DATA)
     329              self.assertEqual(await reader.readexactly(2), b'OK')
     330  
     331              writer.write(B_DATA)
     332              self.assertEqual(await reader.readexactly(4), b'SPAM')
     333  
     334              nonlocal CNT
     335              CNT += 1
     336  
     337              writer.close()
     338              await self.wait_closed(writer)
     339  
     340          async def client_sock(addr):
     341              sock = socket.socket()
     342              sock.connect(addr)
     343              reader, writer = await asyncio.open_connection(
     344                  sock=sock,
     345                  ssl=client_sslctx,
     346                  server_hostname='')
     347  
     348              writer.write(A_DATA)
     349              self.assertEqual(await reader.readexactly(2), b'OK')
     350  
     351              writer.write(B_DATA)
     352              self.assertEqual(await reader.readexactly(4), b'SPAM')
     353  
     354              nonlocal CNT
     355              CNT += 1
     356  
     357              writer.close()
     358              await self.wait_closed(writer)
     359              sock.close()
     360  
     361          def run(coro):
     362              nonlocal CNT
     363              CNT = 0
     364  
     365              async def _gather(*tasks):
     366                  # trampoline
     367                  return await asyncio.gather(*tasks)
     368  
     369              with self.tcp_server(server,
     370                                   max_clients=TOTAL_CNT,
     371                                   backlog=TOTAL_CNT) as srv:
     372                  tasks = []
     373                  for _ in range(TOTAL_CNT):
     374                      tasks.append(coro(srv.addr))
     375  
     376                  self.loop.run_until_complete(_gather(*tasks))
     377  
     378              self.assertEqual(CNT, TOTAL_CNT)
     379  
     380          with self._silence_eof_received_warning():
     381              run(client)
     382  
     383          with self._silence_eof_received_warning():
     384              run(client_sock)
     385  
     386      def test_create_connection_ssl_slow_handshake(self):
     387          client_sslctx = self._create_client_ssl_context()
     388  
     389          # silence error logger
     390          self.loop.set_exception_handler(lambda *args: None)
     391  
     392          def server(sock):
     393              try:
     394                  sock.recv_all(1024 * 1024)
     395              except ConnectionAbortedError:
     396                  pass
     397              finally:
     398                  sock.close()
     399  
     400          async def client(addr):
     401              reader, writer = await asyncio.open_connection(
     402                  *addr,
     403                  ssl=client_sslctx,
     404                  server_hostname='',
     405                  ssl_handshake_timeout=1.0)
     406              writer.close()
     407              await self.wait_closed(writer)
     408  
     409          with self.tcp_server(server,
     410                               max_clients=1,
     411                               backlog=1) as srv:
     412  
     413              with self.assertRaisesRegex(
     414                      ConnectionAbortedError,
     415                      r'SSL handshake.*is taking longer'):
     416  
     417                  self.loop.run_until_complete(client(srv.addr))
     418  
     419      def test_create_connection_ssl_failed_certificate(self):
     420          # silence error logger
     421          self.loop.set_exception_handler(lambda *args: None)
     422  
     423          sslctx = self._create_server_ssl_context(
     424              test_utils.ONLYCERT,
     425              test_utils.ONLYKEY
     426          )
     427          client_sslctx = self._create_client_ssl_context(disable_verify=False)
     428  
     429          def server(sock):
     430              try:
     431                  sock.starttls(
     432                      sslctx,
     433                      server_side=True)
     434                  sock.connect()
     435              except (ssl.SSLError, OSError):
     436                  pass
     437              finally:
     438                  sock.close()
     439  
     440          async def client(addr):
     441              reader, writer = await asyncio.open_connection(
     442                  *addr,
     443                  ssl=client_sslctx,
     444                  server_hostname='',
     445                  ssl_handshake_timeout=support.SHORT_TIMEOUT)
     446              writer.close()
     447              await self.wait_closed(writer)
     448  
     449          with self.tcp_server(server,
     450                               max_clients=1,
     451                               backlog=1) as srv:
     452  
     453              with self.assertRaises(ssl.SSLCertVerificationError):
     454                  self.loop.run_until_complete(client(srv.addr))
     455  
     456      def test_ssl_handshake_timeout(self):
     457          # bpo-29970: Check that a connection is aborted if handshake is not
     458          # completed in timeout period, instead of remaining open indefinitely
     459          client_sslctx = test_utils.simple_client_sslcontext()
     460  
     461          # silence error logger
     462          messages = []
     463          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     464  
     465          server_side_aborted = False
     466  
     467          def server(sock):
     468              nonlocal server_side_aborted
     469              try:
     470                  sock.recv_all(1024 * 1024)
     471              except ConnectionAbortedError:
     472                  server_side_aborted = True
     473              finally:
     474                  sock.close()
     475  
     476          async def client(addr):
     477              await asyncio.wait_for(
     478                  self.loop.create_connection(
     479                      asyncio.Protocol,
     480                      *addr,
     481                      ssl=client_sslctx,
     482                      server_hostname='',
     483                      ssl_handshake_timeout=10.0),
     484                  0.5)
     485  
     486          with self.tcp_server(server,
     487                               max_clients=1,
     488                               backlog=1) as srv:
     489  
     490              with self.assertRaises(asyncio.TimeoutError):
     491                  self.loop.run_until_complete(client(srv.addr))
     492  
     493          self.assertTrue(server_side_aborted)
     494  
     495          # Python issue #23197: cancelling a handshake must not raise an
     496          # exception or log an error, even if the handshake failed
     497          self.assertEqual(messages, [])
     498  
     499      def test_ssl_handshake_connection_lost(self):
     500          # #246: make sure that no connection_lost() is called before
     501          # connection_made() is called first
     502  
     503          client_sslctx = test_utils.simple_client_sslcontext()
     504  
     505          # silence error logger
     506          self.loop.set_exception_handler(lambda loop, ctx: None)
     507  
     508          connection_made_called = False
     509          connection_lost_called = False
     510  
     511          def server(sock):
     512              sock.recv(1024)
     513              # break the connection during handshake
     514              sock.close()
     515  
     516          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     517              def connection_made(self, transport):
     518                  nonlocal connection_made_called
     519                  connection_made_called = True
     520  
     521              def connection_lost(self, exc):
     522                  nonlocal connection_lost_called
     523                  connection_lost_called = True
     524  
     525          async def client(addr):
     526              await self.loop.create_connection(
     527                  ClientProto,
     528                  *addr,
     529                  ssl=client_sslctx,
     530                  server_hostname=''),
     531  
     532          with self.tcp_server(server,
     533                               max_clients=1,
     534                               backlog=1) as srv:
     535  
     536              with self.assertRaises(ConnectionResetError):
     537                  self.loop.run_until_complete(client(srv.addr))
     538  
     539          if connection_lost_called:
     540              if connection_made_called:
     541                  self.fail("unexpected call to connection_lost()")
     542              else:
     543                  self.fail("unexpected call to connection_lost() without"
     544                            "calling connection_made()")
     545          elif connection_made_called:
     546              self.fail("unexpected call to connection_made()")
     547  
     548      def test_ssl_connect_accepted_socket(self):
     549          proto = ssl.PROTOCOL_TLS_SERVER
     550          server_context = ssl.SSLContext(proto)
     551          server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
     552          if hasattr(server_context, 'check_hostname'):
     553              server_context.check_hostname = False
     554          server_context.verify_mode = ssl.CERT_NONE
     555  
     556          client_context = ssl.SSLContext(proto)
     557          if hasattr(server_context, 'check_hostname'):
     558              client_context.check_hostname = False
     559          client_context.verify_mode = ssl.CERT_NONE
     560  
     561      def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
     562          loop = self.loop
     563  
     564          class ESC[4;38;5;81mMyProto(ESC[4;38;5;149mMyBaseProto):
     565  
     566              def connection_lost(self, exc):
     567                  super().connection_lost(exc)
     568                  loop.call_soon(loop.stop)
     569  
     570              def data_received(self, data):
     571                  super().data_received(data)
     572                  self.transport.write(expected_response)
     573  
     574          lsock = socket.socket(socket.AF_INET)
     575          lsock.bind(('127.0.0.1', 0))
     576          lsock.listen(1)
     577          addr = lsock.getsockname()
     578  
     579          message = b'test data'
     580          response = None
     581          expected_response = b'roger'
     582  
     583          def client():
     584              nonlocal response
     585              try:
     586                  csock = socket.socket(socket.AF_INET)
     587                  if client_ssl is not None:
     588                      csock = client_ssl.wrap_socket(csock)
     589                  csock.connect(addr)
     590                  csock.sendall(message)
     591                  response = csock.recv(99)
     592                  csock.close()
     593              except Exception as exc:
     594                  print(
     595                      "Failure in client thread in test_connect_accepted_socket",
     596                      exc)
     597  
     598          thread = threading.Thread(target=client, daemon=True)
     599          thread.start()
     600  
     601          conn, _ = lsock.accept()
     602          proto = MyProto(loop=loop)
     603          proto.loop = loop
     604  
     605          extras = {}
     606          if server_ssl:
     607              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
     608  
     609          f = loop.create_task(
     610              loop.connect_accepted_socket(
     611                  (lambda: proto), conn, ssl=server_ssl,
     612                  **extras))
     613          loop.run_forever()
     614          conn.close()
     615          lsock.close()
     616  
     617          thread.join(1)
     618          self.assertFalse(thread.is_alive())
     619          self.assertEqual(proto.state, 'CLOSED')
     620          self.assertEqual(proto.nbytes, len(message))
     621          self.assertEqual(response, expected_response)
     622          tr, _ = f.result()
     623  
     624          if server_ssl:
     625              self.assertIn('SSL', tr.__class__.__name__)
     626  
     627          tr.close()
     628          # let it close
     629          self.loop.run_until_complete(asyncio.sleep(0.1))
     630  
     631      def test_start_tls_client_corrupted_ssl(self):
     632          self.loop.set_exception_handler(lambda loop, ctx: None)
     633  
     634          sslctx = test_utils.simple_server_sslcontext()
     635          client_sslctx = test_utils.simple_client_sslcontext()
     636  
     637          def server(sock):
     638              orig_sock = sock.dup()
     639              try:
     640                  sock.starttls(
     641                      sslctx,
     642                      server_side=True)
     643                  sock.sendall(b'A\n')
     644                  sock.recv_all(1)
     645                  orig_sock.send(b'please corrupt the SSL connection')
     646              except ssl.SSLError:
     647                  pass
     648              finally:
     649                  sock.close()
     650                  orig_sock.close()
     651  
     652          async def client(addr):
     653              reader, writer = await asyncio.open_connection(
     654                  *addr,
     655                  ssl=client_sslctx,
     656                  server_hostname='')
     657  
     658              self.assertEqual(await reader.readline(), b'A\n')
     659              writer.write(b'B')
     660              with self.assertRaises(ssl.SSLError):
     661                  await reader.readline()
     662              writer.close()
     663              try:
     664                  await self.wait_closed(writer)
     665              except ssl.SSLError:
     666                  pass
     667              return 'OK'
     668  
     669          with self.tcp_server(server,
     670                               max_clients=1,
     671                               backlog=1) as srv:
     672  
     673              res = self.loop.run_until_complete(client(srv.addr))
     674  
     675          self.assertEqual(res, 'OK')
     676  
     677      def test_start_tls_client_reg_proto_1(self):
     678          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     679  
     680          server_context = test_utils.simple_server_sslcontext()
     681          client_context = test_utils.simple_client_sslcontext()
     682  
     683          def serve(sock):
     684              sock.settimeout(self.TIMEOUT)
     685  
     686              data = sock.recv_all(len(HELLO_MSG))
     687              self.assertEqual(len(data), len(HELLO_MSG))
     688  
     689              sock.starttls(server_context, server_side=True)
     690  
     691              sock.sendall(b'O')
     692              data = sock.recv_all(len(HELLO_MSG))
     693              self.assertEqual(len(data), len(HELLO_MSG))
     694  
     695              sock.unwrap()
     696              sock.close()
     697  
     698          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     699              def __init__(self, on_data, on_eof):
     700                  self.on_data = on_data
     701                  self.on_eof = on_eof
     702                  self.con_made_cnt = 0
     703  
     704              def connection_made(proto, tr):
     705                  proto.con_made_cnt += 1
     706                  # Ensure connection_made gets called only once.
     707                  self.assertEqual(proto.con_made_cnt, 1)
     708  
     709              def data_received(self, data):
     710                  self.on_data.set_result(data)
     711  
     712              def eof_received(self):
     713                  self.on_eof.set_result(True)
     714  
     715          async def client(addr):
     716              await asyncio.sleep(0.5)
     717  
     718              on_data = self.loop.create_future()
     719              on_eof = self.loop.create_future()
     720  
     721              tr, proto = await self.loop.create_connection(
     722                  lambda: ClientProto(on_data, on_eof), *addr)
     723  
     724              tr.write(HELLO_MSG)
     725              new_tr = await self.loop.start_tls(tr, proto, client_context)
     726  
     727              self.assertEqual(await on_data, b'O')
     728              new_tr.write(HELLO_MSG)
     729              await on_eof
     730  
     731              new_tr.close()
     732  
     733          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     734              self.loop.run_until_complete(
     735                  asyncio.wait_for(client(srv.addr),
     736                                   timeout=support.SHORT_TIMEOUT))
     737  
     738      def test_create_connection_memory_leak(self):
     739          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     740  
     741          server_context = self._create_server_ssl_context(
     742              test_utils.ONLYCERT, test_utils.ONLYKEY)
     743          client_context = self._create_client_ssl_context()
     744  
     745          def serve(sock):
     746              sock.settimeout(self.TIMEOUT)
     747  
     748              sock.starttls(server_context, server_side=True)
     749  
     750              sock.sendall(b'O')
     751              data = sock.recv_all(len(HELLO_MSG))
     752              self.assertEqual(len(data), len(HELLO_MSG))
     753  
     754              sock.unwrap()
     755              sock.close()
     756  
     757          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     758              def __init__(self, on_data, on_eof):
     759                  self.on_data = on_data
     760                  self.on_eof = on_eof
     761                  self.con_made_cnt = 0
     762  
     763              def connection_made(proto, tr):
     764                  # XXX: We assume user stores the transport in protocol
     765                  proto.tr = tr
     766                  proto.con_made_cnt += 1
     767                  # Ensure connection_made gets called only once.
     768                  self.assertEqual(proto.con_made_cnt, 1)
     769  
     770              def data_received(self, data):
     771                  self.on_data.set_result(data)
     772  
     773              def eof_received(self):
     774                  self.on_eof.set_result(True)
     775  
     776          async def client(addr):
     777              await asyncio.sleep(0.5)
     778  
     779              on_data = self.loop.create_future()
     780              on_eof = self.loop.create_future()
     781  
     782              tr, proto = await self.loop.create_connection(
     783                  lambda: ClientProto(on_data, on_eof), *addr,
     784                  ssl=client_context)
     785  
     786              self.assertEqual(await on_data, b'O')
     787              tr.write(HELLO_MSG)
     788              await on_eof
     789  
     790              tr.close()
     791  
     792          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     793              self.loop.run_until_complete(
     794                  asyncio.wait_for(client(srv.addr),
     795                                   timeout=support.SHORT_TIMEOUT))
     796  
     797          # No garbage is left for SSL client from loop.create_connection, even
     798          # if user stores the SSLTransport in corresponding protocol instance
     799          client_context = weakref.ref(client_context)
     800          self.assertIsNone(client_context())
     801  
     802      def test_start_tls_client_buf_proto_1(self):
     803          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     804  
     805          server_context = test_utils.simple_server_sslcontext()
     806          client_context = test_utils.simple_client_sslcontext()
     807  
     808          client_con_made_calls = 0
     809  
     810          def serve(sock):
     811              sock.settimeout(self.TIMEOUT)
     812  
     813              data = sock.recv_all(len(HELLO_MSG))
     814              self.assertEqual(len(data), len(HELLO_MSG))
     815  
     816              sock.starttls(server_context, server_side=True)
     817  
     818              sock.sendall(b'O')
     819              data = sock.recv_all(len(HELLO_MSG))
     820              self.assertEqual(len(data), len(HELLO_MSG))
     821  
     822              sock.sendall(b'2')
     823              data = sock.recv_all(len(HELLO_MSG))
     824              self.assertEqual(len(data), len(HELLO_MSG))
     825  
     826              sock.unwrap()
     827              sock.close()
     828  
     829          class ESC[4;38;5;81mClientProtoFirst(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
     830              def __init__(self, on_data):
     831                  self.on_data = on_data
     832                  self.buf = bytearray(1)
     833  
     834              def connection_made(self, tr):
     835                  nonlocal client_con_made_calls
     836                  client_con_made_calls += 1
     837  
     838              def get_buffer(self, sizehint):
     839                  return self.buf
     840  
     841              def buffer_updated(self, nsize):
     842                  assert nsize == 1
     843                  self.on_data.set_result(bytes(self.buf[:nsize]))
     844  
     845              def eof_received(self):
     846                  pass
     847  
     848          class ESC[4;38;5;81mClientProtoSecond(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     849              def __init__(self, on_data, on_eof):
     850                  self.on_data = on_data
     851                  self.on_eof = on_eof
     852                  self.con_made_cnt = 0
     853  
     854              def connection_made(self, tr):
     855                  nonlocal client_con_made_calls
     856                  client_con_made_calls += 1
     857  
     858              def data_received(self, data):
     859                  self.on_data.set_result(data)
     860  
     861              def eof_received(self):
     862                  self.on_eof.set_result(True)
     863  
     864          async def client(addr):
     865              await asyncio.sleep(0.5)
     866  
     867              on_data1 = self.loop.create_future()
     868              on_data2 = self.loop.create_future()
     869              on_eof = self.loop.create_future()
     870  
     871              tr, proto = await self.loop.create_connection(
     872                  lambda: ClientProtoFirst(on_data1), *addr)
     873  
     874              tr.write(HELLO_MSG)
     875              new_tr = await self.loop.start_tls(tr, proto, client_context)
     876  
     877              self.assertEqual(await on_data1, b'O')
     878              new_tr.write(HELLO_MSG)
     879  
     880              new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
     881              self.assertEqual(await on_data2, b'2')
     882              new_tr.write(HELLO_MSG)
     883              await on_eof
     884  
     885              new_tr.close()
     886  
     887              # connection_made() should be called only once -- when
     888              # we establish connection for the first time. Start TLS
     889              # doesn't call connection_made() on application protocols.
     890              self.assertEqual(client_con_made_calls, 1)
     891  
     892          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     893              self.loop.run_until_complete(
     894                  asyncio.wait_for(client(srv.addr),
     895                                   timeout=self.TIMEOUT))
     896  
     897      def test_start_tls_slow_client_cancel(self):
     898          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     899  
     900          client_context = test_utils.simple_client_sslcontext()
     901          server_waits_on_handshake = self.loop.create_future()
     902  
     903          def serve(sock):
     904              sock.settimeout(self.TIMEOUT)
     905  
     906              data = sock.recv_all(len(HELLO_MSG))
     907              self.assertEqual(len(data), len(HELLO_MSG))
     908  
     909              try:
     910                  self.loop.call_soon_threadsafe(
     911                      server_waits_on_handshake.set_result, None)
     912                  data = sock.recv_all(1024 * 1024)
     913              except ConnectionAbortedError:
     914                  pass
     915              finally:
     916                  sock.close()
     917  
     918          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     919              def __init__(self, on_data, on_eof):
     920                  self.on_data = on_data
     921                  self.on_eof = on_eof
     922                  self.con_made_cnt = 0
     923  
     924              def connection_made(proto, tr):
     925                  proto.con_made_cnt += 1
     926                  # Ensure connection_made gets called only once.
     927                  self.assertEqual(proto.con_made_cnt, 1)
     928  
     929              def data_received(self, data):
     930                  self.on_data.set_result(data)
     931  
     932              def eof_received(self):
     933                  self.on_eof.set_result(True)
     934  
     935          async def client(addr):
     936              await asyncio.sleep(0.5)
     937  
     938              on_data = self.loop.create_future()
     939              on_eof = self.loop.create_future()
     940  
     941              tr, proto = await self.loop.create_connection(
     942                  lambda: ClientProto(on_data, on_eof), *addr)
     943  
     944              tr.write(HELLO_MSG)
     945  
     946              await server_waits_on_handshake
     947  
     948              with self.assertRaises(asyncio.TimeoutError):
     949                  await asyncio.wait_for(
     950                      self.loop.start_tls(tr, proto, client_context),
     951                      0.5)
     952  
     953          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     954              self.loop.run_until_complete(
     955                  asyncio.wait_for(client(srv.addr),
     956                                   timeout=support.SHORT_TIMEOUT))
     957  
     958      def test_start_tls_server_1(self):
     959          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     960  
     961          server_context = test_utils.simple_server_sslcontext()
     962          client_context = test_utils.simple_client_sslcontext()
     963  
     964          def client(sock, addr):
     965              sock.settimeout(self.TIMEOUT)
     966  
     967              sock.connect(addr)
     968              data = sock.recv_all(len(HELLO_MSG))
     969              self.assertEqual(len(data), len(HELLO_MSG))
     970  
     971              sock.starttls(client_context)
     972              sock.sendall(HELLO_MSG)
     973  
     974              sock.unwrap()
     975              sock.close()
     976  
     977          class ESC[4;38;5;81mServerProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     978              def __init__(self, on_con, on_eof, on_con_lost):
     979                  self.on_con = on_con
     980                  self.on_eof = on_eof
     981                  self.on_con_lost = on_con_lost
     982                  self.data = b''
     983  
     984              def connection_made(self, tr):
     985                  self.on_con.set_result(tr)
     986  
     987              def data_received(self, data):
     988                  self.data += data
     989  
     990              def eof_received(self):
     991                  self.on_eof.set_result(1)
     992  
     993              def connection_lost(self, exc):
     994                  if exc is None:
     995                      self.on_con_lost.set_result(None)
     996                  else:
     997                      self.on_con_lost.set_exception(exc)
     998  
     999          async def main(proto, on_con, on_eof, on_con_lost):
    1000              tr = await on_con
    1001              tr.write(HELLO_MSG)
    1002  
    1003              self.assertEqual(proto.data, b'')
    1004  
    1005              new_tr = await self.loop.start_tls(
    1006                  tr, proto, server_context,
    1007                  server_side=True,
    1008                  ssl_handshake_timeout=self.TIMEOUT)
    1009  
    1010              await on_eof
    1011              await on_con_lost
    1012              self.assertEqual(proto.data, HELLO_MSG)
    1013              new_tr.close()
    1014  
    1015          async def run_main():
    1016              on_con = self.loop.create_future()
    1017              on_eof = self.loop.create_future()
    1018              on_con_lost = self.loop.create_future()
    1019              proto = ServerProto(on_con, on_eof, on_con_lost)
    1020  
    1021              server = await self.loop.create_server(
    1022                  lambda: proto, '127.0.0.1', 0)
    1023              addr = server.sockets[0].getsockname()
    1024  
    1025              with self.tcp_client(lambda sock: client(sock, addr),
    1026                                   timeout=self.TIMEOUT):
    1027                  await asyncio.wait_for(
    1028                      main(proto, on_con, on_eof, on_con_lost),
    1029                      timeout=self.TIMEOUT)
    1030  
    1031              server.close()
    1032              await server.wait_closed()
    1033  
    1034          self.loop.run_until_complete(run_main())
    1035  
    1036      def test_create_server_ssl_over_ssl(self):
    1037          CNT = 0           # number of clients that were successful
    1038          TOTAL_CNT = 25    # total number of clients that test will create
    1039          TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
    1040  
    1041          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
    1042          B_DATA = b'B' * 1024 * BUF_MULTIPLIER
    1043  
    1044          sslctx_1 = self._create_server_ssl_context(
    1045              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1046          client_sslctx_1 = self._create_client_ssl_context()
    1047          sslctx_2 = self._create_server_ssl_context(
    1048              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1049          client_sslctx_2 = self._create_client_ssl_context()
    1050  
    1051          clients = []
    1052  
    1053          async def handle_client(reader, writer):
    1054              nonlocal CNT
    1055  
    1056              data = await reader.readexactly(len(A_DATA))
    1057              self.assertEqual(data, A_DATA)
    1058              writer.write(b'OK')
    1059  
    1060              data = await reader.readexactly(len(B_DATA))
    1061              self.assertEqual(data, B_DATA)
    1062              writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
    1063  
    1064              await writer.drain()
    1065              writer.close()
    1066  
    1067              CNT += 1
    1068  
    1069          class ESC[4;38;5;81mServerProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mStreamReaderProtocol):
    1070              def connection_made(self, transport):
    1071                  super_ = super()
    1072                  transport.pause_reading()
    1073                  fut = self._loop.create_task(self._loop.start_tls(
    1074                      transport, self, sslctx_2, server_side=True))
    1075  
    1076                  def cb(_):
    1077                      try:
    1078                          tr = fut.result()
    1079                      except Exception as ex:
    1080                          super_.connection_lost(ex)
    1081                      else:
    1082                          super_.connection_made(tr)
    1083                  fut.add_done_callback(cb)
    1084  
    1085          def server_protocol_factory():
    1086              reader = asyncio.StreamReader()
    1087              protocol = ServerProtocol(reader, handle_client)
    1088              return protocol
    1089  
    1090          async def test_client(addr):
    1091              fut = asyncio.Future()
    1092  
    1093              def prog(sock):
    1094                  try:
    1095                      sock.connect(addr)
    1096                      sock.starttls(client_sslctx_1)
    1097  
    1098                      # because wrap_socket() doesn't work correctly on
    1099                      # SSLSocket, we have to do the 2nd level SSL manually
    1100                      incoming = ssl.MemoryBIO()
    1101                      outgoing = ssl.MemoryBIO()
    1102                      sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
    1103  
    1104                      def do(func, *args):
    1105                          while True:
    1106                              try:
    1107                                  rv = func(*args)
    1108                                  break
    1109                              except ssl.SSLWantReadError:
    1110                                  if outgoing.pending:
    1111                                      sock.send(outgoing.read())
    1112                                  incoming.write(sock.recv(65536))
    1113                          if outgoing.pending:
    1114                              sock.send(outgoing.read())
    1115                          return rv
    1116  
    1117                      do(sslobj.do_handshake)
    1118  
    1119                      do(sslobj.write, A_DATA)
    1120                      data = do(sslobj.read, 2)
    1121                      self.assertEqual(data, b'OK')
    1122  
    1123                      do(sslobj.write, B_DATA)
    1124                      data = b''
    1125                      while True:
    1126                          chunk = do(sslobj.read, 4)
    1127                          if not chunk:
    1128                              break
    1129                          data += chunk
    1130                      self.assertEqual(data, b'SPAM')
    1131  
    1132                      do(sslobj.unwrap)
    1133                      sock.close()
    1134  
    1135                  except Exception as ex:
    1136                      self.loop.call_soon_threadsafe(fut.set_exception, ex)
    1137                      sock.close()
    1138                  else:
    1139                      self.loop.call_soon_threadsafe(fut.set_result, None)
    1140  
    1141              client = self.tcp_client(prog)
    1142              client.start()
    1143              clients.append(client)
    1144  
    1145              await fut
    1146  
    1147          async def start_server():
    1148              extras = {}
    1149  
    1150              srv = await self.loop.create_server(
    1151                  server_protocol_factory,
    1152                  '127.0.0.1', 0,
    1153                  family=socket.AF_INET,
    1154                  ssl=sslctx_1,
    1155                  **extras)
    1156  
    1157              try:
    1158                  srv_socks = srv.sockets
    1159                  self.assertTrue(srv_socks)
    1160  
    1161                  addr = srv_socks[0].getsockname()
    1162  
    1163                  tasks = []
    1164                  for _ in range(TOTAL_CNT):
    1165                      tasks.append(test_client(addr))
    1166  
    1167                  await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
    1168  
    1169              finally:
    1170                  self.loop.call_soon(srv.close)
    1171                  await srv.wait_closed()
    1172  
    1173          with self._silence_eof_received_warning():
    1174              self.loop.run_until_complete(start_server())
    1175  
    1176          self.assertEqual(CNT, TOTAL_CNT)
    1177  
    1178          for client in clients:
    1179              client.stop()
    1180  
    1181      def test_shutdown_cleanly(self):
    1182          CNT = 0
    1183          TOTAL_CNT = 25
    1184  
    1185          A_DATA = b'A' * 1024 * BUF_MULTIPLIER
    1186  
    1187          sslctx = self._create_server_ssl_context(
    1188              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1189          client_sslctx = self._create_client_ssl_context()
    1190  
    1191          def server(sock):
    1192              sock.starttls(
    1193                  sslctx,
    1194                  server_side=True)
    1195  
    1196              data = sock.recv_all(len(A_DATA))
    1197              self.assertEqual(data, A_DATA)
    1198              sock.send(b'OK')
    1199  
    1200              sock.unwrap()
    1201  
    1202              sock.close()
    1203  
    1204          async def client(addr):
    1205              extras = {}
    1206              extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
    1207  
    1208              reader, writer = await asyncio.open_connection(
    1209                  *addr,
    1210                  ssl=client_sslctx,
    1211                  server_hostname='',
    1212                  **extras)
    1213  
    1214              writer.write(A_DATA)
    1215              self.assertEqual(await reader.readexactly(2), b'OK')
    1216  
    1217              self.assertEqual(await reader.read(), b'')
    1218  
    1219              nonlocal CNT
    1220              CNT += 1
    1221  
    1222              writer.close()
    1223              await self.wait_closed(writer)
    1224  
    1225          def run(coro):
    1226              nonlocal CNT
    1227              CNT = 0
    1228  
    1229              async def _gather(*tasks):
    1230                  return await asyncio.gather(*tasks)
    1231  
    1232              with self.tcp_server(server,
    1233                                   max_clients=TOTAL_CNT,
    1234                                   backlog=TOTAL_CNT) as srv:
    1235                  tasks = []
    1236                  for _ in range(TOTAL_CNT):
    1237                      tasks.append(coro(srv.addr))
    1238  
    1239                  self.loop.run_until_complete(
    1240                      _gather(*tasks))
    1241  
    1242              self.assertEqual(CNT, TOTAL_CNT)
    1243  
    1244          with self._silence_eof_received_warning():
    1245              run(client)
    1246  
    1247      def test_flush_before_shutdown(self):
    1248          CHUNK = 1024 * 128
    1249          SIZE = 32
    1250  
    1251          sslctx = self._create_server_ssl_context(
    1252              test_utils.ONLYCERT, test_utils.ONLYKEY)
    1253          client_sslctx = self._create_client_ssl_context()
    1254  
    1255          future = None
    1256  
    1257          def server(sock):
    1258              sock.starttls(sslctx, server_side=True)
    1259              self.assertEqual(sock.recv_all(4), b'ping')
    1260              sock.send(b'pong')
    1261              time.sleep(0.5)  # hopefully stuck the TCP buffer
    1262              data = sock.recv_all(CHUNK * SIZE)
    1263              self.assertEqual(len(data), CHUNK * SIZE)
    1264              sock.close()
    1265  
    1266          def run(meth):
    1267              def wrapper(sock):
    1268                  try:
    1269                      meth(sock)
    1270                  except Exception as ex:
    1271                      self.loop.call_soon_threadsafe(future.set_exception, ex)
    1272                  else:
    1273                      self.loop.call_soon_threadsafe(future.set_result, None)
    1274              return wrapper
    1275  
    1276          async def client(addr):
    1277              nonlocal future
    1278              future = self.loop.create_future()
    1279              reader, writer = await asyncio.open_connection(
    1280                  *addr,
    1281                  ssl=client_sslctx,
    1282                  server_hostname='')
    1283              sslprotocol = writer.transport._ssl_protocol
    1284              writer.write(b'ping')
    1285              data = await reader.readexactly(4)
    1286              self.assertEqual(data, b'pong')
    1287  
    1288              sslprotocol.pause_writing()
    1289              for _ in range(SIZE):
    1290                  writer.write(b'x' * CHUNK)
    1291  
    1292              writer.close()
    1293              sslprotocol.resume_writing()
    1294  
    1295              await self.wait_closed(writer)
    1296              try:
    1297                  data = await reader.read()
    1298                  self.assertEqual(data, b'')
    1299              except ConnectionResetError:
    1300                  pass
    1301              await future
    1302  
    1303          with self.tcp_server(run(server)) as srv:
    1304              self.loop.run_until_complete(client(srv.addr))
    1305  
    1306      def test_remote_shutdown_receives_trailing_data(self):
    1307          CHUNK = 1024 * 128
    1308          SIZE = 32
    1309  
    1310          sslctx = self._create_server_ssl_context(
    1311              test_utils.ONLYCERT,
    1312              test_utils.ONLYKEY
    1313          )
    1314          client_sslctx = self._create_client_ssl_context()
    1315          future = None
    1316  
    1317          def server(sock):
    1318              incoming = ssl.MemoryBIO()
    1319              outgoing = ssl.MemoryBIO()
    1320              sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
    1321  
    1322              while True:
    1323                  try:
    1324                      sslobj.do_handshake()
    1325                  except ssl.SSLWantReadError:
    1326                      if outgoing.pending:
    1327                          sock.send(outgoing.read())
    1328                      incoming.write(sock.recv(16384))
    1329                  else:
    1330                      if outgoing.pending:
    1331                          sock.send(outgoing.read())
    1332                      break
    1333  
    1334              while True:
    1335                  try:
    1336                      data = sslobj.read(4)
    1337                  except ssl.SSLWantReadError:
    1338                      incoming.write(sock.recv(16384))
    1339                  else:
    1340                      break
    1341  
    1342              self.assertEqual(data, b'ping')
    1343              sslobj.write(b'pong')
    1344              sock.send(outgoing.read())
    1345  
    1346              time.sleep(0.2)  # wait for the peer to fill its backlog
    1347  
    1348              # send close_notify but don't wait for response
    1349              with self.assertRaises(ssl.SSLWantReadError):
    1350                  sslobj.unwrap()
    1351              sock.send(outgoing.read())
    1352  
    1353              # should receive all data
    1354              data_len = 0
    1355              while True:
    1356                  try:
    1357                      chunk = len(sslobj.read(16384))
    1358                      data_len += chunk
    1359                  except ssl.SSLWantReadError:
    1360                      incoming.write(sock.recv(16384))
    1361                  except ssl.SSLZeroReturnError:
    1362                      break
    1363  
    1364              self.assertEqual(data_len, CHUNK * SIZE)
    1365  
    1366              # verify that close_notify is received
    1367              sslobj.unwrap()
    1368  
    1369              sock.close()
    1370  
    1371          def eof_server(sock):
    1372              sock.starttls(sslctx, server_side=True)
    1373              self.assertEqual(sock.recv_all(4), b'ping')
    1374              sock.send(b'pong')
    1375  
    1376              time.sleep(0.2)  # wait for the peer to fill its backlog
    1377  
    1378              # send EOF
    1379              sock.shutdown(socket.SHUT_WR)
    1380  
    1381              # should receive all data
    1382              data = sock.recv_all(CHUNK * SIZE)
    1383              self.assertEqual(len(data), CHUNK * SIZE)
    1384  
    1385              sock.close()
    1386  
    1387          async def client(addr):
    1388              nonlocal future
    1389              future = self.loop.create_future()
    1390  
    1391              reader, writer = await asyncio.open_connection(
    1392                  *addr,
    1393                  ssl=client_sslctx,
    1394                  server_hostname='')
    1395              writer.write(b'ping')
    1396              data = await reader.readexactly(4)
    1397              self.assertEqual(data, b'pong')
    1398  
    1399              # fill write backlog in a hacky way - renegotiation won't help
    1400              for _ in range(SIZE):
    1401                  writer.transport._test__append_write_backlog(b'x' * CHUNK)
    1402  
    1403              try:
    1404                  data = await reader.read()
    1405                  self.assertEqual(data, b'')
    1406              except (BrokenPipeError, ConnectionResetError):
    1407                  pass
    1408  
    1409              await future
    1410  
    1411              writer.close()
    1412              await self.wait_closed(writer)
    1413  
    1414          def run(meth):
    1415              def wrapper(sock):
    1416                  try:
    1417                      meth(sock)
    1418                  except Exception as ex:
    1419                      self.loop.call_soon_threadsafe(future.set_exception, ex)
    1420                  else:
    1421                      self.loop.call_soon_threadsafe(future.set_result, None)
    1422              return wrapper
    1423  
    1424          with self.tcp_server(run(server)) as srv:
    1425              self.loop.run_until_complete(client(srv.addr))
    1426  
    1427          with self.tcp_server(run(eof_server)) as srv:
    1428              self.loop.run_until_complete(client(srv.addr))
    1429  
    1430      def test_connect_timeout_warning(self):
    1431          s = socket.socket(socket.AF_INET)
    1432          s.bind(('127.0.0.1', 0))
    1433          addr = s.getsockname()
    1434  
    1435          async def test():
    1436              try:
    1437                  await asyncio.wait_for(
    1438                      self.loop.create_connection(asyncio.Protocol,
    1439                                                  *addr, ssl=True),
    1440                      0.1)
    1441              except (ConnectionRefusedError, asyncio.TimeoutError):
    1442                  pass
    1443              else:
    1444                  self.fail('TimeoutError is not raised')
    1445  
    1446          with s:
    1447              try:
    1448                  with self.assertWarns(ResourceWarning) as cm:
    1449                      self.loop.run_until_complete(test())
    1450                      gc.collect()
    1451                      gc.collect()
    1452                      gc.collect()
    1453              except AssertionError as e:
    1454                  self.assertEqual(str(e), 'ResourceWarning not triggered')
    1455              else:
    1456                  self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
    1457  
    1458      def test_handshake_timeout_handler_leak(self):
    1459          s = socket.socket(socket.AF_INET)
    1460          s.bind(('127.0.0.1', 0))
    1461          s.listen(1)
    1462          addr = s.getsockname()
    1463  
    1464          async def test(ctx):
    1465              try:
    1466                  await asyncio.wait_for(
    1467                      self.loop.create_connection(asyncio.Protocol, *addr,
    1468                                                  ssl=ctx),
    1469                      0.1)
    1470              except (ConnectionRefusedError, asyncio.TimeoutError):
    1471                  pass
    1472              else:
    1473                  self.fail('TimeoutError is not raised')
    1474  
    1475          with s:
    1476              ctx = ssl.create_default_context()
    1477              self.loop.run_until_complete(test(ctx))
    1478              ctx = weakref.ref(ctx)
    1479  
    1480          # SSLProtocol should be DECREF to 0
    1481          self.assertIsNone(ctx())
    1482  
    1483      def test_shutdown_timeout_handler_leak(self):
    1484          loop = self.loop
    1485  
    1486          def server(sock):
    1487              sslctx = self._create_server_ssl_context(
    1488                  test_utils.ONLYCERT,
    1489                  test_utils.ONLYKEY
    1490              )
    1491              sock = sslctx.wrap_socket(sock, server_side=True)
    1492              sock.recv(32)
    1493              sock.close()
    1494  
    1495          class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
    1496              def __init__(self):
    1497                  self.fut = asyncio.Future(loop=loop)
    1498  
    1499              def connection_lost(self, exc):
    1500                  self.fut.set_result(None)
    1501  
    1502          async def client(addr, ctx):
    1503              tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
    1504              tr.close()
    1505              await pr.fut
    1506  
    1507          with self.tcp_server(server) as srv:
    1508              ctx = self._create_client_ssl_context()
    1509              loop.run_until_complete(client(srv.addr, ctx))
    1510              ctx = weakref.ref(ctx)
    1511  
    1512          # asyncio has no shutdown timeout, but it ends up with a circular
    1513          # reference loop - not ideal (introduces gc glitches), but at least
    1514          # not leaking
    1515          gc.collect()
    1516          gc.collect()
    1517          gc.collect()
    1518  
    1519          # SSLProtocol should be DECREF to 0
    1520          self.assertIsNone(ctx())
    1521  
    1522      def test_shutdown_timeout_handler_not_set(self):
    1523          loop = self.loop
    1524          eof = asyncio.Event()
    1525          extra = None
    1526  
    1527          def server(sock):
    1528              sslctx = self._create_server_ssl_context(
    1529                  test_utils.ONLYCERT,
    1530                  test_utils.ONLYKEY
    1531              )
    1532              sock = sslctx.wrap_socket(sock, server_side=True)
    1533              sock.send(b'hello')
    1534              assert sock.recv(1024) == b'world'
    1535              sock.send(b'extra bytes')
    1536              # sending EOF here
    1537              sock.shutdown(socket.SHUT_WR)
    1538              loop.call_soon_threadsafe(eof.set)
    1539              # make sure we have enough time to reproduce the issue
    1540              assert sock.recv(1024) == b''
    1541              sock.close()
    1542  
    1543          class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
    1544              def __init__(self):
    1545                  self.fut = asyncio.Future(loop=loop)
    1546                  self.transport = None
    1547  
    1548              def connection_made(self, transport):
    1549                  self.transport = transport
    1550  
    1551              def data_received(self, data):
    1552                  if data == b'hello':
    1553                      self.transport.write(b'world')
    1554                      # pause reading would make incoming data stay in the sslobj
    1555                      self.transport.pause_reading()
    1556                  else:
    1557                      nonlocal extra
    1558                      extra = data
    1559  
    1560              def connection_lost(self, exc):
    1561                  if exc is None:
    1562                      self.fut.set_result(None)
    1563                  else:
    1564                      self.fut.set_exception(exc)
    1565  
    1566          async def client(addr):
    1567              ctx = self._create_client_ssl_context()
    1568              tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
    1569              await eof.wait()
    1570              tr.resume_reading()
    1571              await pr.fut
    1572              tr.close()
    1573              assert extra == b'extra bytes'
    1574  
    1575          with self.tcp_server(server) as srv:
    1576              loop.run_until_complete(client(srv.addr))
    1577  
    1578  
    1579  ###############################################################################
    1580  # Socket Testing Utilities
    1581  ###############################################################################
    1582  
    1583  
    1584  class ESC[4;38;5;81mTestSocketWrapper:
    1585  
    1586      def __init__(self, sock):
    1587          self.__sock = sock
    1588  
    1589      def recv_all(self, n):
    1590          buf = b''
    1591          while len(buf) < n:
    1592              data = self.recv(n - len(buf))
    1593              if data == b'':
    1594                  raise ConnectionAbortedError
    1595              buf += data
    1596          return buf
    1597  
    1598      def starttls(self, ssl_context, *,
    1599                   server_side=False,
    1600                   server_hostname=None,
    1601                   do_handshake_on_connect=True):
    1602  
    1603          assert isinstance(ssl_context, ssl.SSLContext)
    1604  
    1605          ssl_sock = ssl_context.wrap_socket(
    1606              self.__sock, server_side=server_side,
    1607              server_hostname=server_hostname,
    1608              do_handshake_on_connect=do_handshake_on_connect)
    1609  
    1610          if server_side:
    1611              ssl_sock.do_handshake()
    1612  
    1613          self.__sock.close()
    1614          self.__sock = ssl_sock
    1615  
    1616      def __getattr__(self, name):
    1617          return getattr(self.__sock, name)
    1618  
    1619      def __repr__(self):
    1620          return '<{} {!r}>'.format(type(self).__name__, self.__sock)
    1621  
    1622  
    1623  class ESC[4;38;5;81mSocketThread(ESC[4;38;5;149mthreadingESC[4;38;5;149m.ESC[4;38;5;149mThread):
    1624  
    1625      def stop(self):
    1626          self._active = False
    1627          self.join()
    1628  
    1629      def __enter__(self):
    1630          self.start()
    1631          return self
    1632  
    1633      def __exit__(self, *exc):
    1634          self.stop()
    1635  
    1636  
    1637  class ESC[4;38;5;81mTestThreadedClient(ESC[4;38;5;149mSocketThread):
    1638  
    1639      def __init__(self, test, sock, prog, timeout):
    1640          threading.Thread.__init__(self, None, None, 'test-client')
    1641          self.daemon = True
    1642  
    1643          self._timeout = timeout
    1644          self._sock = sock
    1645          self._active = True
    1646          self._prog = prog
    1647          self._test = test
    1648  
    1649      def run(self):
    1650          try:
    1651              self._prog(TestSocketWrapper(self._sock))
    1652          except (KeyboardInterrupt, SystemExit):
    1653              raise
    1654          except BaseException as ex:
    1655              self._test._abort_socket_test(ex)
    1656  
    1657  
    1658  class ESC[4;38;5;81mTestThreadedServer(ESC[4;38;5;149mSocketThread):
    1659  
    1660      def __init__(self, test, sock, prog, timeout, max_clients):
    1661          threading.Thread.__init__(self, None, None, 'test-server')
    1662          self.daemon = True
    1663  
    1664          self._clients = 0
    1665          self._finished_clients = 0
    1666          self._max_clients = max_clients
    1667          self._timeout = timeout
    1668          self._sock = sock
    1669          self._active = True
    1670  
    1671          self._prog = prog
    1672  
    1673          self._s1, self._s2 = socket.socketpair()
    1674          self._s1.setblocking(False)
    1675  
    1676          self._test = test
    1677  
    1678      def stop(self):
    1679          try:
    1680              if self._s2 and self._s2.fileno() != -1:
    1681                  try:
    1682                      self._s2.send(b'stop')
    1683                  except OSError:
    1684                      pass
    1685          finally:
    1686              super().stop()
    1687  
    1688      def run(self):
    1689          try:
    1690              with self._sock:
    1691                  self._sock.setblocking(False)
    1692                  self._run()
    1693          finally:
    1694              self._s1.close()
    1695              self._s2.close()
    1696  
    1697      def _run(self):
    1698          while self._active:
    1699              if self._clients >= self._max_clients:
    1700                  return
    1701  
    1702              r, w, x = select.select(
    1703                  [self._sock, self._s1], [], [], self._timeout)
    1704  
    1705              if self._s1 in r:
    1706                  return
    1707  
    1708              if self._sock in r:
    1709                  try:
    1710                      conn, addr = self._sock.accept()
    1711                  except BlockingIOError:
    1712                      continue
    1713                  except socket.timeout:
    1714                      if not self._active:
    1715                          return
    1716                      else:
    1717                          raise
    1718                  else:
    1719                      self._clients += 1
    1720                      conn.settimeout(self._timeout)
    1721                      try:
    1722                          with conn:
    1723                              self._handle_client(conn)
    1724                      except (KeyboardInterrupt, SystemExit):
    1725                          raise
    1726                      except BaseException as ex:
    1727                          self._active = False
    1728                          try:
    1729                              raise
    1730                          finally:
    1731                              self._test._abort_socket_test(ex)
    1732  
    1733      def _handle_client(self, sock):
    1734          self._prog(TestSocketWrapper(sock))
    1735  
    1736      @property
    1737      def addr(self):
    1738          return self._sock.getsockname()