1 """Tests for asyncio/sslproto.py."""
2
3 import logging
4 import socket
5 import unittest
6 import weakref
7 from test import support
8 from test.support import socket_helper
9 from unittest import mock
10 try:
11 import ssl
12 except ImportError:
13 ssl = None
14
15 import asyncio
16 from asyncio import log
17 from asyncio import protocols
18 from asyncio import sslproto
19 from test.test_asyncio import utils as test_utils
20 from test.test_asyncio import functional as func_tests
21
22
23 def tearDownModule():
24 asyncio.set_event_loop_policy(None)
25
26
27 @unittest.skipIf(ssl is None, 'No ssl module')
28 class ESC[4;38;5;81mSslProtoHandshakeTests(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
29
30 def setUp(self):
31 super().setUp()
32 self.loop = asyncio.new_event_loop()
33 self.set_event_loop(self.loop)
34
35 def ssl_protocol(self, *, waiter=None, proto=None):
36 sslcontext = test_utils.dummy_ssl_context()
37 if proto is None: # app protocol
38 proto = asyncio.Protocol()
39 ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
40 ssl_handshake_timeout=0.1)
41 self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
42 self.addCleanup(ssl_proto._app_transport.close)
43 return ssl_proto
44
45 def connection_made(self, ssl_proto, *, do_handshake=None):
46 transport = mock.Mock()
47 sslobj = mock.Mock()
48 # emulate reading decompressed data
49 sslobj.read.side_effect = ssl.SSLWantReadError
50 if do_handshake is not None:
51 sslobj.do_handshake = do_handshake
52 ssl_proto._sslobj = sslobj
53 ssl_proto.connection_made(transport)
54 return transport
55
56 def test_handshake_timeout_zero(self):
57 sslcontext = test_utils.dummy_ssl_context()
58 app_proto = mock.Mock()
59 waiter = mock.Mock()
60 with self.assertRaisesRegex(ValueError, 'a positive number'):
61 sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
62 ssl_handshake_timeout=0)
63
64 def test_handshake_timeout_negative(self):
65 sslcontext = test_utils.dummy_ssl_context()
66 app_proto = mock.Mock()
67 waiter = mock.Mock()
68 with self.assertRaisesRegex(ValueError, 'a positive number'):
69 sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
70 ssl_handshake_timeout=-10)
71
72 def test_eof_received_waiter(self):
73 waiter = self.loop.create_future()
74 ssl_proto = self.ssl_protocol(waiter=waiter)
75 self.connection_made(
76 ssl_proto,
77 do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
78 )
79 ssl_proto.eof_received()
80 test_utils.run_briefly(self.loop)
81 self.assertIsInstance(waiter.exception(), ConnectionResetError)
82
83 def test_fatal_error_no_name_error(self):
84 # From issue #363.
85 # _fatal_error() generates a NameError if sslproto.py
86 # does not import base_events.
87 waiter = self.loop.create_future()
88 ssl_proto = self.ssl_protocol(waiter=waiter)
89 # Temporarily turn off error logging so as not to spoil test output.
90 log_level = log.logger.getEffectiveLevel()
91 log.logger.setLevel(logging.FATAL)
92 try:
93 ssl_proto._fatal_error(None)
94 finally:
95 # Restore error logging.
96 log.logger.setLevel(log_level)
97
98 def test_connection_lost(self):
99 # From issue #472.
100 # yield from waiter hang if lost_connection was called.
101 waiter = self.loop.create_future()
102 ssl_proto = self.ssl_protocol(waiter=waiter)
103 self.connection_made(
104 ssl_proto,
105 do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
106 )
107 ssl_proto.connection_lost(ConnectionAbortedError)
108 test_utils.run_briefly(self.loop)
109 self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
110
111 def test_close_during_handshake(self):
112 # bpo-29743 Closing transport during handshake process leaks socket
113 waiter = self.loop.create_future()
114 ssl_proto = self.ssl_protocol(waiter=waiter)
115
116 transport = self.connection_made(
117 ssl_proto,
118 do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
119 )
120 test_utils.run_briefly(self.loop)
121
122 ssl_proto._app_transport.close()
123 self.assertTrue(transport.abort.called)
124
125 def test_get_extra_info_on_closed_connection(self):
126 waiter = self.loop.create_future()
127 ssl_proto = self.ssl_protocol(waiter=waiter)
128 self.assertIsNone(ssl_proto._get_extra_info('socket'))
129 default = object()
130 self.assertIs(ssl_proto._get_extra_info('socket', default), default)
131 self.connection_made(ssl_proto)
132 self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
133 ssl_proto.connection_lost(None)
134 self.assertIsNone(ssl_proto._get_extra_info('socket'))
135
136 def test_set_new_app_protocol(self):
137 waiter = self.loop.create_future()
138 ssl_proto = self.ssl_protocol(waiter=waiter)
139 new_app_proto = asyncio.Protocol()
140 ssl_proto._app_transport.set_protocol(new_app_proto)
141 self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
142 self.assertIs(ssl_proto._app_protocol, new_app_proto)
143
144 def test_data_received_after_closing(self):
145 ssl_proto = self.ssl_protocol()
146 self.connection_made(ssl_proto)
147 transp = ssl_proto._app_transport
148
149 transp.close()
150
151 # should not raise
152 self.assertIsNone(ssl_proto.buffer_updated(5))
153
154 def test_write_after_closing(self):
155 ssl_proto = self.ssl_protocol()
156 self.connection_made(ssl_proto)
157 transp = ssl_proto._app_transport
158 transp.close()
159
160 # should not raise
161 self.assertIsNone(transp.write(b'data'))
162
163
164 ##############################################################################
165 # Start TLS Tests
166 ##############################################################################
167
168
169 class ESC[4;38;5;81mBaseStartTLS(ESC[4;38;5;149mfunc_testsESC[4;38;5;149m.ESC[4;38;5;149mFunctionalTestCaseMixin):
170
171 PAYLOAD_SIZE = 1024 * 100
172 TIMEOUT = support.LONG_TIMEOUT
173
174 def new_loop(self):
175 raise NotImplementedError
176
177 def test_buf_feed_data(self):
178
179 class ESC[4;38;5;81mProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
180
181 def __init__(self, bufsize, usemv):
182 self.buf = bytearray(bufsize)
183 self.mv = memoryview(self.buf)
184 self.data = b''
185 self.usemv = usemv
186
187 def get_buffer(self, sizehint):
188 if self.usemv:
189 return self.mv
190 else:
191 return self.buf
192
193 def buffer_updated(self, nsize):
194 if self.usemv:
195 self.data += self.mv[:nsize]
196 else:
197 self.data += self.buf[:nsize]
198
199 for usemv in [False, True]:
200 proto = Proto(1, usemv)
201 protocols._feed_data_to_buffered_proto(proto, b'12345')
202 self.assertEqual(proto.data, b'12345')
203
204 proto = Proto(2, usemv)
205 protocols._feed_data_to_buffered_proto(proto, b'12345')
206 self.assertEqual(proto.data, b'12345')
207
208 proto = Proto(2, usemv)
209 protocols._feed_data_to_buffered_proto(proto, b'1234')
210 self.assertEqual(proto.data, b'1234')
211
212 proto = Proto(4, usemv)
213 protocols._feed_data_to_buffered_proto(proto, b'1234')
214 self.assertEqual(proto.data, b'1234')
215
216 proto = Proto(100, usemv)
217 protocols._feed_data_to_buffered_proto(proto, b'12345')
218 self.assertEqual(proto.data, b'12345')
219
220 proto = Proto(0, usemv)
221 with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
222 protocols._feed_data_to_buffered_proto(proto, b'12345')
223
224 def test_start_tls_client_reg_proto_1(self):
225 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
226
227 server_context = test_utils.simple_server_sslcontext()
228 client_context = test_utils.simple_client_sslcontext()
229
230 def serve(sock):
231 sock.settimeout(self.TIMEOUT)
232
233 data = sock.recv_all(len(HELLO_MSG))
234 self.assertEqual(len(data), len(HELLO_MSG))
235
236 sock.start_tls(server_context, server_side=True)
237
238 sock.sendall(b'O')
239 data = sock.recv_all(len(HELLO_MSG))
240 self.assertEqual(len(data), len(HELLO_MSG))
241
242 sock.shutdown(socket.SHUT_RDWR)
243 sock.close()
244
245 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
246 def __init__(self, on_data, on_eof):
247 self.on_data = on_data
248 self.on_eof = on_eof
249 self.con_made_cnt = 0
250
251 def connection_made(proto, tr):
252 proto.con_made_cnt += 1
253 # Ensure connection_made gets called only once.
254 self.assertEqual(proto.con_made_cnt, 1)
255
256 def data_received(self, data):
257 self.on_data.set_result(data)
258
259 def eof_received(self):
260 self.on_eof.set_result(True)
261
262 async def client(addr):
263 await asyncio.sleep(0.5)
264
265 on_data = self.loop.create_future()
266 on_eof = self.loop.create_future()
267
268 tr, proto = await self.loop.create_connection(
269 lambda: ClientProto(on_data, on_eof), *addr)
270
271 tr.write(HELLO_MSG)
272 new_tr = await self.loop.start_tls(tr, proto, client_context)
273
274 self.assertEqual(await on_data, b'O')
275 new_tr.write(HELLO_MSG)
276 await on_eof
277
278 new_tr.close()
279
280 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
281 self.loop.run_until_complete(
282 asyncio.wait_for(client(srv.addr),
283 timeout=support.SHORT_TIMEOUT))
284
285 # No garbage is left if SSL is closed uncleanly
286 client_context = weakref.ref(client_context)
287 support.gc_collect()
288 self.assertIsNone(client_context())
289
290 def test_create_connection_memory_leak(self):
291 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
292
293 server_context = test_utils.simple_server_sslcontext()
294 client_context = test_utils.simple_client_sslcontext()
295
296 def serve(sock):
297 sock.settimeout(self.TIMEOUT)
298
299 sock.start_tls(server_context, server_side=True)
300
301 sock.sendall(b'O')
302 data = sock.recv_all(len(HELLO_MSG))
303 self.assertEqual(len(data), len(HELLO_MSG))
304
305 sock.shutdown(socket.SHUT_RDWR)
306 sock.close()
307
308 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
309 def __init__(self, on_data, on_eof):
310 self.on_data = on_data
311 self.on_eof = on_eof
312 self.con_made_cnt = 0
313
314 def connection_made(proto, tr):
315 # XXX: We assume user stores the transport in protocol
316 proto.tr = tr
317 proto.con_made_cnt += 1
318 # Ensure connection_made gets called only once.
319 self.assertEqual(proto.con_made_cnt, 1)
320
321 def data_received(self, data):
322 self.on_data.set_result(data)
323
324 def eof_received(self):
325 self.on_eof.set_result(True)
326
327 async def client(addr):
328 await asyncio.sleep(0.5)
329
330 on_data = self.loop.create_future()
331 on_eof = self.loop.create_future()
332
333 tr, proto = await self.loop.create_connection(
334 lambda: ClientProto(on_data, on_eof), *addr,
335 ssl=client_context)
336
337 self.assertEqual(await on_data, b'O')
338 tr.write(HELLO_MSG)
339 await on_eof
340
341 tr.close()
342
343 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
344 self.loop.run_until_complete(
345 asyncio.wait_for(client(srv.addr),
346 timeout=support.SHORT_TIMEOUT))
347
348 # No garbage is left for SSL client from loop.create_connection, even
349 # if user stores the SSLTransport in corresponding protocol instance
350 client_context = weakref.ref(client_context)
351 support.gc_collect()
352 self.assertIsNone(client_context())
353
354 @socket_helper.skip_if_tcp_blackhole
355 def test_start_tls_client_buf_proto_1(self):
356 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
357
358 server_context = test_utils.simple_server_sslcontext()
359 client_context = test_utils.simple_client_sslcontext()
360 client_con_made_calls = 0
361
362 def serve(sock):
363 sock.settimeout(self.TIMEOUT)
364
365 data = sock.recv_all(len(HELLO_MSG))
366 self.assertEqual(len(data), len(HELLO_MSG))
367
368 sock.start_tls(server_context, server_side=True)
369
370 sock.sendall(b'O')
371 data = sock.recv_all(len(HELLO_MSG))
372 self.assertEqual(len(data), len(HELLO_MSG))
373
374 sock.sendall(b'2')
375 data = sock.recv_all(len(HELLO_MSG))
376 self.assertEqual(len(data), len(HELLO_MSG))
377
378 sock.shutdown(socket.SHUT_RDWR)
379 sock.close()
380
381 class ESC[4;38;5;81mClientProtoFirst(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
382 def __init__(self, on_data):
383 self.on_data = on_data
384 self.buf = bytearray(1)
385
386 def connection_made(self, tr):
387 nonlocal client_con_made_calls
388 client_con_made_calls += 1
389
390 def get_buffer(self, sizehint):
391 return self.buf
392
393 def buffer_updated(slf, nsize):
394 self.assertEqual(nsize, 1)
395 slf.on_data.set_result(bytes(slf.buf[:nsize]))
396
397 class ESC[4;38;5;81mClientProtoSecond(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
398 def __init__(self, on_data, on_eof):
399 self.on_data = on_data
400 self.on_eof = on_eof
401 self.con_made_cnt = 0
402
403 def connection_made(self, tr):
404 nonlocal client_con_made_calls
405 client_con_made_calls += 1
406
407 def data_received(self, data):
408 self.on_data.set_result(data)
409
410 def eof_received(self):
411 self.on_eof.set_result(True)
412
413 async def client(addr):
414 await asyncio.sleep(0.5)
415
416 on_data1 = self.loop.create_future()
417 on_data2 = self.loop.create_future()
418 on_eof = self.loop.create_future()
419
420 tr, proto = await self.loop.create_connection(
421 lambda: ClientProtoFirst(on_data1), *addr)
422
423 tr.write(HELLO_MSG)
424 new_tr = await self.loop.start_tls(tr, proto, client_context)
425
426 self.assertEqual(await on_data1, b'O')
427 new_tr.write(HELLO_MSG)
428
429 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
430 self.assertEqual(await on_data2, b'2')
431 new_tr.write(HELLO_MSG)
432 await on_eof
433
434 new_tr.close()
435
436 # connection_made() should be called only once -- when
437 # we establish connection for the first time. Start TLS
438 # doesn't call connection_made() on application protocols.
439 self.assertEqual(client_con_made_calls, 1)
440
441 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
442 self.loop.run_until_complete(
443 asyncio.wait_for(client(srv.addr),
444 timeout=self.TIMEOUT))
445
446 def test_start_tls_slow_client_cancel(self):
447 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
448
449 client_context = test_utils.simple_client_sslcontext()
450 server_waits_on_handshake = self.loop.create_future()
451
452 def serve(sock):
453 sock.settimeout(self.TIMEOUT)
454
455 data = sock.recv_all(len(HELLO_MSG))
456 self.assertEqual(len(data), len(HELLO_MSG))
457
458 try:
459 self.loop.call_soon_threadsafe(
460 server_waits_on_handshake.set_result, None)
461 data = sock.recv_all(1024 * 1024)
462 except ConnectionAbortedError:
463 pass
464 finally:
465 sock.close()
466
467 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
468 def __init__(self, on_data, on_eof):
469 self.on_data = on_data
470 self.on_eof = on_eof
471 self.con_made_cnt = 0
472
473 def connection_made(proto, tr):
474 proto.con_made_cnt += 1
475 # Ensure connection_made gets called only once.
476 self.assertEqual(proto.con_made_cnt, 1)
477
478 def data_received(self, data):
479 self.on_data.set_result(data)
480
481 def eof_received(self):
482 self.on_eof.set_result(True)
483
484 async def client(addr):
485 await asyncio.sleep(0.5)
486
487 on_data = self.loop.create_future()
488 on_eof = self.loop.create_future()
489
490 tr, proto = await self.loop.create_connection(
491 lambda: ClientProto(on_data, on_eof), *addr)
492
493 tr.write(HELLO_MSG)
494
495 await server_waits_on_handshake
496
497 with self.assertRaises(asyncio.TimeoutError):
498 await asyncio.wait_for(
499 self.loop.start_tls(tr, proto, client_context),
500 0.5)
501
502 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
503 self.loop.run_until_complete(
504 asyncio.wait_for(client(srv.addr),
505 timeout=support.SHORT_TIMEOUT))
506
507 @socket_helper.skip_if_tcp_blackhole
508 def test_start_tls_server_1(self):
509 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
510 ANSWER = b'answer'
511
512 server_context = test_utils.simple_server_sslcontext()
513 client_context = test_utils.simple_client_sslcontext()
514 answer = None
515
516 def client(sock, addr):
517 nonlocal answer
518 sock.settimeout(self.TIMEOUT)
519
520 sock.connect(addr)
521 data = sock.recv_all(len(HELLO_MSG))
522 self.assertEqual(len(data), len(HELLO_MSG))
523
524 sock.start_tls(client_context)
525 sock.sendall(HELLO_MSG)
526 answer = sock.recv_all(len(ANSWER))
527 sock.close()
528
529 class ESC[4;38;5;81mServerProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
530 def __init__(self, on_con, on_con_lost, on_got_hello):
531 self.on_con = on_con
532 self.on_con_lost = on_con_lost
533 self.on_got_hello = on_got_hello
534 self.data = b''
535 self.transport = None
536
537 def connection_made(self, tr):
538 self.transport = tr
539 self.on_con.set_result(tr)
540
541 def replace_transport(self, tr):
542 self.transport = tr
543
544 def data_received(self, data):
545 self.data += data
546 if len(self.data) >= len(HELLO_MSG):
547 self.on_got_hello.set_result(None)
548
549 def connection_lost(self, exc):
550 self.transport = None
551 if exc is None:
552 self.on_con_lost.set_result(None)
553 else:
554 self.on_con_lost.set_exception(exc)
555
556 async def main(proto, on_con, on_con_lost, on_got_hello):
557 tr = await on_con
558 tr.write(HELLO_MSG)
559
560 self.assertEqual(proto.data, b'')
561
562 new_tr = await self.loop.start_tls(
563 tr, proto, server_context,
564 server_side=True,
565 ssl_handshake_timeout=self.TIMEOUT)
566 proto.replace_transport(new_tr)
567
568 await on_got_hello
569 new_tr.write(ANSWER)
570
571 await on_con_lost
572 self.assertEqual(proto.data, HELLO_MSG)
573 new_tr.close()
574
575 async def run_main():
576 on_con = self.loop.create_future()
577 on_con_lost = self.loop.create_future()
578 on_got_hello = self.loop.create_future()
579 proto = ServerProto(on_con, on_con_lost, on_got_hello)
580
581 server = await self.loop.create_server(
582 lambda: proto, '127.0.0.1', 0)
583 addr = server.sockets[0].getsockname()
584
585 with self.tcp_client(lambda sock: client(sock, addr),
586 timeout=self.TIMEOUT):
587 await asyncio.wait_for(
588 main(proto, on_con, on_con_lost, on_got_hello),
589 timeout=self.TIMEOUT)
590
591 server.close()
592 await server.wait_closed()
593 self.assertEqual(answer, ANSWER)
594
595 self.loop.run_until_complete(run_main())
596
597 def test_start_tls_wrong_args(self):
598 async def main():
599 with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
600 await self.loop.start_tls(None, None, None)
601
602 sslctx = test_utils.simple_server_sslcontext()
603 with self.assertRaisesRegex(TypeError, 'is not supported'):
604 await self.loop.start_tls(None, None, sslctx)
605
606 self.loop.run_until_complete(main())
607
608 def test_handshake_timeout(self):
609 # bpo-29970: Check that a connection is aborted if handshake is not
610 # completed in timeout period, instead of remaining open indefinitely
611 client_sslctx = test_utils.simple_client_sslcontext()
612
613 messages = []
614 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
615
616 server_side_aborted = False
617
618 def server(sock):
619 nonlocal server_side_aborted
620 try:
621 sock.recv_all(1024 * 1024)
622 except ConnectionAbortedError:
623 server_side_aborted = True
624 finally:
625 sock.close()
626
627 async def client(addr):
628 await asyncio.wait_for(
629 self.loop.create_connection(
630 asyncio.Protocol,
631 *addr,
632 ssl=client_sslctx,
633 server_hostname='',
634 ssl_handshake_timeout=support.SHORT_TIMEOUT),
635 0.5)
636
637 with self.tcp_server(server,
638 max_clients=1,
639 backlog=1) as srv:
640
641 with self.assertRaises(asyncio.TimeoutError):
642 self.loop.run_until_complete(client(srv.addr))
643
644 self.assertTrue(server_side_aborted)
645
646 # Python issue #23197: cancelling a handshake must not raise an
647 # exception or log an error, even if the handshake failed
648 self.assertEqual(messages, [])
649
650 # The 10s handshake timeout should be cancelled to free related
651 # objects without really waiting for 10s
652 client_sslctx = weakref.ref(client_sslctx)
653 support.gc_collect()
654 self.assertIsNone(client_sslctx())
655
656 def test_create_connection_ssl_slow_handshake(self):
657 client_sslctx = test_utils.simple_client_sslcontext()
658
659 messages = []
660 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
661
662 def server(sock):
663 try:
664 sock.recv_all(1024 * 1024)
665 except ConnectionAbortedError:
666 pass
667 finally:
668 sock.close()
669
670 async def client(addr):
671 reader, writer = await asyncio.open_connection(
672 *addr,
673 ssl=client_sslctx,
674 server_hostname='',
675 ssl_handshake_timeout=1.0)
676
677 with self.tcp_server(server,
678 max_clients=1,
679 backlog=1) as srv:
680
681 with self.assertRaisesRegex(
682 ConnectionAbortedError,
683 r'SSL handshake.*is taking longer'):
684
685 self.loop.run_until_complete(client(srv.addr))
686
687 self.assertEqual(messages, [])
688
689 def test_create_connection_ssl_failed_certificate(self):
690 self.loop.set_exception_handler(lambda loop, ctx: None)
691
692 sslctx = test_utils.simple_server_sslcontext()
693 client_sslctx = test_utils.simple_client_sslcontext(
694 disable_verify=False)
695
696 def server(sock):
697 try:
698 sock.start_tls(
699 sslctx,
700 server_side=True)
701 except ssl.SSLError:
702 pass
703 except OSError:
704 pass
705 finally:
706 sock.close()
707
708 async def client(addr):
709 reader, writer = await asyncio.open_connection(
710 *addr,
711 ssl=client_sslctx,
712 server_hostname='',
713 ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
714
715 with self.tcp_server(server,
716 max_clients=1,
717 backlog=1) as srv:
718
719 with self.assertRaises(ssl.SSLCertVerificationError):
720 self.loop.run_until_complete(client(srv.addr))
721
722 def test_start_tls_client_corrupted_ssl(self):
723 self.loop.set_exception_handler(lambda loop, ctx: None)
724
725 sslctx = test_utils.simple_server_sslcontext()
726 client_sslctx = test_utils.simple_client_sslcontext()
727
728 def server(sock):
729 orig_sock = sock.dup()
730 try:
731 sock.start_tls(
732 sslctx,
733 server_side=True)
734 sock.sendall(b'A\n')
735 sock.recv_all(1)
736 orig_sock.send(b'please corrupt the SSL connection')
737 except ssl.SSLError:
738 pass
739 finally:
740 orig_sock.close()
741 sock.close()
742
743 async def client(addr):
744 reader, writer = await asyncio.open_connection(
745 *addr,
746 ssl=client_sslctx,
747 server_hostname='')
748
749 self.assertEqual(await reader.readline(), b'A\n')
750 writer.write(b'B')
751 with self.assertRaises(ssl.SSLError):
752 await reader.readline()
753
754 writer.close()
755 return 'OK'
756
757 with self.tcp_server(server,
758 max_clients=1,
759 backlog=1) as srv:
760
761 res = self.loop.run_until_complete(client(srv.addr))
762
763 self.assertEqual(res, 'OK')
764
765
766 @unittest.skipIf(ssl is None, 'No ssl module')
767 class ESC[4;38;5;81mSelectorStartTLSTests(ESC[4;38;5;149mBaseStartTLS, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
768
769 def new_loop(self):
770 return asyncio.SelectorEventLoop()
771
772
773 @unittest.skipIf(ssl is None, 'No ssl module')
774 @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
775 class ESC[4;38;5;81mProactorStartTLSTests(ESC[4;38;5;149mBaseStartTLS, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
776
777 def new_loop(self):
778 return asyncio.ProactorEventLoop()
779
780
781 if __name__ == '__main__':
782 unittest.main()