1 import asyncio
2 import asyncio.sslproto
3 import contextlib
4 import gc
5 import logging
6 import select
7 import socket
8 import sys
9 import tempfile
10 import threading
11 import time
12 import weakref
13 import unittest
14
15 try:
16 import ssl
17 except ImportError:
18 ssl = None
19
20 from test import support
21 from test.test_asyncio import utils as test_utils
22
23
24 MACOS = (sys.platform == 'darwin')
25 BUF_MULTIPLIER = 1024 if not MACOS else 64
26
27
28 def tearDownModule():
29 asyncio.set_event_loop_policy(None)
30
31
32 class ESC[4;38;5;81mMyBaseProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
33 connected = None
34 done = None
35
36 def __init__(self, loop=None):
37 self.transport = None
38 self.state = 'INITIAL'
39 self.nbytes = 0
40 if loop is not None:
41 self.connected = asyncio.Future(loop=loop)
42 self.done = asyncio.Future(loop=loop)
43
44 def connection_made(self, transport):
45 self.transport = transport
46 assert self.state == 'INITIAL', self.state
47 self.state = 'CONNECTED'
48 if self.connected:
49 self.connected.set_result(None)
50
51 def data_received(self, data):
52 assert self.state == 'CONNECTED', self.state
53 self.nbytes += len(data)
54
55 def eof_received(self):
56 assert self.state == 'CONNECTED', self.state
57 self.state = 'EOF'
58
59 def connection_lost(self, exc):
60 assert self.state in ('CONNECTED', 'EOF'), self.state
61 self.state = 'CLOSED'
62 if self.done:
63 self.done.set_result(None)
64
65
66 class ESC[4;38;5;81mMessageOutFilter(ESC[4;38;5;149mloggingESC[4;38;5;149m.ESC[4;38;5;149mFilter):
67 def __init__(self, msg):
68 self.msg = msg
69
70 def filter(self, record):
71 if self.msg in record.msg:
72 return False
73 return True
74
75
76 @unittest.skipIf(ssl is None, 'No ssl module')
77 class ESC[4;38;5;81mTestSSL(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
78
79 PAYLOAD_SIZE = 1024 * 100
80 TIMEOUT = support.LONG_TIMEOUT
81
82 def setUp(self):
83 super().setUp()
84 self.loop = asyncio.new_event_loop()
85 self.set_event_loop(self.loop)
86 self.addCleanup(self.loop.close)
87
88 def tearDown(self):
89 # just in case if we have transport close callbacks
90 if not self.loop.is_closed():
91 test_utils.run_briefly(self.loop)
92
93 self.doCleanups()
94 support.gc_collect()
95 super().tearDown()
96
97 def tcp_server(self, server_prog, *,
98 family=socket.AF_INET,
99 addr=None,
100 timeout=support.SHORT_TIMEOUT,
101 backlog=1,
102 max_clients=10):
103
104 if addr is None:
105 if family == getattr(socket, "AF_UNIX", None):
106 with tempfile.NamedTemporaryFile() as tmp:
107 addr = tmp.name
108 else:
109 addr = ('127.0.0.1', 0)
110
111 sock = socket.socket(family, socket.SOCK_STREAM)
112
113 if timeout is None:
114 raise RuntimeError('timeout is required')
115 if timeout <= 0:
116 raise RuntimeError('only blocking sockets are supported')
117 sock.settimeout(timeout)
118
119 try:
120 sock.bind(addr)
121 sock.listen(backlog)
122 except OSError as ex:
123 sock.close()
124 raise ex
125
126 return TestThreadedServer(
127 self, sock, server_prog, timeout, max_clients)
128
129 def tcp_client(self, client_prog,
130 family=socket.AF_INET,
131 timeout=support.SHORT_TIMEOUT):
132
133 sock = socket.socket(family, socket.SOCK_STREAM)
134
135 if timeout is None:
136 raise RuntimeError('timeout is required')
137 if timeout <= 0:
138 raise RuntimeError('only blocking sockets are supported')
139 sock.settimeout(timeout)
140
141 return TestThreadedClient(
142 self, sock, client_prog, timeout)
143
144 def unix_server(self, *args, **kwargs):
145 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
146
147 def unix_client(self, *args, **kwargs):
148 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
149
150 def _create_server_ssl_context(self, certfile, keyfile=None):
151 sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
152 sslcontext.options |= ssl.OP_NO_SSLv2
153 sslcontext.load_cert_chain(certfile, keyfile)
154 return sslcontext
155
156 def _create_client_ssl_context(self, *, disable_verify=True):
157 sslcontext = ssl.create_default_context()
158 sslcontext.check_hostname = False
159 if disable_verify:
160 sslcontext.verify_mode = ssl.CERT_NONE
161 return sslcontext
162
163 @contextlib.contextmanager
164 def _silence_eof_received_warning(self):
165 # TODO This warning has to be fixed in asyncio.
166 logger = logging.getLogger('asyncio')
167 filter = MessageOutFilter('has no effect when using ssl')
168 logger.addFilter(filter)
169 try:
170 yield
171 finally:
172 logger.removeFilter(filter)
173
174 def _abort_socket_test(self, ex):
175 try:
176 self.loop.stop()
177 finally:
178 self.fail(ex)
179
180 def new_loop(self):
181 return asyncio.new_event_loop()
182
183 def new_policy(self):
184 return asyncio.DefaultEventLoopPolicy()
185
186 async def wait_closed(self, obj):
187 if not isinstance(obj, asyncio.StreamWriter):
188 return
189 try:
190 await obj.wait_closed()
191 except (BrokenPipeError, ConnectionError):
192 pass
193
194 def test_create_server_ssl_1(self):
195 CNT = 0 # number of clients that were successful
196 TOTAL_CNT = 25 # total number of clients that test will create
197 TIMEOUT = support.LONG_TIMEOUT # timeout for this test
198
199 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
200 B_DATA = b'B' * 1024 * BUF_MULTIPLIER
201
202 sslctx = self._create_server_ssl_context(
203 test_utils.ONLYCERT, test_utils.ONLYKEY
204 )
205 client_sslctx = self._create_client_ssl_context()
206
207 clients = []
208
209 async def handle_client(reader, writer):
210 nonlocal CNT
211
212 data = await reader.readexactly(len(A_DATA))
213 self.assertEqual(data, A_DATA)
214 writer.write(b'OK')
215
216 data = await reader.readexactly(len(B_DATA))
217 self.assertEqual(data, B_DATA)
218 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
219
220 await writer.drain()
221 writer.close()
222
223 CNT += 1
224
225 async def test_client(addr):
226 fut = asyncio.Future()
227
228 def prog(sock):
229 try:
230 sock.starttls(client_sslctx)
231 sock.connect(addr)
232 sock.send(A_DATA)
233
234 data = sock.recv_all(2)
235 self.assertEqual(data, b'OK')
236
237 sock.send(B_DATA)
238 data = sock.recv_all(4)
239 self.assertEqual(data, b'SPAM')
240
241 sock.close()
242
243 except Exception as ex:
244 self.loop.call_soon_threadsafe(fut.set_exception, ex)
245 else:
246 self.loop.call_soon_threadsafe(fut.set_result, None)
247
248 client = self.tcp_client(prog)
249 client.start()
250 clients.append(client)
251
252 await fut
253
254 async def start_server():
255 extras = {}
256 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
257
258 srv = await asyncio.start_server(
259 handle_client,
260 '127.0.0.1', 0,
261 family=socket.AF_INET,
262 ssl=sslctx,
263 **extras)
264
265 try:
266 srv_socks = srv.sockets
267 self.assertTrue(srv_socks)
268
269 addr = srv_socks[0].getsockname()
270
271 tasks = []
272 for _ in range(TOTAL_CNT):
273 tasks.append(test_client(addr))
274
275 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
276
277 finally:
278 self.loop.call_soon(srv.close)
279 await srv.wait_closed()
280
281 with self._silence_eof_received_warning():
282 self.loop.run_until_complete(start_server())
283
284 self.assertEqual(CNT, TOTAL_CNT)
285
286 for client in clients:
287 client.stop()
288
289 def test_create_connection_ssl_1(self):
290 self.loop.set_exception_handler(None)
291
292 CNT = 0
293 TOTAL_CNT = 25
294
295 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
296 B_DATA = b'B' * 1024 * BUF_MULTIPLIER
297
298 sslctx = self._create_server_ssl_context(
299 test_utils.ONLYCERT,
300 test_utils.ONLYKEY
301 )
302 client_sslctx = self._create_client_ssl_context()
303
304 def server(sock):
305 sock.starttls(
306 sslctx,
307 server_side=True)
308
309 data = sock.recv_all(len(A_DATA))
310 self.assertEqual(data, A_DATA)
311 sock.send(b'OK')
312
313 data = sock.recv_all(len(B_DATA))
314 self.assertEqual(data, B_DATA)
315 sock.send(b'SPAM')
316
317 sock.close()
318
319 async def client(addr):
320 extras = {}
321 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
322
323 reader, writer = await asyncio.open_connection(
324 *addr,
325 ssl=client_sslctx,
326 server_hostname='',
327 **extras)
328
329 writer.write(A_DATA)
330 self.assertEqual(await reader.readexactly(2), b'OK')
331
332 writer.write(B_DATA)
333 self.assertEqual(await reader.readexactly(4), b'SPAM')
334
335 nonlocal CNT
336 CNT += 1
337
338 writer.close()
339 await self.wait_closed(writer)
340
341 async def client_sock(addr):
342 sock = socket.socket()
343 sock.connect(addr)
344 reader, writer = await asyncio.open_connection(
345 sock=sock,
346 ssl=client_sslctx,
347 server_hostname='')
348
349 writer.write(A_DATA)
350 self.assertEqual(await reader.readexactly(2), b'OK')
351
352 writer.write(B_DATA)
353 self.assertEqual(await reader.readexactly(4), b'SPAM')
354
355 nonlocal CNT
356 CNT += 1
357
358 writer.close()
359 await self.wait_closed(writer)
360 sock.close()
361
362 def run(coro):
363 nonlocal CNT
364 CNT = 0
365
366 async def _gather(*tasks):
367 # trampoline
368 return await asyncio.gather(*tasks)
369
370 with self.tcp_server(server,
371 max_clients=TOTAL_CNT,
372 backlog=TOTAL_CNT) as srv:
373 tasks = []
374 for _ in range(TOTAL_CNT):
375 tasks.append(coro(srv.addr))
376
377 self.loop.run_until_complete(_gather(*tasks))
378
379 self.assertEqual(CNT, TOTAL_CNT)
380
381 with self._silence_eof_received_warning():
382 run(client)
383
384 with self._silence_eof_received_warning():
385 run(client_sock)
386
387 def test_create_connection_ssl_slow_handshake(self):
388 client_sslctx = self._create_client_ssl_context()
389
390 # silence error logger
391 self.loop.set_exception_handler(lambda *args: None)
392
393 def server(sock):
394 try:
395 sock.recv_all(1024 * 1024)
396 except ConnectionAbortedError:
397 pass
398 finally:
399 sock.close()
400
401 async def client(addr):
402 reader, writer = await asyncio.open_connection(
403 *addr,
404 ssl=client_sslctx,
405 server_hostname='',
406 ssl_handshake_timeout=1.0)
407 writer.close()
408 await self.wait_closed(writer)
409
410 with self.tcp_server(server,
411 max_clients=1,
412 backlog=1) as srv:
413
414 with self.assertRaisesRegex(
415 ConnectionAbortedError,
416 r'SSL handshake.*is taking longer'):
417
418 self.loop.run_until_complete(client(srv.addr))
419
420 def test_create_connection_ssl_failed_certificate(self):
421 # silence error logger
422 self.loop.set_exception_handler(lambda *args: None)
423
424 sslctx = self._create_server_ssl_context(
425 test_utils.ONLYCERT,
426 test_utils.ONLYKEY
427 )
428 client_sslctx = self._create_client_ssl_context(disable_verify=False)
429
430 def server(sock):
431 try:
432 sock.starttls(
433 sslctx,
434 server_side=True)
435 sock.connect()
436 except (ssl.SSLError, OSError):
437 pass
438 finally:
439 sock.close()
440
441 async def client(addr):
442 reader, writer = await asyncio.open_connection(
443 *addr,
444 ssl=client_sslctx,
445 server_hostname='',
446 ssl_handshake_timeout=support.SHORT_TIMEOUT)
447 writer.close()
448 await self.wait_closed(writer)
449
450 with self.tcp_server(server,
451 max_clients=1,
452 backlog=1) as srv:
453
454 with self.assertRaises(ssl.SSLCertVerificationError):
455 self.loop.run_until_complete(client(srv.addr))
456
457 def test_ssl_handshake_timeout(self):
458 # bpo-29970: Check that a connection is aborted if handshake is not
459 # completed in timeout period, instead of remaining open indefinitely
460 client_sslctx = test_utils.simple_client_sslcontext()
461
462 # silence error logger
463 messages = []
464 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
465
466 server_side_aborted = False
467
468 def server(sock):
469 nonlocal server_side_aborted
470 try:
471 sock.recv_all(1024 * 1024)
472 except ConnectionAbortedError:
473 server_side_aborted = True
474 finally:
475 sock.close()
476
477 async def client(addr):
478 await asyncio.wait_for(
479 self.loop.create_connection(
480 asyncio.Protocol,
481 *addr,
482 ssl=client_sslctx,
483 server_hostname='',
484 ssl_handshake_timeout=10.0),
485 0.5)
486
487 with self.tcp_server(server,
488 max_clients=1,
489 backlog=1) as srv:
490
491 with self.assertRaises(asyncio.TimeoutError):
492 self.loop.run_until_complete(client(srv.addr))
493
494 self.assertTrue(server_side_aborted)
495
496 # Python issue #23197: cancelling a handshake must not raise an
497 # exception or log an error, even if the handshake failed
498 self.assertEqual(messages, [])
499
500 def test_ssl_handshake_connection_lost(self):
501 # #246: make sure that no connection_lost() is called before
502 # connection_made() is called first
503
504 client_sslctx = test_utils.simple_client_sslcontext()
505
506 # silence error logger
507 self.loop.set_exception_handler(lambda loop, ctx: None)
508
509 connection_made_called = False
510 connection_lost_called = False
511
512 def server(sock):
513 sock.recv(1024)
514 # break the connection during handshake
515 sock.close()
516
517 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
518 def connection_made(self, transport):
519 nonlocal connection_made_called
520 connection_made_called = True
521
522 def connection_lost(self, exc):
523 nonlocal connection_lost_called
524 connection_lost_called = True
525
526 async def client(addr):
527 await self.loop.create_connection(
528 ClientProto,
529 *addr,
530 ssl=client_sslctx,
531 server_hostname=''),
532
533 with self.tcp_server(server,
534 max_clients=1,
535 backlog=1) as srv:
536
537 with self.assertRaises(ConnectionResetError):
538 self.loop.run_until_complete(client(srv.addr))
539
540 if connection_lost_called:
541 if connection_made_called:
542 self.fail("unexpected call to connection_lost()")
543 else:
544 self.fail("unexpected call to connection_lost() without"
545 "calling connection_made()")
546 elif connection_made_called:
547 self.fail("unexpected call to connection_made()")
548
549 def test_ssl_connect_accepted_socket(self):
550 proto = ssl.PROTOCOL_TLS_SERVER
551 server_context = ssl.SSLContext(proto)
552 server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
553 if hasattr(server_context, 'check_hostname'):
554 server_context.check_hostname = False
555 server_context.verify_mode = ssl.CERT_NONE
556
557 client_context = ssl.SSLContext(proto)
558 if hasattr(server_context, 'check_hostname'):
559 client_context.check_hostname = False
560 client_context.verify_mode = ssl.CERT_NONE
561
562 def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
563 loop = self.loop
564
565 class ESC[4;38;5;81mMyProto(ESC[4;38;5;149mMyBaseProto):
566
567 def connection_lost(self, exc):
568 super().connection_lost(exc)
569 loop.call_soon(loop.stop)
570
571 def data_received(self, data):
572 super().data_received(data)
573 self.transport.write(expected_response)
574
575 lsock = socket.socket(socket.AF_INET)
576 lsock.bind(('127.0.0.1', 0))
577 lsock.listen(1)
578 addr = lsock.getsockname()
579
580 message = b'test data'
581 response = None
582 expected_response = b'roger'
583
584 def client():
585 nonlocal response
586 try:
587 csock = socket.socket(socket.AF_INET)
588 if client_ssl is not None:
589 csock = client_ssl.wrap_socket(csock)
590 csock.connect(addr)
591 csock.sendall(message)
592 response = csock.recv(99)
593 csock.close()
594 except Exception as exc:
595 print(
596 "Failure in client thread in test_connect_accepted_socket",
597 exc)
598
599 thread = threading.Thread(target=client, daemon=True)
600 thread.start()
601
602 conn, _ = lsock.accept()
603 proto = MyProto(loop=loop)
604 proto.loop = loop
605
606 extras = {}
607 if server_ssl:
608 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
609
610 f = loop.create_task(
611 loop.connect_accepted_socket(
612 (lambda: proto), conn, ssl=server_ssl,
613 **extras))
614 loop.run_forever()
615 conn.close()
616 lsock.close()
617
618 thread.join(1)
619 self.assertFalse(thread.is_alive())
620 self.assertEqual(proto.state, 'CLOSED')
621 self.assertEqual(proto.nbytes, len(message))
622 self.assertEqual(response, expected_response)
623 tr, _ = f.result()
624
625 if server_ssl:
626 self.assertIn('SSL', tr.__class__.__name__)
627
628 tr.close()
629 # let it close
630 self.loop.run_until_complete(asyncio.sleep(0.1))
631
632 def test_start_tls_client_corrupted_ssl(self):
633 self.loop.set_exception_handler(lambda loop, ctx: None)
634
635 sslctx = test_utils.simple_server_sslcontext()
636 client_sslctx = test_utils.simple_client_sslcontext()
637
638 def server(sock):
639 orig_sock = sock.dup()
640 try:
641 sock.starttls(
642 sslctx,
643 server_side=True)
644 sock.sendall(b'A\n')
645 sock.recv_all(1)
646 orig_sock.send(b'please corrupt the SSL connection')
647 except ssl.SSLError:
648 pass
649 finally:
650 sock.close()
651 orig_sock.close()
652
653 async def client(addr):
654 reader, writer = await asyncio.open_connection(
655 *addr,
656 ssl=client_sslctx,
657 server_hostname='')
658
659 self.assertEqual(await reader.readline(), b'A\n')
660 writer.write(b'B')
661 with self.assertRaises(ssl.SSLError):
662 await reader.readline()
663 writer.close()
664 try:
665 await self.wait_closed(writer)
666 except ssl.SSLError:
667 pass
668 return 'OK'
669
670 with self.tcp_server(server,
671 max_clients=1,
672 backlog=1) as srv:
673
674 res = self.loop.run_until_complete(client(srv.addr))
675
676 self.assertEqual(res, 'OK')
677
678 def test_start_tls_client_reg_proto_1(self):
679 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
680
681 server_context = test_utils.simple_server_sslcontext()
682 client_context = test_utils.simple_client_sslcontext()
683
684 def serve(sock):
685 sock.settimeout(self.TIMEOUT)
686
687 data = sock.recv_all(len(HELLO_MSG))
688 self.assertEqual(len(data), len(HELLO_MSG))
689
690 sock.starttls(server_context, server_side=True)
691
692 sock.sendall(b'O')
693 data = sock.recv_all(len(HELLO_MSG))
694 self.assertEqual(len(data), len(HELLO_MSG))
695
696 sock.unwrap()
697 sock.close()
698
699 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
700 def __init__(self, on_data, on_eof):
701 self.on_data = on_data
702 self.on_eof = on_eof
703 self.con_made_cnt = 0
704
705 def connection_made(proto, tr):
706 proto.con_made_cnt += 1
707 # Ensure connection_made gets called only once.
708 self.assertEqual(proto.con_made_cnt, 1)
709
710 def data_received(self, data):
711 self.on_data.set_result(data)
712
713 def eof_received(self):
714 self.on_eof.set_result(True)
715
716 async def client(addr):
717 await asyncio.sleep(0.5)
718
719 on_data = self.loop.create_future()
720 on_eof = self.loop.create_future()
721
722 tr, proto = await self.loop.create_connection(
723 lambda: ClientProto(on_data, on_eof), *addr)
724
725 tr.write(HELLO_MSG)
726 new_tr = await self.loop.start_tls(tr, proto, client_context)
727
728 self.assertEqual(await on_data, b'O')
729 new_tr.write(HELLO_MSG)
730 await on_eof
731
732 new_tr.close()
733
734 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
735 self.loop.run_until_complete(
736 asyncio.wait_for(client(srv.addr),
737 timeout=support.SHORT_TIMEOUT))
738
739 def test_create_connection_memory_leak(self):
740 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
741
742 server_context = self._create_server_ssl_context(
743 test_utils.ONLYCERT, test_utils.ONLYKEY)
744 client_context = self._create_client_ssl_context()
745
746 def serve(sock):
747 sock.settimeout(self.TIMEOUT)
748
749 sock.starttls(server_context, server_side=True)
750
751 sock.sendall(b'O')
752 data = sock.recv_all(len(HELLO_MSG))
753 self.assertEqual(len(data), len(HELLO_MSG))
754
755 sock.unwrap()
756 sock.close()
757
758 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
759 def __init__(self, on_data, on_eof):
760 self.on_data = on_data
761 self.on_eof = on_eof
762 self.con_made_cnt = 0
763
764 def connection_made(proto, tr):
765 # XXX: We assume user stores the transport in protocol
766 proto.tr = tr
767 proto.con_made_cnt += 1
768 # Ensure connection_made gets called only once.
769 self.assertEqual(proto.con_made_cnt, 1)
770
771 def data_received(self, data):
772 self.on_data.set_result(data)
773
774 def eof_received(self):
775 self.on_eof.set_result(True)
776
777 async def client(addr):
778 await asyncio.sleep(0.5)
779
780 on_data = self.loop.create_future()
781 on_eof = self.loop.create_future()
782
783 tr, proto = await self.loop.create_connection(
784 lambda: ClientProto(on_data, on_eof), *addr,
785 ssl=client_context)
786
787 self.assertEqual(await on_data, b'O')
788 tr.write(HELLO_MSG)
789 await on_eof
790
791 tr.close()
792
793 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
794 self.loop.run_until_complete(
795 asyncio.wait_for(client(srv.addr),
796 timeout=support.SHORT_TIMEOUT))
797
798 # No garbage is left for SSL client from loop.create_connection, even
799 # if user stores the SSLTransport in corresponding protocol instance
800 client_context = weakref.ref(client_context)
801 self.assertIsNone(client_context())
802
803 def test_start_tls_client_buf_proto_1(self):
804 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
805
806 server_context = test_utils.simple_server_sslcontext()
807 client_context = test_utils.simple_client_sslcontext()
808
809 client_con_made_calls = 0
810
811 def serve(sock):
812 sock.settimeout(self.TIMEOUT)
813
814 data = sock.recv_all(len(HELLO_MSG))
815 self.assertEqual(len(data), len(HELLO_MSG))
816
817 sock.starttls(server_context, server_side=True)
818
819 sock.sendall(b'O')
820 data = sock.recv_all(len(HELLO_MSG))
821 self.assertEqual(len(data), len(HELLO_MSG))
822
823 sock.sendall(b'2')
824 data = sock.recv_all(len(HELLO_MSG))
825 self.assertEqual(len(data), len(HELLO_MSG))
826
827 sock.unwrap()
828 sock.close()
829
830 class ESC[4;38;5;81mClientProtoFirst(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
831 def __init__(self, on_data):
832 self.on_data = on_data
833 self.buf = bytearray(1)
834
835 def connection_made(self, tr):
836 nonlocal client_con_made_calls
837 client_con_made_calls += 1
838
839 def get_buffer(self, sizehint):
840 return self.buf
841
842 def buffer_updated(self, nsize):
843 assert nsize == 1
844 self.on_data.set_result(bytes(self.buf[:nsize]))
845
846 def eof_received(self):
847 pass
848
849 class ESC[4;38;5;81mClientProtoSecond(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
850 def __init__(self, on_data, on_eof):
851 self.on_data = on_data
852 self.on_eof = on_eof
853 self.con_made_cnt = 0
854
855 def connection_made(self, tr):
856 nonlocal client_con_made_calls
857 client_con_made_calls += 1
858
859 def data_received(self, data):
860 self.on_data.set_result(data)
861
862 def eof_received(self):
863 self.on_eof.set_result(True)
864
865 async def client(addr):
866 await asyncio.sleep(0.5)
867
868 on_data1 = self.loop.create_future()
869 on_data2 = self.loop.create_future()
870 on_eof = self.loop.create_future()
871
872 tr, proto = await self.loop.create_connection(
873 lambda: ClientProtoFirst(on_data1), *addr)
874
875 tr.write(HELLO_MSG)
876 new_tr = await self.loop.start_tls(tr, proto, client_context)
877
878 self.assertEqual(await on_data1, b'O')
879 new_tr.write(HELLO_MSG)
880
881 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
882 self.assertEqual(await on_data2, b'2')
883 new_tr.write(HELLO_MSG)
884 await on_eof
885
886 new_tr.close()
887
888 # connection_made() should be called only once -- when
889 # we establish connection for the first time. Start TLS
890 # doesn't call connection_made() on application protocols.
891 self.assertEqual(client_con_made_calls, 1)
892
893 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
894 self.loop.run_until_complete(
895 asyncio.wait_for(client(srv.addr),
896 timeout=self.TIMEOUT))
897
898 def test_start_tls_slow_client_cancel(self):
899 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
900
901 client_context = test_utils.simple_client_sslcontext()
902 server_waits_on_handshake = self.loop.create_future()
903
904 def serve(sock):
905 sock.settimeout(self.TIMEOUT)
906
907 data = sock.recv_all(len(HELLO_MSG))
908 self.assertEqual(len(data), len(HELLO_MSG))
909
910 try:
911 self.loop.call_soon_threadsafe(
912 server_waits_on_handshake.set_result, None)
913 data = sock.recv_all(1024 * 1024)
914 except ConnectionAbortedError:
915 pass
916 finally:
917 sock.close()
918
919 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
920 def __init__(self, on_data, on_eof):
921 self.on_data = on_data
922 self.on_eof = on_eof
923 self.con_made_cnt = 0
924
925 def connection_made(proto, tr):
926 proto.con_made_cnt += 1
927 # Ensure connection_made gets called only once.
928 self.assertEqual(proto.con_made_cnt, 1)
929
930 def data_received(self, data):
931 self.on_data.set_result(data)
932
933 def eof_received(self):
934 self.on_eof.set_result(True)
935
936 async def client(addr):
937 await asyncio.sleep(0.5)
938
939 on_data = self.loop.create_future()
940 on_eof = self.loop.create_future()
941
942 tr, proto = await self.loop.create_connection(
943 lambda: ClientProto(on_data, on_eof), *addr)
944
945 tr.write(HELLO_MSG)
946
947 await server_waits_on_handshake
948
949 with self.assertRaises(asyncio.TimeoutError):
950 await asyncio.wait_for(
951 self.loop.start_tls(tr, proto, client_context),
952 0.5)
953
954 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
955 self.loop.run_until_complete(
956 asyncio.wait_for(client(srv.addr),
957 timeout=support.SHORT_TIMEOUT))
958
959 def test_start_tls_server_1(self):
960 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
961
962 server_context = test_utils.simple_server_sslcontext()
963 client_context = test_utils.simple_client_sslcontext()
964
965 def client(sock, addr):
966 sock.settimeout(self.TIMEOUT)
967
968 sock.connect(addr)
969 data = sock.recv_all(len(HELLO_MSG))
970 self.assertEqual(len(data), len(HELLO_MSG))
971
972 sock.starttls(client_context)
973 sock.sendall(HELLO_MSG)
974
975 sock.unwrap()
976 sock.close()
977
978 class ESC[4;38;5;81mServerProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
979 def __init__(self, on_con, on_eof, on_con_lost):
980 self.on_con = on_con
981 self.on_eof = on_eof
982 self.on_con_lost = on_con_lost
983 self.data = b''
984
985 def connection_made(self, tr):
986 self.on_con.set_result(tr)
987
988 def data_received(self, data):
989 self.data += data
990
991 def eof_received(self):
992 self.on_eof.set_result(1)
993
994 def connection_lost(self, exc):
995 if exc is None:
996 self.on_con_lost.set_result(None)
997 else:
998 self.on_con_lost.set_exception(exc)
999
1000 async def main(proto, on_con, on_eof, on_con_lost):
1001 tr = await on_con
1002 tr.write(HELLO_MSG)
1003
1004 self.assertEqual(proto.data, b'')
1005
1006 new_tr = await self.loop.start_tls(
1007 tr, proto, server_context,
1008 server_side=True,
1009 ssl_handshake_timeout=self.TIMEOUT)
1010
1011 await on_eof
1012 await on_con_lost
1013 self.assertEqual(proto.data, HELLO_MSG)
1014 new_tr.close()
1015
1016 async def run_main():
1017 on_con = self.loop.create_future()
1018 on_eof = self.loop.create_future()
1019 on_con_lost = self.loop.create_future()
1020 proto = ServerProto(on_con, on_eof, on_con_lost)
1021
1022 server = await self.loop.create_server(
1023 lambda: proto, '127.0.0.1', 0)
1024 addr = server.sockets[0].getsockname()
1025
1026 with self.tcp_client(lambda sock: client(sock, addr),
1027 timeout=self.TIMEOUT):
1028 await asyncio.wait_for(
1029 main(proto, on_con, on_eof, on_con_lost),
1030 timeout=self.TIMEOUT)
1031
1032 server.close()
1033 await server.wait_closed()
1034
1035 self.loop.run_until_complete(run_main())
1036
1037 def test_create_server_ssl_over_ssl(self):
1038 CNT = 0 # number of clients that were successful
1039 TOTAL_CNT = 25 # total number of clients that test will create
1040 TIMEOUT = support.LONG_TIMEOUT # timeout for this test
1041
1042 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
1043 B_DATA = b'B' * 1024 * BUF_MULTIPLIER
1044
1045 sslctx_1 = self._create_server_ssl_context(
1046 test_utils.ONLYCERT, test_utils.ONLYKEY)
1047 client_sslctx_1 = self._create_client_ssl_context()
1048 sslctx_2 = self._create_server_ssl_context(
1049 test_utils.ONLYCERT, test_utils.ONLYKEY)
1050 client_sslctx_2 = self._create_client_ssl_context()
1051
1052 clients = []
1053
1054 async def handle_client(reader, writer):
1055 nonlocal CNT
1056
1057 data = await reader.readexactly(len(A_DATA))
1058 self.assertEqual(data, A_DATA)
1059 writer.write(b'OK')
1060
1061 data = await reader.readexactly(len(B_DATA))
1062 self.assertEqual(data, B_DATA)
1063 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
1064
1065 await writer.drain()
1066 writer.close()
1067
1068 CNT += 1
1069
1070 class ESC[4;38;5;81mServerProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mStreamReaderProtocol):
1071 def connection_made(self, transport):
1072 super_ = super()
1073 transport.pause_reading()
1074 fut = self._loop.create_task(self._loop.start_tls(
1075 transport, self, sslctx_2, server_side=True))
1076
1077 def cb(_):
1078 try:
1079 tr = fut.result()
1080 except Exception as ex:
1081 super_.connection_lost(ex)
1082 else:
1083 super_.connection_made(tr)
1084 fut.add_done_callback(cb)
1085
1086 def server_protocol_factory():
1087 reader = asyncio.StreamReader()
1088 protocol = ServerProtocol(reader, handle_client)
1089 return protocol
1090
1091 async def test_client(addr):
1092 fut = asyncio.Future()
1093
1094 def prog(sock):
1095 try:
1096 sock.connect(addr)
1097 sock.starttls(client_sslctx_1)
1098
1099 # because wrap_socket() doesn't work correctly on
1100 # SSLSocket, we have to do the 2nd level SSL manually
1101 incoming = ssl.MemoryBIO()
1102 outgoing = ssl.MemoryBIO()
1103 sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
1104
1105 def do(func, *args):
1106 while True:
1107 try:
1108 rv = func(*args)
1109 break
1110 except ssl.SSLWantReadError:
1111 if outgoing.pending:
1112 sock.send(outgoing.read())
1113 incoming.write(sock.recv(65536))
1114 if outgoing.pending:
1115 sock.send(outgoing.read())
1116 return rv
1117
1118 do(sslobj.do_handshake)
1119
1120 do(sslobj.write, A_DATA)
1121 data = do(sslobj.read, 2)
1122 self.assertEqual(data, b'OK')
1123
1124 do(sslobj.write, B_DATA)
1125 data = b''
1126 while True:
1127 chunk = do(sslobj.read, 4)
1128 if not chunk:
1129 break
1130 data += chunk
1131 self.assertEqual(data, b'SPAM')
1132
1133 do(sslobj.unwrap)
1134 sock.close()
1135
1136 except Exception as ex:
1137 self.loop.call_soon_threadsafe(fut.set_exception, ex)
1138 sock.close()
1139 else:
1140 self.loop.call_soon_threadsafe(fut.set_result, None)
1141
1142 client = self.tcp_client(prog)
1143 client.start()
1144 clients.append(client)
1145
1146 await fut
1147
1148 async def start_server():
1149 extras = {}
1150
1151 srv = await self.loop.create_server(
1152 server_protocol_factory,
1153 '127.0.0.1', 0,
1154 family=socket.AF_INET,
1155 ssl=sslctx_1,
1156 **extras)
1157
1158 try:
1159 srv_socks = srv.sockets
1160 self.assertTrue(srv_socks)
1161
1162 addr = srv_socks[0].getsockname()
1163
1164 tasks = []
1165 for _ in range(TOTAL_CNT):
1166 tasks.append(test_client(addr))
1167
1168 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
1169
1170 finally:
1171 self.loop.call_soon(srv.close)
1172 await srv.wait_closed()
1173
1174 with self._silence_eof_received_warning():
1175 self.loop.run_until_complete(start_server())
1176
1177 self.assertEqual(CNT, TOTAL_CNT)
1178
1179 for client in clients:
1180 client.stop()
1181
1182 def test_shutdown_cleanly(self):
1183 CNT = 0
1184 TOTAL_CNT = 25
1185
1186 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
1187
1188 sslctx = self._create_server_ssl_context(
1189 test_utils.ONLYCERT, test_utils.ONLYKEY)
1190 client_sslctx = self._create_client_ssl_context()
1191
1192 def server(sock):
1193 sock.starttls(
1194 sslctx,
1195 server_side=True)
1196
1197 data = sock.recv_all(len(A_DATA))
1198 self.assertEqual(data, A_DATA)
1199 sock.send(b'OK')
1200
1201 sock.unwrap()
1202
1203 sock.close()
1204
1205 async def client(addr):
1206 extras = {}
1207 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
1208
1209 reader, writer = await asyncio.open_connection(
1210 *addr,
1211 ssl=client_sslctx,
1212 server_hostname='',
1213 **extras)
1214
1215 writer.write(A_DATA)
1216 self.assertEqual(await reader.readexactly(2), b'OK')
1217
1218 self.assertEqual(await reader.read(), b'')
1219
1220 nonlocal CNT
1221 CNT += 1
1222
1223 writer.close()
1224 await self.wait_closed(writer)
1225
1226 def run(coro):
1227 nonlocal CNT
1228 CNT = 0
1229
1230 async def _gather(*tasks):
1231 return await asyncio.gather(*tasks)
1232
1233 with self.tcp_server(server,
1234 max_clients=TOTAL_CNT,
1235 backlog=TOTAL_CNT) as srv:
1236 tasks = []
1237 for _ in range(TOTAL_CNT):
1238 tasks.append(coro(srv.addr))
1239
1240 self.loop.run_until_complete(
1241 _gather(*tasks))
1242
1243 self.assertEqual(CNT, TOTAL_CNT)
1244
1245 with self._silence_eof_received_warning():
1246 run(client)
1247
1248 def test_flush_before_shutdown(self):
1249 CHUNK = 1024 * 128
1250 SIZE = 32
1251
1252 sslctx = self._create_server_ssl_context(
1253 test_utils.ONLYCERT, test_utils.ONLYKEY)
1254 client_sslctx = self._create_client_ssl_context()
1255
1256 future = None
1257
1258 def server(sock):
1259 sock.starttls(sslctx, server_side=True)
1260 self.assertEqual(sock.recv_all(4), b'ping')
1261 sock.send(b'pong')
1262 time.sleep(0.5) # hopefully stuck the TCP buffer
1263 data = sock.recv_all(CHUNK * SIZE)
1264 self.assertEqual(len(data), CHUNK * SIZE)
1265 sock.close()
1266
1267 def run(meth):
1268 def wrapper(sock):
1269 try:
1270 meth(sock)
1271 except Exception as ex:
1272 self.loop.call_soon_threadsafe(future.set_exception, ex)
1273 else:
1274 self.loop.call_soon_threadsafe(future.set_result, None)
1275 return wrapper
1276
1277 async def client(addr):
1278 nonlocal future
1279 future = self.loop.create_future()
1280 reader, writer = await asyncio.open_connection(
1281 *addr,
1282 ssl=client_sslctx,
1283 server_hostname='')
1284 sslprotocol = writer.transport._ssl_protocol
1285 writer.write(b'ping')
1286 data = await reader.readexactly(4)
1287 self.assertEqual(data, b'pong')
1288
1289 sslprotocol.pause_writing()
1290 for _ in range(SIZE):
1291 writer.write(b'x' * CHUNK)
1292
1293 writer.close()
1294 sslprotocol.resume_writing()
1295
1296 await self.wait_closed(writer)
1297 try:
1298 data = await reader.read()
1299 self.assertEqual(data, b'')
1300 except ConnectionResetError:
1301 pass
1302 await future
1303
1304 with self.tcp_server(run(server)) as srv:
1305 self.loop.run_until_complete(client(srv.addr))
1306
1307 def test_remote_shutdown_receives_trailing_data(self):
1308 CHUNK = 1024 * 128
1309 SIZE = 32
1310
1311 sslctx = self._create_server_ssl_context(
1312 test_utils.ONLYCERT,
1313 test_utils.ONLYKEY
1314 )
1315 client_sslctx = self._create_client_ssl_context()
1316 future = None
1317
1318 def server(sock):
1319 incoming = ssl.MemoryBIO()
1320 outgoing = ssl.MemoryBIO()
1321 sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
1322
1323 while True:
1324 try:
1325 sslobj.do_handshake()
1326 except ssl.SSLWantReadError:
1327 if outgoing.pending:
1328 sock.send(outgoing.read())
1329 incoming.write(sock.recv(16384))
1330 else:
1331 if outgoing.pending:
1332 sock.send(outgoing.read())
1333 break
1334
1335 while True:
1336 try:
1337 data = sslobj.read(4)
1338 except ssl.SSLWantReadError:
1339 incoming.write(sock.recv(16384))
1340 else:
1341 break
1342
1343 self.assertEqual(data, b'ping')
1344 sslobj.write(b'pong')
1345 sock.send(outgoing.read())
1346
1347 time.sleep(0.2) # wait for the peer to fill its backlog
1348
1349 # send close_notify but don't wait for response
1350 with self.assertRaises(ssl.SSLWantReadError):
1351 sslobj.unwrap()
1352 sock.send(outgoing.read())
1353
1354 # should receive all data
1355 data_len = 0
1356 while True:
1357 try:
1358 chunk = len(sslobj.read(16384))
1359 data_len += chunk
1360 except ssl.SSLWantReadError:
1361 incoming.write(sock.recv(16384))
1362 except ssl.SSLZeroReturnError:
1363 break
1364
1365 self.assertEqual(data_len, CHUNK * SIZE)
1366
1367 # verify that close_notify is received
1368 sslobj.unwrap()
1369
1370 sock.close()
1371
1372 def eof_server(sock):
1373 sock.starttls(sslctx, server_side=True)
1374 self.assertEqual(sock.recv_all(4), b'ping')
1375 sock.send(b'pong')
1376
1377 time.sleep(0.2) # wait for the peer to fill its backlog
1378
1379 # send EOF
1380 sock.shutdown(socket.SHUT_WR)
1381
1382 # should receive all data
1383 data = sock.recv_all(CHUNK * SIZE)
1384 self.assertEqual(len(data), CHUNK * SIZE)
1385
1386 sock.close()
1387
1388 async def client(addr):
1389 nonlocal future
1390 future = self.loop.create_future()
1391
1392 reader, writer = await asyncio.open_connection(
1393 *addr,
1394 ssl=client_sslctx,
1395 server_hostname='')
1396 writer.write(b'ping')
1397 data = await reader.readexactly(4)
1398 self.assertEqual(data, b'pong')
1399
1400 # fill write backlog in a hacky way - renegotiation won't help
1401 for _ in range(SIZE):
1402 writer.transport._test__append_write_backlog(b'x' * CHUNK)
1403
1404 try:
1405 data = await reader.read()
1406 self.assertEqual(data, b'')
1407 except (BrokenPipeError, ConnectionResetError):
1408 pass
1409
1410 await future
1411
1412 writer.close()
1413 await self.wait_closed(writer)
1414
1415 def run(meth):
1416 def wrapper(sock):
1417 try:
1418 meth(sock)
1419 except Exception as ex:
1420 self.loop.call_soon_threadsafe(future.set_exception, ex)
1421 else:
1422 self.loop.call_soon_threadsafe(future.set_result, None)
1423 return wrapper
1424
1425 with self.tcp_server(run(server)) as srv:
1426 self.loop.run_until_complete(client(srv.addr))
1427
1428 with self.tcp_server(run(eof_server)) as srv:
1429 self.loop.run_until_complete(client(srv.addr))
1430
1431 def test_connect_timeout_warning(self):
1432 s = socket.socket(socket.AF_INET)
1433 s.bind(('127.0.0.1', 0))
1434 addr = s.getsockname()
1435
1436 async def test():
1437 try:
1438 await asyncio.wait_for(
1439 self.loop.create_connection(asyncio.Protocol,
1440 *addr, ssl=True),
1441 0.1)
1442 except (ConnectionRefusedError, asyncio.TimeoutError):
1443 pass
1444 else:
1445 self.fail('TimeoutError is not raised')
1446
1447 with s:
1448 try:
1449 with self.assertWarns(ResourceWarning) as cm:
1450 self.loop.run_until_complete(test())
1451 gc.collect()
1452 gc.collect()
1453 gc.collect()
1454 except AssertionError as e:
1455 self.assertEqual(str(e), 'ResourceWarning not triggered')
1456 else:
1457 self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
1458
1459 def test_handshake_timeout_handler_leak(self):
1460 s = socket.socket(socket.AF_INET)
1461 s.bind(('127.0.0.1', 0))
1462 s.listen(1)
1463 addr = s.getsockname()
1464
1465 async def test(ctx):
1466 try:
1467 await asyncio.wait_for(
1468 self.loop.create_connection(asyncio.Protocol, *addr,
1469 ssl=ctx),
1470 0.1)
1471 except (ConnectionRefusedError, asyncio.TimeoutError):
1472 pass
1473 else:
1474 self.fail('TimeoutError is not raised')
1475
1476 with s:
1477 ctx = ssl.create_default_context()
1478 self.loop.run_until_complete(test(ctx))
1479 ctx = weakref.ref(ctx)
1480
1481 # SSLProtocol should be DECREF to 0
1482 self.assertIsNone(ctx())
1483
1484 def test_shutdown_timeout_handler_leak(self):
1485 loop = self.loop
1486
1487 def server(sock):
1488 sslctx = self._create_server_ssl_context(
1489 test_utils.ONLYCERT,
1490 test_utils.ONLYKEY
1491 )
1492 sock = sslctx.wrap_socket(sock, server_side=True)
1493 sock.recv(32)
1494 sock.close()
1495
1496 class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
1497 def __init__(self):
1498 self.fut = asyncio.Future(loop=loop)
1499
1500 def connection_lost(self, exc):
1501 self.fut.set_result(None)
1502
1503 async def client(addr, ctx):
1504 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
1505 tr.close()
1506 await pr.fut
1507
1508 with self.tcp_server(server) as srv:
1509 ctx = self._create_client_ssl_context()
1510 loop.run_until_complete(client(srv.addr, ctx))
1511 ctx = weakref.ref(ctx)
1512
1513 # asyncio has no shutdown timeout, but it ends up with a circular
1514 # reference loop - not ideal (introduces gc glitches), but at least
1515 # not leaking
1516 gc.collect()
1517 gc.collect()
1518 gc.collect()
1519
1520 # SSLProtocol should be DECREF to 0
1521 self.assertIsNone(ctx())
1522
1523 def test_shutdown_timeout_handler_not_set(self):
1524 loop = self.loop
1525 eof = asyncio.Event()
1526 extra = None
1527
1528 def server(sock):
1529 sslctx = self._create_server_ssl_context(
1530 test_utils.ONLYCERT,
1531 test_utils.ONLYKEY
1532 )
1533 sock = sslctx.wrap_socket(sock, server_side=True)
1534 sock.send(b'hello')
1535 assert sock.recv(1024) == b'world'
1536 sock.send(b'extra bytes')
1537 # sending EOF here
1538 sock.shutdown(socket.SHUT_WR)
1539 loop.call_soon_threadsafe(eof.set)
1540 # make sure we have enough time to reproduce the issue
1541 assert sock.recv(1024) == b''
1542 sock.close()
1543
1544 class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
1545 def __init__(self):
1546 self.fut = asyncio.Future(loop=loop)
1547 self.transport = None
1548
1549 def connection_made(self, transport):
1550 self.transport = transport
1551
1552 def data_received(self, data):
1553 if data == b'hello':
1554 self.transport.write(b'world')
1555 # pause reading would make incoming data stay in the sslobj
1556 self.transport.pause_reading()
1557 else:
1558 nonlocal extra
1559 extra = data
1560
1561 def connection_lost(self, exc):
1562 if exc is None:
1563 self.fut.set_result(None)
1564 else:
1565 self.fut.set_exception(exc)
1566
1567 async def client(addr):
1568 ctx = self._create_client_ssl_context()
1569 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
1570 await eof.wait()
1571 tr.resume_reading()
1572 await pr.fut
1573 tr.close()
1574 assert extra == b'extra bytes'
1575
1576 with self.tcp_server(server) as srv:
1577 loop.run_until_complete(client(srv.addr))
1578
1579
1580 ###############################################################################
1581 # Socket Testing Utilities
1582 ###############################################################################
1583
1584
1585 class ESC[4;38;5;81mTestSocketWrapper:
1586
1587 def __init__(self, sock):
1588 self.__sock = sock
1589
1590 def recv_all(self, n):
1591 buf = b''
1592 while len(buf) < n:
1593 data = self.recv(n - len(buf))
1594 if data == b'':
1595 raise ConnectionAbortedError
1596 buf += data
1597 return buf
1598
1599 def starttls(self, ssl_context, *,
1600 server_side=False,
1601 server_hostname=None,
1602 do_handshake_on_connect=True):
1603
1604 assert isinstance(ssl_context, ssl.SSLContext)
1605
1606 ssl_sock = ssl_context.wrap_socket(
1607 self.__sock, server_side=server_side,
1608 server_hostname=server_hostname,
1609 do_handshake_on_connect=do_handshake_on_connect)
1610
1611 if server_side:
1612 ssl_sock.do_handshake()
1613
1614 self.__sock.close()
1615 self.__sock = ssl_sock
1616
1617 def __getattr__(self, name):
1618 return getattr(self.__sock, name)
1619
1620 def __repr__(self):
1621 return '<{} {!r}>'.format(type(self).__name__, self.__sock)
1622
1623
1624 class ESC[4;38;5;81mSocketThread(ESC[4;38;5;149mthreadingESC[4;38;5;149m.ESC[4;38;5;149mThread):
1625
1626 def stop(self):
1627 self._active = False
1628 self.join()
1629
1630 def __enter__(self):
1631 self.start()
1632 return self
1633
1634 def __exit__(self, *exc):
1635 self.stop()
1636
1637
1638 class ESC[4;38;5;81mTestThreadedClient(ESC[4;38;5;149mSocketThread):
1639
1640 def __init__(self, test, sock, prog, timeout):
1641 threading.Thread.__init__(self, None, None, 'test-client')
1642 self.daemon = True
1643
1644 self._timeout = timeout
1645 self._sock = sock
1646 self._active = True
1647 self._prog = prog
1648 self._test = test
1649
1650 def run(self):
1651 try:
1652 self._prog(TestSocketWrapper(self._sock))
1653 except (KeyboardInterrupt, SystemExit):
1654 raise
1655 except BaseException as ex:
1656 self._test._abort_socket_test(ex)
1657
1658
1659 class ESC[4;38;5;81mTestThreadedServer(ESC[4;38;5;149mSocketThread):
1660
1661 def __init__(self, test, sock, prog, timeout, max_clients):
1662 threading.Thread.__init__(self, None, None, 'test-server')
1663 self.daemon = True
1664
1665 self._clients = 0
1666 self._finished_clients = 0
1667 self._max_clients = max_clients
1668 self._timeout = timeout
1669 self._sock = sock
1670 self._active = True
1671
1672 self._prog = prog
1673
1674 self._s1, self._s2 = socket.socketpair()
1675 self._s1.setblocking(False)
1676
1677 self._test = test
1678
1679 def stop(self):
1680 try:
1681 if self._s2 and self._s2.fileno() != -1:
1682 try:
1683 self._s2.send(b'stop')
1684 except OSError:
1685 pass
1686 finally:
1687 super().stop()
1688
1689 def run(self):
1690 try:
1691 with self._sock:
1692 self._sock.setblocking(False)
1693 self._run()
1694 finally:
1695 self._s1.close()
1696 self._s2.close()
1697
1698 def _run(self):
1699 while self._active:
1700 if self._clients >= self._max_clients:
1701 return
1702
1703 r, w, x = select.select(
1704 [self._sock, self._s1], [], [], self._timeout)
1705
1706 if self._s1 in r:
1707 return
1708
1709 if self._sock in r:
1710 try:
1711 conn, addr = self._sock.accept()
1712 except BlockingIOError:
1713 continue
1714 except socket.timeout:
1715 if not self._active:
1716 return
1717 else:
1718 raise
1719 else:
1720 self._clients += 1
1721 conn.settimeout(self._timeout)
1722 try:
1723 with conn:
1724 self._handle_client(conn)
1725 except (KeyboardInterrupt, SystemExit):
1726 raise
1727 except BaseException as ex:
1728 self._active = False
1729 try:
1730 raise
1731 finally:
1732 self._test._abort_socket_test(ex)
1733
1734 def _handle_client(self, sock):
1735 self._prog(TestSocketWrapper(sock))
1736
1737 @property
1738 def addr(self):
1739 return self._sock.getsockname()