(root)/
Python-3.11.7/
Lib/
test/
test_asyncio/
test_sslproto.py
       1  """Tests for asyncio/sslproto.py."""
       2  
       3  import logging
       4  import socket
       5  import unittest
       6  import weakref
       7  from test import support
       8  from test.support import socket_helper
       9  from unittest import mock
      10  try:
      11      import ssl
      12  except ImportError:
      13      ssl = None
      14  
      15  import asyncio
      16  from asyncio import log
      17  from asyncio import protocols
      18  from asyncio import sslproto
      19  from test.test_asyncio import utils as test_utils
      20  from test.test_asyncio import functional as func_tests
      21  
      22  
      23  def tearDownModule():
      24      asyncio.set_event_loop_policy(None)
      25  
      26  
      27  @unittest.skipIf(ssl is None, 'No ssl module')
      28  class ESC[4;38;5;81mSslProtoHandshakeTests(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      29  
      30      def setUp(self):
      31          super().setUp()
      32          self.loop = asyncio.new_event_loop()
      33          self.set_event_loop(self.loop)
      34  
      35      def ssl_protocol(self, *, waiter=None, proto=None):
      36          sslcontext = test_utils.dummy_ssl_context()
      37          if proto is None:  # app protocol
      38              proto = asyncio.Protocol()
      39          ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
      40                                           ssl_handshake_timeout=0.1)
      41          self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
      42          self.addCleanup(ssl_proto._app_transport.close)
      43          return ssl_proto
      44  
      45      def connection_made(self, ssl_proto, *, do_handshake=None):
      46          transport = mock.Mock()
      47          sslobj = mock.Mock()
      48          # emulate reading decompressed data
      49          sslobj.read.side_effect = ssl.SSLWantReadError
      50          if do_handshake is not None:
      51              sslobj.do_handshake = do_handshake
      52          ssl_proto._sslobj = sslobj
      53          ssl_proto.connection_made(transport)
      54          return transport
      55  
      56      def test_handshake_timeout_zero(self):
      57          sslcontext = test_utils.dummy_ssl_context()
      58          app_proto = mock.Mock()
      59          waiter = mock.Mock()
      60          with self.assertRaisesRegex(ValueError, 'a positive number'):
      61              sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
      62                                   ssl_handshake_timeout=0)
      63  
      64      def test_handshake_timeout_negative(self):
      65          sslcontext = test_utils.dummy_ssl_context()
      66          app_proto = mock.Mock()
      67          waiter = mock.Mock()
      68          with self.assertRaisesRegex(ValueError, 'a positive number'):
      69              sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
      70                                   ssl_handshake_timeout=-10)
      71  
      72      def test_eof_received_waiter(self):
      73          waiter = self.loop.create_future()
      74          ssl_proto = self.ssl_protocol(waiter=waiter)
      75          self.connection_made(
      76              ssl_proto,
      77              do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
      78          )
      79          ssl_proto.eof_received()
      80          test_utils.run_briefly(self.loop)
      81          self.assertIsInstance(waiter.exception(), ConnectionResetError)
      82  
      83      def test_fatal_error_no_name_error(self):
      84          # From issue #363.
      85          # _fatal_error() generates a NameError if sslproto.py
      86          # does not import base_events.
      87          waiter = self.loop.create_future()
      88          ssl_proto = self.ssl_protocol(waiter=waiter)
      89          # Temporarily turn off error logging so as not to spoil test output.
      90          log_level = log.logger.getEffectiveLevel()
      91          log.logger.setLevel(logging.FATAL)
      92          try:
      93              ssl_proto._fatal_error(None)
      94          finally:
      95              # Restore error logging.
      96              log.logger.setLevel(log_level)
      97  
      98      def test_connection_lost(self):
      99          # From issue #472.
     100          # yield from waiter hang if lost_connection was called.
     101          waiter = self.loop.create_future()
     102          ssl_proto = self.ssl_protocol(waiter=waiter)
     103          self.connection_made(
     104              ssl_proto,
     105              do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
     106          )
     107          ssl_proto.connection_lost(ConnectionAbortedError)
     108          test_utils.run_briefly(self.loop)
     109          self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
     110  
     111      def test_close_during_handshake(self):
     112          # bpo-29743 Closing transport during handshake process leaks socket
     113          waiter = self.loop.create_future()
     114          ssl_proto = self.ssl_protocol(waiter=waiter)
     115  
     116          transport = self.connection_made(
     117              ssl_proto,
     118              do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
     119          )
     120          test_utils.run_briefly(self.loop)
     121  
     122          ssl_proto._app_transport.close()
     123          self.assertTrue(transport.abort.called)
     124  
     125      def test_get_extra_info_on_closed_connection(self):
     126          waiter = self.loop.create_future()
     127          ssl_proto = self.ssl_protocol(waiter=waiter)
     128          self.assertIsNone(ssl_proto._get_extra_info('socket'))
     129          default = object()
     130          self.assertIs(ssl_proto._get_extra_info('socket', default), default)
     131          self.connection_made(ssl_proto)
     132          self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
     133          ssl_proto.connection_lost(None)
     134          self.assertIsNone(ssl_proto._get_extra_info('socket'))
     135  
     136      def test_set_new_app_protocol(self):
     137          waiter = self.loop.create_future()
     138          ssl_proto = self.ssl_protocol(waiter=waiter)
     139          new_app_proto = asyncio.Protocol()
     140          ssl_proto._app_transport.set_protocol(new_app_proto)
     141          self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
     142          self.assertIs(ssl_proto._app_protocol, new_app_proto)
     143  
     144      def test_data_received_after_closing(self):
     145          ssl_proto = self.ssl_protocol()
     146          self.connection_made(ssl_proto)
     147          transp = ssl_proto._app_transport
     148  
     149          transp.close()
     150  
     151          # should not raise
     152          self.assertIsNone(ssl_proto.buffer_updated(5))
     153  
     154      def test_write_after_closing(self):
     155          ssl_proto = self.ssl_protocol()
     156          self.connection_made(ssl_proto)
     157          transp = ssl_proto._app_transport
     158          transp.close()
     159  
     160          # should not raise
     161          self.assertIsNone(transp.write(b'data'))
     162  
     163  
     164  ##############################################################################
     165  # Start TLS Tests
     166  ##############################################################################
     167  
     168  
     169  class ESC[4;38;5;81mBaseStartTLS(ESC[4;38;5;149mfunc_testsESC[4;38;5;149m.ESC[4;38;5;149mFunctionalTestCaseMixin):
     170  
     171      PAYLOAD_SIZE = 1024 * 100
     172      TIMEOUT = support.LONG_TIMEOUT
     173  
     174      def new_loop(self):
     175          raise NotImplementedError
     176  
     177      def test_buf_feed_data(self):
     178  
     179          class ESC[4;38;5;81mProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
     180  
     181              def __init__(self, bufsize, usemv):
     182                  self.buf = bytearray(bufsize)
     183                  self.mv = memoryview(self.buf)
     184                  self.data = b''
     185                  self.usemv = usemv
     186  
     187              def get_buffer(self, sizehint):
     188                  if self.usemv:
     189                      return self.mv
     190                  else:
     191                      return self.buf
     192  
     193              def buffer_updated(self, nsize):
     194                  if self.usemv:
     195                      self.data += self.mv[:nsize]
     196                  else:
     197                      self.data += self.buf[:nsize]
     198  
     199          for usemv in [False, True]:
     200              proto = Proto(1, usemv)
     201              protocols._feed_data_to_buffered_proto(proto, b'12345')
     202              self.assertEqual(proto.data, b'12345')
     203  
     204              proto = Proto(2, usemv)
     205              protocols._feed_data_to_buffered_proto(proto, b'12345')
     206              self.assertEqual(proto.data, b'12345')
     207  
     208              proto = Proto(2, usemv)
     209              protocols._feed_data_to_buffered_proto(proto, b'1234')
     210              self.assertEqual(proto.data, b'1234')
     211  
     212              proto = Proto(4, usemv)
     213              protocols._feed_data_to_buffered_proto(proto, b'1234')
     214              self.assertEqual(proto.data, b'1234')
     215  
     216              proto = Proto(100, usemv)
     217              protocols._feed_data_to_buffered_proto(proto, b'12345')
     218              self.assertEqual(proto.data, b'12345')
     219  
     220              proto = Proto(0, usemv)
     221              with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
     222                  protocols._feed_data_to_buffered_proto(proto, b'12345')
     223  
     224      def test_start_tls_client_reg_proto_1(self):
     225          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     226  
     227          server_context = test_utils.simple_server_sslcontext()
     228          client_context = test_utils.simple_client_sslcontext()
     229  
     230          def serve(sock):
     231              sock.settimeout(self.TIMEOUT)
     232  
     233              data = sock.recv_all(len(HELLO_MSG))
     234              self.assertEqual(len(data), len(HELLO_MSG))
     235  
     236              sock.start_tls(server_context, server_side=True)
     237  
     238              sock.sendall(b'O')
     239              data = sock.recv_all(len(HELLO_MSG))
     240              self.assertEqual(len(data), len(HELLO_MSG))
     241  
     242              sock.shutdown(socket.SHUT_RDWR)
     243              sock.close()
     244  
     245          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     246              def __init__(self, on_data, on_eof):
     247                  self.on_data = on_data
     248                  self.on_eof = on_eof
     249                  self.con_made_cnt = 0
     250  
     251              def connection_made(proto, tr):
     252                  proto.con_made_cnt += 1
     253                  # Ensure connection_made gets called only once.
     254                  self.assertEqual(proto.con_made_cnt, 1)
     255  
     256              def data_received(self, data):
     257                  self.on_data.set_result(data)
     258  
     259              def eof_received(self):
     260                  self.on_eof.set_result(True)
     261  
     262          async def client(addr):
     263              await asyncio.sleep(0.5)
     264  
     265              on_data = self.loop.create_future()
     266              on_eof = self.loop.create_future()
     267  
     268              tr, proto = await self.loop.create_connection(
     269                  lambda: ClientProto(on_data, on_eof), *addr)
     270  
     271              tr.write(HELLO_MSG)
     272              new_tr = await self.loop.start_tls(tr, proto, client_context)
     273  
     274              self.assertEqual(await on_data, b'O')
     275              new_tr.write(HELLO_MSG)
     276              await on_eof
     277  
     278              new_tr.close()
     279  
     280          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     281              self.loop.run_until_complete(
     282                  asyncio.wait_for(client(srv.addr),
     283                                   timeout=support.SHORT_TIMEOUT))
     284  
     285          # No garbage is left if SSL is closed uncleanly
     286          client_context = weakref.ref(client_context)
     287          support.gc_collect()
     288          self.assertIsNone(client_context())
     289  
     290      def test_create_connection_memory_leak(self):
     291          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     292  
     293          server_context = test_utils.simple_server_sslcontext()
     294          client_context = test_utils.simple_client_sslcontext()
     295  
     296          def serve(sock):
     297              sock.settimeout(self.TIMEOUT)
     298  
     299              sock.start_tls(server_context, server_side=True)
     300  
     301              sock.sendall(b'O')
     302              data = sock.recv_all(len(HELLO_MSG))
     303              self.assertEqual(len(data), len(HELLO_MSG))
     304  
     305              sock.shutdown(socket.SHUT_RDWR)
     306              sock.close()
     307  
     308          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     309              def __init__(self, on_data, on_eof):
     310                  self.on_data = on_data
     311                  self.on_eof = on_eof
     312                  self.con_made_cnt = 0
     313  
     314              def connection_made(proto, tr):
     315                  # XXX: We assume user stores the transport in protocol
     316                  proto.tr = tr
     317                  proto.con_made_cnt += 1
     318                  # Ensure connection_made gets called only once.
     319                  self.assertEqual(proto.con_made_cnt, 1)
     320  
     321              def data_received(self, data):
     322                  self.on_data.set_result(data)
     323  
     324              def eof_received(self):
     325                  self.on_eof.set_result(True)
     326  
     327          async def client(addr):
     328              await asyncio.sleep(0.5)
     329  
     330              on_data = self.loop.create_future()
     331              on_eof = self.loop.create_future()
     332  
     333              tr, proto = await self.loop.create_connection(
     334                  lambda: ClientProto(on_data, on_eof), *addr,
     335                  ssl=client_context)
     336  
     337              self.assertEqual(await on_data, b'O')
     338              tr.write(HELLO_MSG)
     339              await on_eof
     340  
     341              tr.close()
     342  
     343          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     344              self.loop.run_until_complete(
     345                  asyncio.wait_for(client(srv.addr),
     346                                   timeout=support.SHORT_TIMEOUT))
     347  
     348          # No garbage is left for SSL client from loop.create_connection, even
     349          # if user stores the SSLTransport in corresponding protocol instance
     350          client_context = weakref.ref(client_context)
     351          support.gc_collect()
     352          self.assertIsNone(client_context())
     353  
     354      @socket_helper.skip_if_tcp_blackhole
     355      def test_start_tls_client_buf_proto_1(self):
     356          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     357  
     358          server_context = test_utils.simple_server_sslcontext()
     359          client_context = test_utils.simple_client_sslcontext()
     360          client_con_made_calls = 0
     361  
     362          def serve(sock):
     363              sock.settimeout(self.TIMEOUT)
     364  
     365              data = sock.recv_all(len(HELLO_MSG))
     366              self.assertEqual(len(data), len(HELLO_MSG))
     367  
     368              sock.start_tls(server_context, server_side=True)
     369  
     370              sock.sendall(b'O')
     371              data = sock.recv_all(len(HELLO_MSG))
     372              self.assertEqual(len(data), len(HELLO_MSG))
     373  
     374              sock.sendall(b'2')
     375              data = sock.recv_all(len(HELLO_MSG))
     376              self.assertEqual(len(data), len(HELLO_MSG))
     377  
     378              sock.shutdown(socket.SHUT_RDWR)
     379              sock.close()
     380  
     381          class ESC[4;38;5;81mClientProtoFirst(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
     382              def __init__(self, on_data):
     383                  self.on_data = on_data
     384                  self.buf = bytearray(1)
     385  
     386              def connection_made(self, tr):
     387                  nonlocal client_con_made_calls
     388                  client_con_made_calls += 1
     389  
     390              def get_buffer(self, sizehint):
     391                  return self.buf
     392  
     393              def buffer_updated(slf, nsize):
     394                  self.assertEqual(nsize, 1)
     395                  slf.on_data.set_result(bytes(slf.buf[:nsize]))
     396  
     397          class ESC[4;38;5;81mClientProtoSecond(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     398              def __init__(self, on_data, on_eof):
     399                  self.on_data = on_data
     400                  self.on_eof = on_eof
     401                  self.con_made_cnt = 0
     402  
     403              def connection_made(self, tr):
     404                  nonlocal client_con_made_calls
     405                  client_con_made_calls += 1
     406  
     407              def data_received(self, data):
     408                  self.on_data.set_result(data)
     409  
     410              def eof_received(self):
     411                  self.on_eof.set_result(True)
     412  
     413          async def client(addr):
     414              await asyncio.sleep(0.5)
     415  
     416              on_data1 = self.loop.create_future()
     417              on_data2 = self.loop.create_future()
     418              on_eof = self.loop.create_future()
     419  
     420              tr, proto = await self.loop.create_connection(
     421                  lambda: ClientProtoFirst(on_data1), *addr)
     422  
     423              tr.write(HELLO_MSG)
     424              new_tr = await self.loop.start_tls(tr, proto, client_context)
     425  
     426              self.assertEqual(await on_data1, b'O')
     427              new_tr.write(HELLO_MSG)
     428  
     429              new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
     430              self.assertEqual(await on_data2, b'2')
     431              new_tr.write(HELLO_MSG)
     432              await on_eof
     433  
     434              new_tr.close()
     435  
     436              # connection_made() should be called only once -- when
     437              # we establish connection for the first time. Start TLS
     438              # doesn't call connection_made() on application protocols.
     439              self.assertEqual(client_con_made_calls, 1)
     440  
     441          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     442              self.loop.run_until_complete(
     443                  asyncio.wait_for(client(srv.addr),
     444                                   timeout=self.TIMEOUT))
     445  
     446      def test_start_tls_slow_client_cancel(self):
     447          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     448  
     449          client_context = test_utils.simple_client_sslcontext()
     450          server_waits_on_handshake = self.loop.create_future()
     451  
     452          def serve(sock):
     453              sock.settimeout(self.TIMEOUT)
     454  
     455              data = sock.recv_all(len(HELLO_MSG))
     456              self.assertEqual(len(data), len(HELLO_MSG))
     457  
     458              try:
     459                  self.loop.call_soon_threadsafe(
     460                      server_waits_on_handshake.set_result, None)
     461                  data = sock.recv_all(1024 * 1024)
     462              except ConnectionAbortedError:
     463                  pass
     464              finally:
     465                  sock.close()
     466  
     467          class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     468              def __init__(self, on_data, on_eof):
     469                  self.on_data = on_data
     470                  self.on_eof = on_eof
     471                  self.con_made_cnt = 0
     472  
     473              def connection_made(proto, tr):
     474                  proto.con_made_cnt += 1
     475                  # Ensure connection_made gets called only once.
     476                  self.assertEqual(proto.con_made_cnt, 1)
     477  
     478              def data_received(self, data):
     479                  self.on_data.set_result(data)
     480  
     481              def eof_received(self):
     482                  self.on_eof.set_result(True)
     483  
     484          async def client(addr):
     485              await asyncio.sleep(0.5)
     486  
     487              on_data = self.loop.create_future()
     488              on_eof = self.loop.create_future()
     489  
     490              tr, proto = await self.loop.create_connection(
     491                  lambda: ClientProto(on_data, on_eof), *addr)
     492  
     493              tr.write(HELLO_MSG)
     494  
     495              await server_waits_on_handshake
     496  
     497              with self.assertRaises(asyncio.TimeoutError):
     498                  await asyncio.wait_for(
     499                      self.loop.start_tls(tr, proto, client_context),
     500                      0.5)
     501  
     502          with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
     503              self.loop.run_until_complete(
     504                  asyncio.wait_for(client(srv.addr),
     505                                   timeout=support.SHORT_TIMEOUT))
     506  
     507      @socket_helper.skip_if_tcp_blackhole
     508      def test_start_tls_server_1(self):
     509          HELLO_MSG = b'1' * self.PAYLOAD_SIZE
     510          ANSWER = b'answer'
     511  
     512          server_context = test_utils.simple_server_sslcontext()
     513          client_context = test_utils.simple_client_sslcontext()
     514          answer = None
     515  
     516          def client(sock, addr):
     517              nonlocal answer
     518              sock.settimeout(self.TIMEOUT)
     519  
     520              sock.connect(addr)
     521              data = sock.recv_all(len(HELLO_MSG))
     522              self.assertEqual(len(data), len(HELLO_MSG))
     523  
     524              sock.start_tls(client_context)
     525              sock.sendall(HELLO_MSG)
     526              answer = sock.recv_all(len(ANSWER))
     527              sock.close()
     528  
     529          class ESC[4;38;5;81mServerProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
     530              def __init__(self, on_con, on_con_lost, on_got_hello):
     531                  self.on_con = on_con
     532                  self.on_con_lost = on_con_lost
     533                  self.on_got_hello = on_got_hello
     534                  self.data = b''
     535                  self.transport = None
     536  
     537              def connection_made(self, tr):
     538                  self.transport = tr
     539                  self.on_con.set_result(tr)
     540  
     541              def replace_transport(self, tr):
     542                  self.transport = tr
     543  
     544              def data_received(self, data):
     545                  self.data += data
     546                  if len(self.data) >= len(HELLO_MSG):
     547                      self.on_got_hello.set_result(None)
     548  
     549              def connection_lost(self, exc):
     550                  self.transport = None
     551                  if exc is None:
     552                      self.on_con_lost.set_result(None)
     553                  else:
     554                      self.on_con_lost.set_exception(exc)
     555  
     556          async def main(proto, on_con, on_con_lost, on_got_hello):
     557              tr = await on_con
     558              tr.write(HELLO_MSG)
     559  
     560              self.assertEqual(proto.data, b'')
     561  
     562              new_tr = await self.loop.start_tls(
     563                  tr, proto, server_context,
     564                  server_side=True,
     565                  ssl_handshake_timeout=self.TIMEOUT)
     566              proto.replace_transport(new_tr)
     567  
     568              await on_got_hello
     569              new_tr.write(ANSWER)
     570  
     571              await on_con_lost
     572              self.assertEqual(proto.data, HELLO_MSG)
     573              new_tr.close()
     574  
     575          async def run_main():
     576              on_con = self.loop.create_future()
     577              on_con_lost = self.loop.create_future()
     578              on_got_hello = self.loop.create_future()
     579              proto = ServerProto(on_con, on_con_lost, on_got_hello)
     580  
     581              server = await self.loop.create_server(
     582                  lambda: proto, '127.0.0.1', 0)
     583              addr = server.sockets[0].getsockname()
     584  
     585              with self.tcp_client(lambda sock: client(sock, addr),
     586                                   timeout=self.TIMEOUT):
     587                  await asyncio.wait_for(
     588                      main(proto, on_con, on_con_lost, on_got_hello),
     589                      timeout=self.TIMEOUT)
     590  
     591              server.close()
     592              await server.wait_closed()
     593              self.assertEqual(answer, ANSWER)
     594  
     595          self.loop.run_until_complete(run_main())
     596  
     597      def test_start_tls_wrong_args(self):
     598          async def main():
     599              with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
     600                  await self.loop.start_tls(None, None, None)
     601  
     602              sslctx = test_utils.simple_server_sslcontext()
     603              with self.assertRaisesRegex(TypeError, 'is not supported'):
     604                  await self.loop.start_tls(None, None, sslctx)
     605  
     606          self.loop.run_until_complete(main())
     607  
     608      def test_handshake_timeout(self):
     609          # bpo-29970: Check that a connection is aborted if handshake is not
     610          # completed in timeout period, instead of remaining open indefinitely
     611          client_sslctx = test_utils.simple_client_sslcontext()
     612  
     613          messages = []
     614          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     615  
     616          server_side_aborted = False
     617  
     618          def server(sock):
     619              nonlocal server_side_aborted
     620              try:
     621                  sock.recv_all(1024 * 1024)
     622              except ConnectionAbortedError:
     623                  server_side_aborted = True
     624              finally:
     625                  sock.close()
     626  
     627          async def client(addr):
     628              await asyncio.wait_for(
     629                  self.loop.create_connection(
     630                      asyncio.Protocol,
     631                      *addr,
     632                      ssl=client_sslctx,
     633                      server_hostname='',
     634                      ssl_handshake_timeout=support.SHORT_TIMEOUT),
     635                  0.5)
     636  
     637          with self.tcp_server(server,
     638                               max_clients=1,
     639                               backlog=1) as srv:
     640  
     641              with self.assertRaises(asyncio.TimeoutError):
     642                  self.loop.run_until_complete(client(srv.addr))
     643  
     644          self.assertTrue(server_side_aborted)
     645  
     646          # Python issue #23197: cancelling a handshake must not raise an
     647          # exception or log an error, even if the handshake failed
     648          self.assertEqual(messages, [])
     649  
     650          # The 10s handshake timeout should be cancelled to free related
     651          # objects without really waiting for 10s
     652          client_sslctx = weakref.ref(client_sslctx)
     653          support.gc_collect()
     654          self.assertIsNone(client_sslctx())
     655  
     656      def test_create_connection_ssl_slow_handshake(self):
     657          client_sslctx = test_utils.simple_client_sslcontext()
     658  
     659          messages = []
     660          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     661  
     662          def server(sock):
     663              try:
     664                  sock.recv_all(1024 * 1024)
     665              except ConnectionAbortedError:
     666                  pass
     667              finally:
     668                  sock.close()
     669  
     670          async def client(addr):
     671              reader, writer = await asyncio.open_connection(
     672                  *addr,
     673                  ssl=client_sslctx,
     674                  server_hostname='',
     675                  ssl_handshake_timeout=1.0)
     676  
     677          with self.tcp_server(server,
     678                               max_clients=1,
     679                               backlog=1) as srv:
     680  
     681              with self.assertRaisesRegex(
     682                      ConnectionAbortedError,
     683                      r'SSL handshake.*is taking longer'):
     684  
     685                  self.loop.run_until_complete(client(srv.addr))
     686  
     687          self.assertEqual(messages, [])
     688  
     689      def test_create_connection_ssl_failed_certificate(self):
     690          self.loop.set_exception_handler(lambda loop, ctx: None)
     691  
     692          sslctx = test_utils.simple_server_sslcontext()
     693          client_sslctx = test_utils.simple_client_sslcontext(
     694              disable_verify=False)
     695  
     696          def server(sock):
     697              try:
     698                  sock.start_tls(
     699                      sslctx,
     700                      server_side=True)
     701              except ssl.SSLError:
     702                  pass
     703              except OSError:
     704                  pass
     705              finally:
     706                  sock.close()
     707  
     708          async def client(addr):
     709              reader, writer = await asyncio.open_connection(
     710                  *addr,
     711                  ssl=client_sslctx,
     712                  server_hostname='',
     713                  ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
     714  
     715          with self.tcp_server(server,
     716                               max_clients=1,
     717                               backlog=1) as srv:
     718  
     719              with self.assertRaises(ssl.SSLCertVerificationError):
     720                  self.loop.run_until_complete(client(srv.addr))
     721  
     722      def test_start_tls_client_corrupted_ssl(self):
     723          self.loop.set_exception_handler(lambda loop, ctx: None)
     724  
     725          sslctx = test_utils.simple_server_sslcontext()
     726          client_sslctx = test_utils.simple_client_sslcontext()
     727  
     728          def server(sock):
     729              orig_sock = sock.dup()
     730              try:
     731                  sock.start_tls(
     732                      sslctx,
     733                      server_side=True)
     734                  sock.sendall(b'A\n')
     735                  sock.recv_all(1)
     736                  orig_sock.send(b'please corrupt the SSL connection')
     737              except ssl.SSLError:
     738                  pass
     739              finally:
     740                  orig_sock.close()
     741                  sock.close()
     742  
     743          async def client(addr):
     744              reader, writer = await asyncio.open_connection(
     745                  *addr,
     746                  ssl=client_sslctx,
     747                  server_hostname='')
     748  
     749              self.assertEqual(await reader.readline(), b'A\n')
     750              writer.write(b'B')
     751              with self.assertRaises(ssl.SSLError):
     752                  await reader.readline()
     753  
     754              writer.close()
     755              return 'OK'
     756  
     757          with self.tcp_server(server,
     758                               max_clients=1,
     759                               backlog=1) as srv:
     760  
     761              res = self.loop.run_until_complete(client(srv.addr))
     762  
     763          self.assertEqual(res, 'OK')
     764  
     765  
     766  @unittest.skipIf(ssl is None, 'No ssl module')
     767  class ESC[4;38;5;81mSelectorStartTLSTests(ESC[4;38;5;149mBaseStartTLS, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     768  
     769      def new_loop(self):
     770          return asyncio.SelectorEventLoop()
     771  
     772  
     773  @unittest.skipIf(ssl is None, 'No ssl module')
     774  @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
     775  class ESC[4;38;5;81mProactorStartTLSTests(ESC[4;38;5;149mBaseStartTLS, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     776  
     777      def new_loop(self):
     778          return asyncio.ProactorEventLoop()
     779  
     780  
     781  if __name__ == '__main__':
     782      unittest.main()