python (3.12.0)
1 import asyncio
2 import contextlib
3 import gc
4 import logging
5 import select
6 import socket
7 import sys
8 import tempfile
9 import threading
10 import time
11 import weakref
12 import unittest
13
14 try:
15 import ssl
16 except ImportError:
17 ssl = None
18
19 from test import support
20 from test.test_asyncio import utils as test_utils
21
22
23 MACOS = (sys.platform == 'darwin')
24 BUF_MULTIPLIER = 1024 if not MACOS else 64
25
26
27 def tearDownModule():
28 asyncio.set_event_loop_policy(None)
29
30
31 class ESC[4;38;5;81mMyBaseProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
32 connected = None
33 done = None
34
35 def __init__(self, loop=None):
36 self.transport = None
37 self.state = 'INITIAL'
38 self.nbytes = 0
39 if loop is not None:
40 self.connected = asyncio.Future(loop=loop)
41 self.done = asyncio.Future(loop=loop)
42
43 def connection_made(self, transport):
44 self.transport = transport
45 assert self.state == 'INITIAL', self.state
46 self.state = 'CONNECTED'
47 if self.connected:
48 self.connected.set_result(None)
49
50 def data_received(self, data):
51 assert self.state == 'CONNECTED', self.state
52 self.nbytes += len(data)
53
54 def eof_received(self):
55 assert self.state == 'CONNECTED', self.state
56 self.state = 'EOF'
57
58 def connection_lost(self, exc):
59 assert self.state in ('CONNECTED', 'EOF'), self.state
60 self.state = 'CLOSED'
61 if self.done:
62 self.done.set_result(None)
63
64
65 class ESC[4;38;5;81mMessageOutFilter(ESC[4;38;5;149mloggingESC[4;38;5;149m.ESC[4;38;5;149mFilter):
66 def __init__(self, msg):
67 self.msg = msg
68
69 def filter(self, record):
70 if self.msg in record.msg:
71 return False
72 return True
73
74
75 @unittest.skipIf(ssl is None, 'No ssl module')
76 class ESC[4;38;5;81mTestSSL(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
77
78 PAYLOAD_SIZE = 1024 * 100
79 TIMEOUT = support.LONG_TIMEOUT
80
81 def setUp(self):
82 super().setUp()
83 self.loop = asyncio.new_event_loop()
84 self.set_event_loop(self.loop)
85 self.addCleanup(self.loop.close)
86
87 def tearDown(self):
88 # just in case if we have transport close callbacks
89 if not self.loop.is_closed():
90 test_utils.run_briefly(self.loop)
91
92 self.doCleanups()
93 support.gc_collect()
94 super().tearDown()
95
96 def tcp_server(self, server_prog, *,
97 family=socket.AF_INET,
98 addr=None,
99 timeout=support.SHORT_TIMEOUT,
100 backlog=1,
101 max_clients=10):
102
103 if addr is None:
104 if family == getattr(socket, "AF_UNIX", None):
105 with tempfile.NamedTemporaryFile() as tmp:
106 addr = tmp.name
107 else:
108 addr = ('127.0.0.1', 0)
109
110 sock = socket.socket(family, socket.SOCK_STREAM)
111
112 if timeout is None:
113 raise RuntimeError('timeout is required')
114 if timeout <= 0:
115 raise RuntimeError('only blocking sockets are supported')
116 sock.settimeout(timeout)
117
118 try:
119 sock.bind(addr)
120 sock.listen(backlog)
121 except OSError as ex:
122 sock.close()
123 raise ex
124
125 return TestThreadedServer(
126 self, sock, server_prog, timeout, max_clients)
127
128 def tcp_client(self, client_prog,
129 family=socket.AF_INET,
130 timeout=support.SHORT_TIMEOUT):
131
132 sock = socket.socket(family, socket.SOCK_STREAM)
133
134 if timeout is None:
135 raise RuntimeError('timeout is required')
136 if timeout <= 0:
137 raise RuntimeError('only blocking sockets are supported')
138 sock.settimeout(timeout)
139
140 return TestThreadedClient(
141 self, sock, client_prog, timeout)
142
143 def unix_server(self, *args, **kwargs):
144 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
145
146 def unix_client(self, *args, **kwargs):
147 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
148
149 def _create_server_ssl_context(self, certfile, keyfile=None):
150 sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
151 sslcontext.options |= ssl.OP_NO_SSLv2
152 sslcontext.load_cert_chain(certfile, keyfile)
153 return sslcontext
154
155 def _create_client_ssl_context(self, *, disable_verify=True):
156 sslcontext = ssl.create_default_context()
157 sslcontext.check_hostname = False
158 if disable_verify:
159 sslcontext.verify_mode = ssl.CERT_NONE
160 return sslcontext
161
162 @contextlib.contextmanager
163 def _silence_eof_received_warning(self):
164 # TODO This warning has to be fixed in asyncio.
165 logger = logging.getLogger('asyncio')
166 filter = MessageOutFilter('has no effect when using ssl')
167 logger.addFilter(filter)
168 try:
169 yield
170 finally:
171 logger.removeFilter(filter)
172
173 def _abort_socket_test(self, ex):
174 try:
175 self.loop.stop()
176 finally:
177 self.fail(ex)
178
179 def new_loop(self):
180 return asyncio.new_event_loop()
181
182 def new_policy(self):
183 return asyncio.DefaultEventLoopPolicy()
184
185 async def wait_closed(self, obj):
186 if not isinstance(obj, asyncio.StreamWriter):
187 return
188 try:
189 await obj.wait_closed()
190 except (BrokenPipeError, ConnectionError):
191 pass
192
193 def test_create_server_ssl_1(self):
194 CNT = 0 # number of clients that were successful
195 TOTAL_CNT = 25 # total number of clients that test will create
196 TIMEOUT = support.LONG_TIMEOUT # timeout for this test
197
198 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
199 B_DATA = b'B' * 1024 * BUF_MULTIPLIER
200
201 sslctx = self._create_server_ssl_context(
202 test_utils.ONLYCERT, test_utils.ONLYKEY
203 )
204 client_sslctx = self._create_client_ssl_context()
205
206 clients = []
207
208 async def handle_client(reader, writer):
209 nonlocal CNT
210
211 data = await reader.readexactly(len(A_DATA))
212 self.assertEqual(data, A_DATA)
213 writer.write(b'OK')
214
215 data = await reader.readexactly(len(B_DATA))
216 self.assertEqual(data, B_DATA)
217 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
218
219 await writer.drain()
220 writer.close()
221
222 CNT += 1
223
224 async def test_client(addr):
225 fut = asyncio.Future()
226
227 def prog(sock):
228 try:
229 sock.starttls(client_sslctx)
230 sock.connect(addr)
231 sock.send(A_DATA)
232
233 data = sock.recv_all(2)
234 self.assertEqual(data, b'OK')
235
236 sock.send(B_DATA)
237 data = sock.recv_all(4)
238 self.assertEqual(data, b'SPAM')
239
240 sock.close()
241
242 except Exception as ex:
243 self.loop.call_soon_threadsafe(fut.set_exception, ex)
244 else:
245 self.loop.call_soon_threadsafe(fut.set_result, None)
246
247 client = self.tcp_client(prog)
248 client.start()
249 clients.append(client)
250
251 await fut
252
253 async def start_server():
254 extras = {}
255 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
256
257 srv = await asyncio.start_server(
258 handle_client,
259 '127.0.0.1', 0,
260 family=socket.AF_INET,
261 ssl=sslctx,
262 **extras)
263
264 try:
265 srv_socks = srv.sockets
266 self.assertTrue(srv_socks)
267
268 addr = srv_socks[0].getsockname()
269
270 tasks = []
271 for _ in range(TOTAL_CNT):
272 tasks.append(test_client(addr))
273
274 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
275
276 finally:
277 self.loop.call_soon(srv.close)
278 await srv.wait_closed()
279
280 with self._silence_eof_received_warning():
281 self.loop.run_until_complete(start_server())
282
283 self.assertEqual(CNT, TOTAL_CNT)
284
285 for client in clients:
286 client.stop()
287
288 def test_create_connection_ssl_1(self):
289 self.loop.set_exception_handler(None)
290
291 CNT = 0
292 TOTAL_CNT = 25
293
294 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
295 B_DATA = b'B' * 1024 * BUF_MULTIPLIER
296
297 sslctx = self._create_server_ssl_context(
298 test_utils.ONLYCERT,
299 test_utils.ONLYKEY
300 )
301 client_sslctx = self._create_client_ssl_context()
302
303 def server(sock):
304 sock.starttls(
305 sslctx,
306 server_side=True)
307
308 data = sock.recv_all(len(A_DATA))
309 self.assertEqual(data, A_DATA)
310 sock.send(b'OK')
311
312 data = sock.recv_all(len(B_DATA))
313 self.assertEqual(data, B_DATA)
314 sock.send(b'SPAM')
315
316 sock.close()
317
318 async def client(addr):
319 extras = {}
320 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
321
322 reader, writer = await asyncio.open_connection(
323 *addr,
324 ssl=client_sslctx,
325 server_hostname='',
326 **extras)
327
328 writer.write(A_DATA)
329 self.assertEqual(await reader.readexactly(2), b'OK')
330
331 writer.write(B_DATA)
332 self.assertEqual(await reader.readexactly(4), b'SPAM')
333
334 nonlocal CNT
335 CNT += 1
336
337 writer.close()
338 await self.wait_closed(writer)
339
340 async def client_sock(addr):
341 sock = socket.socket()
342 sock.connect(addr)
343 reader, writer = await asyncio.open_connection(
344 sock=sock,
345 ssl=client_sslctx,
346 server_hostname='')
347
348 writer.write(A_DATA)
349 self.assertEqual(await reader.readexactly(2), b'OK')
350
351 writer.write(B_DATA)
352 self.assertEqual(await reader.readexactly(4), b'SPAM')
353
354 nonlocal CNT
355 CNT += 1
356
357 writer.close()
358 await self.wait_closed(writer)
359 sock.close()
360
361 def run(coro):
362 nonlocal CNT
363 CNT = 0
364
365 async def _gather(*tasks):
366 # trampoline
367 return await asyncio.gather(*tasks)
368
369 with self.tcp_server(server,
370 max_clients=TOTAL_CNT,
371 backlog=TOTAL_CNT) as srv:
372 tasks = []
373 for _ in range(TOTAL_CNT):
374 tasks.append(coro(srv.addr))
375
376 self.loop.run_until_complete(_gather(*tasks))
377
378 self.assertEqual(CNT, TOTAL_CNT)
379
380 with self._silence_eof_received_warning():
381 run(client)
382
383 with self._silence_eof_received_warning():
384 run(client_sock)
385
386 def test_create_connection_ssl_slow_handshake(self):
387 client_sslctx = self._create_client_ssl_context()
388
389 # silence error logger
390 self.loop.set_exception_handler(lambda *args: None)
391
392 def server(sock):
393 try:
394 sock.recv_all(1024 * 1024)
395 except ConnectionAbortedError:
396 pass
397 finally:
398 sock.close()
399
400 async def client(addr):
401 reader, writer = await asyncio.open_connection(
402 *addr,
403 ssl=client_sslctx,
404 server_hostname='',
405 ssl_handshake_timeout=1.0)
406 writer.close()
407 await self.wait_closed(writer)
408
409 with self.tcp_server(server,
410 max_clients=1,
411 backlog=1) as srv:
412
413 with self.assertRaisesRegex(
414 ConnectionAbortedError,
415 r'SSL handshake.*is taking longer'):
416
417 self.loop.run_until_complete(client(srv.addr))
418
419 def test_create_connection_ssl_failed_certificate(self):
420 # silence error logger
421 self.loop.set_exception_handler(lambda *args: None)
422
423 sslctx = self._create_server_ssl_context(
424 test_utils.ONLYCERT,
425 test_utils.ONLYKEY
426 )
427 client_sslctx = self._create_client_ssl_context(disable_verify=False)
428
429 def server(sock):
430 try:
431 sock.starttls(
432 sslctx,
433 server_side=True)
434 sock.connect()
435 except (ssl.SSLError, OSError):
436 pass
437 finally:
438 sock.close()
439
440 async def client(addr):
441 reader, writer = await asyncio.open_connection(
442 *addr,
443 ssl=client_sslctx,
444 server_hostname='',
445 ssl_handshake_timeout=support.SHORT_TIMEOUT)
446 writer.close()
447 await self.wait_closed(writer)
448
449 with self.tcp_server(server,
450 max_clients=1,
451 backlog=1) as srv:
452
453 with self.assertRaises(ssl.SSLCertVerificationError):
454 self.loop.run_until_complete(client(srv.addr))
455
456 def test_ssl_handshake_timeout(self):
457 # bpo-29970: Check that a connection is aborted if handshake is not
458 # completed in timeout period, instead of remaining open indefinitely
459 client_sslctx = test_utils.simple_client_sslcontext()
460
461 # silence error logger
462 messages = []
463 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
464
465 server_side_aborted = False
466
467 def server(sock):
468 nonlocal server_side_aborted
469 try:
470 sock.recv_all(1024 * 1024)
471 except ConnectionAbortedError:
472 server_side_aborted = True
473 finally:
474 sock.close()
475
476 async def client(addr):
477 await asyncio.wait_for(
478 self.loop.create_connection(
479 asyncio.Protocol,
480 *addr,
481 ssl=client_sslctx,
482 server_hostname='',
483 ssl_handshake_timeout=10.0),
484 0.5)
485
486 with self.tcp_server(server,
487 max_clients=1,
488 backlog=1) as srv:
489
490 with self.assertRaises(asyncio.TimeoutError):
491 self.loop.run_until_complete(client(srv.addr))
492
493 self.assertTrue(server_side_aborted)
494
495 # Python issue #23197: cancelling a handshake must not raise an
496 # exception or log an error, even if the handshake failed
497 self.assertEqual(messages, [])
498
499 def test_ssl_handshake_connection_lost(self):
500 # #246: make sure that no connection_lost() is called before
501 # connection_made() is called first
502
503 client_sslctx = test_utils.simple_client_sslcontext()
504
505 # silence error logger
506 self.loop.set_exception_handler(lambda loop, ctx: None)
507
508 connection_made_called = False
509 connection_lost_called = False
510
511 def server(sock):
512 sock.recv(1024)
513 # break the connection during handshake
514 sock.close()
515
516 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
517 def connection_made(self, transport):
518 nonlocal connection_made_called
519 connection_made_called = True
520
521 def connection_lost(self, exc):
522 nonlocal connection_lost_called
523 connection_lost_called = True
524
525 async def client(addr):
526 await self.loop.create_connection(
527 ClientProto,
528 *addr,
529 ssl=client_sslctx,
530 server_hostname=''),
531
532 with self.tcp_server(server,
533 max_clients=1,
534 backlog=1) as srv:
535
536 with self.assertRaises(ConnectionResetError):
537 self.loop.run_until_complete(client(srv.addr))
538
539 if connection_lost_called:
540 if connection_made_called:
541 self.fail("unexpected call to connection_lost()")
542 else:
543 self.fail("unexpected call to connection_lost() without"
544 "calling connection_made()")
545 elif connection_made_called:
546 self.fail("unexpected call to connection_made()")
547
548 def test_ssl_connect_accepted_socket(self):
549 proto = ssl.PROTOCOL_TLS_SERVER
550 server_context = ssl.SSLContext(proto)
551 server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
552 if hasattr(server_context, 'check_hostname'):
553 server_context.check_hostname = False
554 server_context.verify_mode = ssl.CERT_NONE
555
556 client_context = ssl.SSLContext(proto)
557 if hasattr(server_context, 'check_hostname'):
558 client_context.check_hostname = False
559 client_context.verify_mode = ssl.CERT_NONE
560
561 def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
562 loop = self.loop
563
564 class ESC[4;38;5;81mMyProto(ESC[4;38;5;149mMyBaseProto):
565
566 def connection_lost(self, exc):
567 super().connection_lost(exc)
568 loop.call_soon(loop.stop)
569
570 def data_received(self, data):
571 super().data_received(data)
572 self.transport.write(expected_response)
573
574 lsock = socket.socket(socket.AF_INET)
575 lsock.bind(('127.0.0.1', 0))
576 lsock.listen(1)
577 addr = lsock.getsockname()
578
579 message = b'test data'
580 response = None
581 expected_response = b'roger'
582
583 def client():
584 nonlocal response
585 try:
586 csock = socket.socket(socket.AF_INET)
587 if client_ssl is not None:
588 csock = client_ssl.wrap_socket(csock)
589 csock.connect(addr)
590 csock.sendall(message)
591 response = csock.recv(99)
592 csock.close()
593 except Exception as exc:
594 print(
595 "Failure in client thread in test_connect_accepted_socket",
596 exc)
597
598 thread = threading.Thread(target=client, daemon=True)
599 thread.start()
600
601 conn, _ = lsock.accept()
602 proto = MyProto(loop=loop)
603 proto.loop = loop
604
605 extras = {}
606 if server_ssl:
607 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
608
609 f = loop.create_task(
610 loop.connect_accepted_socket(
611 (lambda: proto), conn, ssl=server_ssl,
612 **extras))
613 loop.run_forever()
614 conn.close()
615 lsock.close()
616
617 thread.join(1)
618 self.assertFalse(thread.is_alive())
619 self.assertEqual(proto.state, 'CLOSED')
620 self.assertEqual(proto.nbytes, len(message))
621 self.assertEqual(response, expected_response)
622 tr, _ = f.result()
623
624 if server_ssl:
625 self.assertIn('SSL', tr.__class__.__name__)
626
627 tr.close()
628 # let it close
629 self.loop.run_until_complete(asyncio.sleep(0.1))
630
631 def test_start_tls_client_corrupted_ssl(self):
632 self.loop.set_exception_handler(lambda loop, ctx: None)
633
634 sslctx = test_utils.simple_server_sslcontext()
635 client_sslctx = test_utils.simple_client_sslcontext()
636
637 def server(sock):
638 orig_sock = sock.dup()
639 try:
640 sock.starttls(
641 sslctx,
642 server_side=True)
643 sock.sendall(b'A\n')
644 sock.recv_all(1)
645 orig_sock.send(b'please corrupt the SSL connection')
646 except ssl.SSLError:
647 pass
648 finally:
649 sock.close()
650 orig_sock.close()
651
652 async def client(addr):
653 reader, writer = await asyncio.open_connection(
654 *addr,
655 ssl=client_sslctx,
656 server_hostname='')
657
658 self.assertEqual(await reader.readline(), b'A\n')
659 writer.write(b'B')
660 with self.assertRaises(ssl.SSLError):
661 await reader.readline()
662 writer.close()
663 try:
664 await self.wait_closed(writer)
665 except ssl.SSLError:
666 pass
667 return 'OK'
668
669 with self.tcp_server(server,
670 max_clients=1,
671 backlog=1) as srv:
672
673 res = self.loop.run_until_complete(client(srv.addr))
674
675 self.assertEqual(res, 'OK')
676
677 def test_start_tls_client_reg_proto_1(self):
678 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
679
680 server_context = test_utils.simple_server_sslcontext()
681 client_context = test_utils.simple_client_sslcontext()
682
683 def serve(sock):
684 sock.settimeout(self.TIMEOUT)
685
686 data = sock.recv_all(len(HELLO_MSG))
687 self.assertEqual(len(data), len(HELLO_MSG))
688
689 sock.starttls(server_context, server_side=True)
690
691 sock.sendall(b'O')
692 data = sock.recv_all(len(HELLO_MSG))
693 self.assertEqual(len(data), len(HELLO_MSG))
694
695 sock.unwrap()
696 sock.close()
697
698 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
699 def __init__(self, on_data, on_eof):
700 self.on_data = on_data
701 self.on_eof = on_eof
702 self.con_made_cnt = 0
703
704 def connection_made(proto, tr):
705 proto.con_made_cnt += 1
706 # Ensure connection_made gets called only once.
707 self.assertEqual(proto.con_made_cnt, 1)
708
709 def data_received(self, data):
710 self.on_data.set_result(data)
711
712 def eof_received(self):
713 self.on_eof.set_result(True)
714
715 async def client(addr):
716 await asyncio.sleep(0.5)
717
718 on_data = self.loop.create_future()
719 on_eof = self.loop.create_future()
720
721 tr, proto = await self.loop.create_connection(
722 lambda: ClientProto(on_data, on_eof), *addr)
723
724 tr.write(HELLO_MSG)
725 new_tr = await self.loop.start_tls(tr, proto, client_context)
726
727 self.assertEqual(await on_data, b'O')
728 new_tr.write(HELLO_MSG)
729 await on_eof
730
731 new_tr.close()
732
733 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
734 self.loop.run_until_complete(
735 asyncio.wait_for(client(srv.addr),
736 timeout=support.SHORT_TIMEOUT))
737
738 def test_create_connection_memory_leak(self):
739 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
740
741 server_context = self._create_server_ssl_context(
742 test_utils.ONLYCERT, test_utils.ONLYKEY)
743 client_context = self._create_client_ssl_context()
744
745 def serve(sock):
746 sock.settimeout(self.TIMEOUT)
747
748 sock.starttls(server_context, server_side=True)
749
750 sock.sendall(b'O')
751 data = sock.recv_all(len(HELLO_MSG))
752 self.assertEqual(len(data), len(HELLO_MSG))
753
754 sock.unwrap()
755 sock.close()
756
757 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
758 def __init__(self, on_data, on_eof):
759 self.on_data = on_data
760 self.on_eof = on_eof
761 self.con_made_cnt = 0
762
763 def connection_made(proto, tr):
764 # XXX: We assume user stores the transport in protocol
765 proto.tr = tr
766 proto.con_made_cnt += 1
767 # Ensure connection_made gets called only once.
768 self.assertEqual(proto.con_made_cnt, 1)
769
770 def data_received(self, data):
771 self.on_data.set_result(data)
772
773 def eof_received(self):
774 self.on_eof.set_result(True)
775
776 async def client(addr):
777 await asyncio.sleep(0.5)
778
779 on_data = self.loop.create_future()
780 on_eof = self.loop.create_future()
781
782 tr, proto = await self.loop.create_connection(
783 lambda: ClientProto(on_data, on_eof), *addr,
784 ssl=client_context)
785
786 self.assertEqual(await on_data, b'O')
787 tr.write(HELLO_MSG)
788 await on_eof
789
790 tr.close()
791
792 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
793 self.loop.run_until_complete(
794 asyncio.wait_for(client(srv.addr),
795 timeout=support.SHORT_TIMEOUT))
796
797 # No garbage is left for SSL client from loop.create_connection, even
798 # if user stores the SSLTransport in corresponding protocol instance
799 client_context = weakref.ref(client_context)
800 self.assertIsNone(client_context())
801
802 def test_start_tls_client_buf_proto_1(self):
803 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
804
805 server_context = test_utils.simple_server_sslcontext()
806 client_context = test_utils.simple_client_sslcontext()
807
808 client_con_made_calls = 0
809
810 def serve(sock):
811 sock.settimeout(self.TIMEOUT)
812
813 data = sock.recv_all(len(HELLO_MSG))
814 self.assertEqual(len(data), len(HELLO_MSG))
815
816 sock.starttls(server_context, server_side=True)
817
818 sock.sendall(b'O')
819 data = sock.recv_all(len(HELLO_MSG))
820 self.assertEqual(len(data), len(HELLO_MSG))
821
822 sock.sendall(b'2')
823 data = sock.recv_all(len(HELLO_MSG))
824 self.assertEqual(len(data), len(HELLO_MSG))
825
826 sock.unwrap()
827 sock.close()
828
829 class ESC[4;38;5;81mClientProtoFirst(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mBufferedProtocol):
830 def __init__(self, on_data):
831 self.on_data = on_data
832 self.buf = bytearray(1)
833
834 def connection_made(self, tr):
835 nonlocal client_con_made_calls
836 client_con_made_calls += 1
837
838 def get_buffer(self, sizehint):
839 return self.buf
840
841 def buffer_updated(self, nsize):
842 assert nsize == 1
843 self.on_data.set_result(bytes(self.buf[:nsize]))
844
845 def eof_received(self):
846 pass
847
848 class ESC[4;38;5;81mClientProtoSecond(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
849 def __init__(self, on_data, on_eof):
850 self.on_data = on_data
851 self.on_eof = on_eof
852 self.con_made_cnt = 0
853
854 def connection_made(self, tr):
855 nonlocal client_con_made_calls
856 client_con_made_calls += 1
857
858 def data_received(self, data):
859 self.on_data.set_result(data)
860
861 def eof_received(self):
862 self.on_eof.set_result(True)
863
864 async def client(addr):
865 await asyncio.sleep(0.5)
866
867 on_data1 = self.loop.create_future()
868 on_data2 = self.loop.create_future()
869 on_eof = self.loop.create_future()
870
871 tr, proto = await self.loop.create_connection(
872 lambda: ClientProtoFirst(on_data1), *addr)
873
874 tr.write(HELLO_MSG)
875 new_tr = await self.loop.start_tls(tr, proto, client_context)
876
877 self.assertEqual(await on_data1, b'O')
878 new_tr.write(HELLO_MSG)
879
880 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
881 self.assertEqual(await on_data2, b'2')
882 new_tr.write(HELLO_MSG)
883 await on_eof
884
885 new_tr.close()
886
887 # connection_made() should be called only once -- when
888 # we establish connection for the first time. Start TLS
889 # doesn't call connection_made() on application protocols.
890 self.assertEqual(client_con_made_calls, 1)
891
892 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
893 self.loop.run_until_complete(
894 asyncio.wait_for(client(srv.addr),
895 timeout=self.TIMEOUT))
896
897 def test_start_tls_slow_client_cancel(self):
898 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
899
900 client_context = test_utils.simple_client_sslcontext()
901 server_waits_on_handshake = self.loop.create_future()
902
903 def serve(sock):
904 sock.settimeout(self.TIMEOUT)
905
906 data = sock.recv_all(len(HELLO_MSG))
907 self.assertEqual(len(data), len(HELLO_MSG))
908
909 try:
910 self.loop.call_soon_threadsafe(
911 server_waits_on_handshake.set_result, None)
912 data = sock.recv_all(1024 * 1024)
913 except ConnectionAbortedError:
914 pass
915 finally:
916 sock.close()
917
918 class ESC[4;38;5;81mClientProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
919 def __init__(self, on_data, on_eof):
920 self.on_data = on_data
921 self.on_eof = on_eof
922 self.con_made_cnt = 0
923
924 def connection_made(proto, tr):
925 proto.con_made_cnt += 1
926 # Ensure connection_made gets called only once.
927 self.assertEqual(proto.con_made_cnt, 1)
928
929 def data_received(self, data):
930 self.on_data.set_result(data)
931
932 def eof_received(self):
933 self.on_eof.set_result(True)
934
935 async def client(addr):
936 await asyncio.sleep(0.5)
937
938 on_data = self.loop.create_future()
939 on_eof = self.loop.create_future()
940
941 tr, proto = await self.loop.create_connection(
942 lambda: ClientProto(on_data, on_eof), *addr)
943
944 tr.write(HELLO_MSG)
945
946 await server_waits_on_handshake
947
948 with self.assertRaises(asyncio.TimeoutError):
949 await asyncio.wait_for(
950 self.loop.start_tls(tr, proto, client_context),
951 0.5)
952
953 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
954 self.loop.run_until_complete(
955 asyncio.wait_for(client(srv.addr),
956 timeout=support.SHORT_TIMEOUT))
957
958 def test_start_tls_server_1(self):
959 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
960
961 server_context = test_utils.simple_server_sslcontext()
962 client_context = test_utils.simple_client_sslcontext()
963
964 def client(sock, addr):
965 sock.settimeout(self.TIMEOUT)
966
967 sock.connect(addr)
968 data = sock.recv_all(len(HELLO_MSG))
969 self.assertEqual(len(data), len(HELLO_MSG))
970
971 sock.starttls(client_context)
972 sock.sendall(HELLO_MSG)
973
974 sock.unwrap()
975 sock.close()
976
977 class ESC[4;38;5;81mServerProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
978 def __init__(self, on_con, on_eof, on_con_lost):
979 self.on_con = on_con
980 self.on_eof = on_eof
981 self.on_con_lost = on_con_lost
982 self.data = b''
983
984 def connection_made(self, tr):
985 self.on_con.set_result(tr)
986
987 def data_received(self, data):
988 self.data += data
989
990 def eof_received(self):
991 self.on_eof.set_result(1)
992
993 def connection_lost(self, exc):
994 if exc is None:
995 self.on_con_lost.set_result(None)
996 else:
997 self.on_con_lost.set_exception(exc)
998
999 async def main(proto, on_con, on_eof, on_con_lost):
1000 tr = await on_con
1001 tr.write(HELLO_MSG)
1002
1003 self.assertEqual(proto.data, b'')
1004
1005 new_tr = await self.loop.start_tls(
1006 tr, proto, server_context,
1007 server_side=True,
1008 ssl_handshake_timeout=self.TIMEOUT)
1009
1010 await on_eof
1011 await on_con_lost
1012 self.assertEqual(proto.data, HELLO_MSG)
1013 new_tr.close()
1014
1015 async def run_main():
1016 on_con = self.loop.create_future()
1017 on_eof = self.loop.create_future()
1018 on_con_lost = self.loop.create_future()
1019 proto = ServerProto(on_con, on_eof, on_con_lost)
1020
1021 server = await self.loop.create_server(
1022 lambda: proto, '127.0.0.1', 0)
1023 addr = server.sockets[0].getsockname()
1024
1025 with self.tcp_client(lambda sock: client(sock, addr),
1026 timeout=self.TIMEOUT):
1027 await asyncio.wait_for(
1028 main(proto, on_con, on_eof, on_con_lost),
1029 timeout=self.TIMEOUT)
1030
1031 server.close()
1032 await server.wait_closed()
1033
1034 self.loop.run_until_complete(run_main())
1035
1036 def test_create_server_ssl_over_ssl(self):
1037 CNT = 0 # number of clients that were successful
1038 TOTAL_CNT = 25 # total number of clients that test will create
1039 TIMEOUT = support.LONG_TIMEOUT # timeout for this test
1040
1041 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
1042 B_DATA = b'B' * 1024 * BUF_MULTIPLIER
1043
1044 sslctx_1 = self._create_server_ssl_context(
1045 test_utils.ONLYCERT, test_utils.ONLYKEY)
1046 client_sslctx_1 = self._create_client_ssl_context()
1047 sslctx_2 = self._create_server_ssl_context(
1048 test_utils.ONLYCERT, test_utils.ONLYKEY)
1049 client_sslctx_2 = self._create_client_ssl_context()
1050
1051 clients = []
1052
1053 async def handle_client(reader, writer):
1054 nonlocal CNT
1055
1056 data = await reader.readexactly(len(A_DATA))
1057 self.assertEqual(data, A_DATA)
1058 writer.write(b'OK')
1059
1060 data = await reader.readexactly(len(B_DATA))
1061 self.assertEqual(data, B_DATA)
1062 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
1063
1064 await writer.drain()
1065 writer.close()
1066
1067 CNT += 1
1068
1069 class ESC[4;38;5;81mServerProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mStreamReaderProtocol):
1070 def connection_made(self, transport):
1071 super_ = super()
1072 transport.pause_reading()
1073 fut = self._loop.create_task(self._loop.start_tls(
1074 transport, self, sslctx_2, server_side=True))
1075
1076 def cb(_):
1077 try:
1078 tr = fut.result()
1079 except Exception as ex:
1080 super_.connection_lost(ex)
1081 else:
1082 super_.connection_made(tr)
1083 fut.add_done_callback(cb)
1084
1085 def server_protocol_factory():
1086 reader = asyncio.StreamReader()
1087 protocol = ServerProtocol(reader, handle_client)
1088 return protocol
1089
1090 async def test_client(addr):
1091 fut = asyncio.Future()
1092
1093 def prog(sock):
1094 try:
1095 sock.connect(addr)
1096 sock.starttls(client_sslctx_1)
1097
1098 # because wrap_socket() doesn't work correctly on
1099 # SSLSocket, we have to do the 2nd level SSL manually
1100 incoming = ssl.MemoryBIO()
1101 outgoing = ssl.MemoryBIO()
1102 sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
1103
1104 def do(func, *args):
1105 while True:
1106 try:
1107 rv = func(*args)
1108 break
1109 except ssl.SSLWantReadError:
1110 if outgoing.pending:
1111 sock.send(outgoing.read())
1112 incoming.write(sock.recv(65536))
1113 if outgoing.pending:
1114 sock.send(outgoing.read())
1115 return rv
1116
1117 do(sslobj.do_handshake)
1118
1119 do(sslobj.write, A_DATA)
1120 data = do(sslobj.read, 2)
1121 self.assertEqual(data, b'OK')
1122
1123 do(sslobj.write, B_DATA)
1124 data = b''
1125 while True:
1126 chunk = do(sslobj.read, 4)
1127 if not chunk:
1128 break
1129 data += chunk
1130 self.assertEqual(data, b'SPAM')
1131
1132 do(sslobj.unwrap)
1133 sock.close()
1134
1135 except Exception as ex:
1136 self.loop.call_soon_threadsafe(fut.set_exception, ex)
1137 sock.close()
1138 else:
1139 self.loop.call_soon_threadsafe(fut.set_result, None)
1140
1141 client = self.tcp_client(prog)
1142 client.start()
1143 clients.append(client)
1144
1145 await fut
1146
1147 async def start_server():
1148 extras = {}
1149
1150 srv = await self.loop.create_server(
1151 server_protocol_factory,
1152 '127.0.0.1', 0,
1153 family=socket.AF_INET,
1154 ssl=sslctx_1,
1155 **extras)
1156
1157 try:
1158 srv_socks = srv.sockets
1159 self.assertTrue(srv_socks)
1160
1161 addr = srv_socks[0].getsockname()
1162
1163 tasks = []
1164 for _ in range(TOTAL_CNT):
1165 tasks.append(test_client(addr))
1166
1167 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
1168
1169 finally:
1170 self.loop.call_soon(srv.close)
1171 await srv.wait_closed()
1172
1173 with self._silence_eof_received_warning():
1174 self.loop.run_until_complete(start_server())
1175
1176 self.assertEqual(CNT, TOTAL_CNT)
1177
1178 for client in clients:
1179 client.stop()
1180
1181 def test_shutdown_cleanly(self):
1182 CNT = 0
1183 TOTAL_CNT = 25
1184
1185 A_DATA = b'A' * 1024 * BUF_MULTIPLIER
1186
1187 sslctx = self._create_server_ssl_context(
1188 test_utils.ONLYCERT, test_utils.ONLYKEY)
1189 client_sslctx = self._create_client_ssl_context()
1190
1191 def server(sock):
1192 sock.starttls(
1193 sslctx,
1194 server_side=True)
1195
1196 data = sock.recv_all(len(A_DATA))
1197 self.assertEqual(data, A_DATA)
1198 sock.send(b'OK')
1199
1200 sock.unwrap()
1201
1202 sock.close()
1203
1204 async def client(addr):
1205 extras = {}
1206 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
1207
1208 reader, writer = await asyncio.open_connection(
1209 *addr,
1210 ssl=client_sslctx,
1211 server_hostname='',
1212 **extras)
1213
1214 writer.write(A_DATA)
1215 self.assertEqual(await reader.readexactly(2), b'OK')
1216
1217 self.assertEqual(await reader.read(), b'')
1218
1219 nonlocal CNT
1220 CNT += 1
1221
1222 writer.close()
1223 await self.wait_closed(writer)
1224
1225 def run(coro):
1226 nonlocal CNT
1227 CNT = 0
1228
1229 async def _gather(*tasks):
1230 return await asyncio.gather(*tasks)
1231
1232 with self.tcp_server(server,
1233 max_clients=TOTAL_CNT,
1234 backlog=TOTAL_CNT) as srv:
1235 tasks = []
1236 for _ in range(TOTAL_CNT):
1237 tasks.append(coro(srv.addr))
1238
1239 self.loop.run_until_complete(
1240 _gather(*tasks))
1241
1242 self.assertEqual(CNT, TOTAL_CNT)
1243
1244 with self._silence_eof_received_warning():
1245 run(client)
1246
1247 def test_flush_before_shutdown(self):
1248 CHUNK = 1024 * 128
1249 SIZE = 32
1250
1251 sslctx = self._create_server_ssl_context(
1252 test_utils.ONLYCERT, test_utils.ONLYKEY)
1253 client_sslctx = self._create_client_ssl_context()
1254
1255 future = None
1256
1257 def server(sock):
1258 sock.starttls(sslctx, server_side=True)
1259 self.assertEqual(sock.recv_all(4), b'ping')
1260 sock.send(b'pong')
1261 time.sleep(0.5) # hopefully stuck the TCP buffer
1262 data = sock.recv_all(CHUNK * SIZE)
1263 self.assertEqual(len(data), CHUNK * SIZE)
1264 sock.close()
1265
1266 def run(meth):
1267 def wrapper(sock):
1268 try:
1269 meth(sock)
1270 except Exception as ex:
1271 self.loop.call_soon_threadsafe(future.set_exception, ex)
1272 else:
1273 self.loop.call_soon_threadsafe(future.set_result, None)
1274 return wrapper
1275
1276 async def client(addr):
1277 nonlocal future
1278 future = self.loop.create_future()
1279 reader, writer = await asyncio.open_connection(
1280 *addr,
1281 ssl=client_sslctx,
1282 server_hostname='')
1283 sslprotocol = writer.transport._ssl_protocol
1284 writer.write(b'ping')
1285 data = await reader.readexactly(4)
1286 self.assertEqual(data, b'pong')
1287
1288 sslprotocol.pause_writing()
1289 for _ in range(SIZE):
1290 writer.write(b'x' * CHUNK)
1291
1292 writer.close()
1293 sslprotocol.resume_writing()
1294
1295 await self.wait_closed(writer)
1296 try:
1297 data = await reader.read()
1298 self.assertEqual(data, b'')
1299 except ConnectionResetError:
1300 pass
1301 await future
1302
1303 with self.tcp_server(run(server)) as srv:
1304 self.loop.run_until_complete(client(srv.addr))
1305
1306 def test_remote_shutdown_receives_trailing_data(self):
1307 CHUNK = 1024 * 128
1308 SIZE = 32
1309
1310 sslctx = self._create_server_ssl_context(
1311 test_utils.ONLYCERT,
1312 test_utils.ONLYKEY
1313 )
1314 client_sslctx = self._create_client_ssl_context()
1315 future = None
1316
1317 def server(sock):
1318 incoming = ssl.MemoryBIO()
1319 outgoing = ssl.MemoryBIO()
1320 sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
1321
1322 while True:
1323 try:
1324 sslobj.do_handshake()
1325 except ssl.SSLWantReadError:
1326 if outgoing.pending:
1327 sock.send(outgoing.read())
1328 incoming.write(sock.recv(16384))
1329 else:
1330 if outgoing.pending:
1331 sock.send(outgoing.read())
1332 break
1333
1334 while True:
1335 try:
1336 data = sslobj.read(4)
1337 except ssl.SSLWantReadError:
1338 incoming.write(sock.recv(16384))
1339 else:
1340 break
1341
1342 self.assertEqual(data, b'ping')
1343 sslobj.write(b'pong')
1344 sock.send(outgoing.read())
1345
1346 time.sleep(0.2) # wait for the peer to fill its backlog
1347
1348 # send close_notify but don't wait for response
1349 with self.assertRaises(ssl.SSLWantReadError):
1350 sslobj.unwrap()
1351 sock.send(outgoing.read())
1352
1353 # should receive all data
1354 data_len = 0
1355 while True:
1356 try:
1357 chunk = len(sslobj.read(16384))
1358 data_len += chunk
1359 except ssl.SSLWantReadError:
1360 incoming.write(sock.recv(16384))
1361 except ssl.SSLZeroReturnError:
1362 break
1363
1364 self.assertEqual(data_len, CHUNK * SIZE)
1365
1366 # verify that close_notify is received
1367 sslobj.unwrap()
1368
1369 sock.close()
1370
1371 def eof_server(sock):
1372 sock.starttls(sslctx, server_side=True)
1373 self.assertEqual(sock.recv_all(4), b'ping')
1374 sock.send(b'pong')
1375
1376 time.sleep(0.2) # wait for the peer to fill its backlog
1377
1378 # send EOF
1379 sock.shutdown(socket.SHUT_WR)
1380
1381 # should receive all data
1382 data = sock.recv_all(CHUNK * SIZE)
1383 self.assertEqual(len(data), CHUNK * SIZE)
1384
1385 sock.close()
1386
1387 async def client(addr):
1388 nonlocal future
1389 future = self.loop.create_future()
1390
1391 reader, writer = await asyncio.open_connection(
1392 *addr,
1393 ssl=client_sslctx,
1394 server_hostname='')
1395 writer.write(b'ping')
1396 data = await reader.readexactly(4)
1397 self.assertEqual(data, b'pong')
1398
1399 # fill write backlog in a hacky way - renegotiation won't help
1400 for _ in range(SIZE):
1401 writer.transport._test__append_write_backlog(b'x' * CHUNK)
1402
1403 try:
1404 data = await reader.read()
1405 self.assertEqual(data, b'')
1406 except (BrokenPipeError, ConnectionResetError):
1407 pass
1408
1409 await future
1410
1411 writer.close()
1412 await self.wait_closed(writer)
1413
1414 def run(meth):
1415 def wrapper(sock):
1416 try:
1417 meth(sock)
1418 except Exception as ex:
1419 self.loop.call_soon_threadsafe(future.set_exception, ex)
1420 else:
1421 self.loop.call_soon_threadsafe(future.set_result, None)
1422 return wrapper
1423
1424 with self.tcp_server(run(server)) as srv:
1425 self.loop.run_until_complete(client(srv.addr))
1426
1427 with self.tcp_server(run(eof_server)) as srv:
1428 self.loop.run_until_complete(client(srv.addr))
1429
1430 def test_connect_timeout_warning(self):
1431 s = socket.socket(socket.AF_INET)
1432 s.bind(('127.0.0.1', 0))
1433 addr = s.getsockname()
1434
1435 async def test():
1436 try:
1437 await asyncio.wait_for(
1438 self.loop.create_connection(asyncio.Protocol,
1439 *addr, ssl=True),
1440 0.1)
1441 except (ConnectionRefusedError, asyncio.TimeoutError):
1442 pass
1443 else:
1444 self.fail('TimeoutError is not raised')
1445
1446 with s:
1447 try:
1448 with self.assertWarns(ResourceWarning) as cm:
1449 self.loop.run_until_complete(test())
1450 gc.collect()
1451 gc.collect()
1452 gc.collect()
1453 except AssertionError as e:
1454 self.assertEqual(str(e), 'ResourceWarning not triggered')
1455 else:
1456 self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
1457
1458 def test_handshake_timeout_handler_leak(self):
1459 s = socket.socket(socket.AF_INET)
1460 s.bind(('127.0.0.1', 0))
1461 s.listen(1)
1462 addr = s.getsockname()
1463
1464 async def test(ctx):
1465 try:
1466 await asyncio.wait_for(
1467 self.loop.create_connection(asyncio.Protocol, *addr,
1468 ssl=ctx),
1469 0.1)
1470 except (ConnectionRefusedError, asyncio.TimeoutError):
1471 pass
1472 else:
1473 self.fail('TimeoutError is not raised')
1474
1475 with s:
1476 ctx = ssl.create_default_context()
1477 self.loop.run_until_complete(test(ctx))
1478 ctx = weakref.ref(ctx)
1479
1480 # SSLProtocol should be DECREF to 0
1481 self.assertIsNone(ctx())
1482
1483 def test_shutdown_timeout_handler_leak(self):
1484 loop = self.loop
1485
1486 def server(sock):
1487 sslctx = self._create_server_ssl_context(
1488 test_utils.ONLYCERT,
1489 test_utils.ONLYKEY
1490 )
1491 sock = sslctx.wrap_socket(sock, server_side=True)
1492 sock.recv(32)
1493 sock.close()
1494
1495 class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
1496 def __init__(self):
1497 self.fut = asyncio.Future(loop=loop)
1498
1499 def connection_lost(self, exc):
1500 self.fut.set_result(None)
1501
1502 async def client(addr, ctx):
1503 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
1504 tr.close()
1505 await pr.fut
1506
1507 with self.tcp_server(server) as srv:
1508 ctx = self._create_client_ssl_context()
1509 loop.run_until_complete(client(srv.addr, ctx))
1510 ctx = weakref.ref(ctx)
1511
1512 # asyncio has no shutdown timeout, but it ends up with a circular
1513 # reference loop - not ideal (introduces gc glitches), but at least
1514 # not leaking
1515 gc.collect()
1516 gc.collect()
1517 gc.collect()
1518
1519 # SSLProtocol should be DECREF to 0
1520 self.assertIsNone(ctx())
1521
1522 def test_shutdown_timeout_handler_not_set(self):
1523 loop = self.loop
1524 eof = asyncio.Event()
1525 extra = None
1526
1527 def server(sock):
1528 sslctx = self._create_server_ssl_context(
1529 test_utils.ONLYCERT,
1530 test_utils.ONLYKEY
1531 )
1532 sock = sslctx.wrap_socket(sock, server_side=True)
1533 sock.send(b'hello')
1534 assert sock.recv(1024) == b'world'
1535 sock.send(b'extra bytes')
1536 # sending EOF here
1537 sock.shutdown(socket.SHUT_WR)
1538 loop.call_soon_threadsafe(eof.set)
1539 # make sure we have enough time to reproduce the issue
1540 assert sock.recv(1024) == b''
1541 sock.close()
1542
1543 class ESC[4;38;5;81mProtocol(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
1544 def __init__(self):
1545 self.fut = asyncio.Future(loop=loop)
1546 self.transport = None
1547
1548 def connection_made(self, transport):
1549 self.transport = transport
1550
1551 def data_received(self, data):
1552 if data == b'hello':
1553 self.transport.write(b'world')
1554 # pause reading would make incoming data stay in the sslobj
1555 self.transport.pause_reading()
1556 else:
1557 nonlocal extra
1558 extra = data
1559
1560 def connection_lost(self, exc):
1561 if exc is None:
1562 self.fut.set_result(None)
1563 else:
1564 self.fut.set_exception(exc)
1565
1566 async def client(addr):
1567 ctx = self._create_client_ssl_context()
1568 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
1569 await eof.wait()
1570 tr.resume_reading()
1571 await pr.fut
1572 tr.close()
1573 assert extra == b'extra bytes'
1574
1575 with self.tcp_server(server) as srv:
1576 loop.run_until_complete(client(srv.addr))
1577
1578
1579 ###############################################################################
1580 # Socket Testing Utilities
1581 ###############################################################################
1582
1583
1584 class ESC[4;38;5;81mTestSocketWrapper:
1585
1586 def __init__(self, sock):
1587 self.__sock = sock
1588
1589 def recv_all(self, n):
1590 buf = b''
1591 while len(buf) < n:
1592 data = self.recv(n - len(buf))
1593 if data == b'':
1594 raise ConnectionAbortedError
1595 buf += data
1596 return buf
1597
1598 def starttls(self, ssl_context, *,
1599 server_side=False,
1600 server_hostname=None,
1601 do_handshake_on_connect=True):
1602
1603 assert isinstance(ssl_context, ssl.SSLContext)
1604
1605 ssl_sock = ssl_context.wrap_socket(
1606 self.__sock, server_side=server_side,
1607 server_hostname=server_hostname,
1608 do_handshake_on_connect=do_handshake_on_connect)
1609
1610 if server_side:
1611 ssl_sock.do_handshake()
1612
1613 self.__sock.close()
1614 self.__sock = ssl_sock
1615
1616 def __getattr__(self, name):
1617 return getattr(self.__sock, name)
1618
1619 def __repr__(self):
1620 return '<{} {!r}>'.format(type(self).__name__, self.__sock)
1621
1622
1623 class ESC[4;38;5;81mSocketThread(ESC[4;38;5;149mthreadingESC[4;38;5;149m.ESC[4;38;5;149mThread):
1624
1625 def stop(self):
1626 self._active = False
1627 self.join()
1628
1629 def __enter__(self):
1630 self.start()
1631 return self
1632
1633 def __exit__(self, *exc):
1634 self.stop()
1635
1636
1637 class ESC[4;38;5;81mTestThreadedClient(ESC[4;38;5;149mSocketThread):
1638
1639 def __init__(self, test, sock, prog, timeout):
1640 threading.Thread.__init__(self, None, None, 'test-client')
1641 self.daemon = True
1642
1643 self._timeout = timeout
1644 self._sock = sock
1645 self._active = True
1646 self._prog = prog
1647 self._test = test
1648
1649 def run(self):
1650 try:
1651 self._prog(TestSocketWrapper(self._sock))
1652 except (KeyboardInterrupt, SystemExit):
1653 raise
1654 except BaseException as ex:
1655 self._test._abort_socket_test(ex)
1656
1657
1658 class ESC[4;38;5;81mTestThreadedServer(ESC[4;38;5;149mSocketThread):
1659
1660 def __init__(self, test, sock, prog, timeout, max_clients):
1661 threading.Thread.__init__(self, None, None, 'test-server')
1662 self.daemon = True
1663
1664 self._clients = 0
1665 self._finished_clients = 0
1666 self._max_clients = max_clients
1667 self._timeout = timeout
1668 self._sock = sock
1669 self._active = True
1670
1671 self._prog = prog
1672
1673 self._s1, self._s2 = socket.socketpair()
1674 self._s1.setblocking(False)
1675
1676 self._test = test
1677
1678 def stop(self):
1679 try:
1680 if self._s2 and self._s2.fileno() != -1:
1681 try:
1682 self._s2.send(b'stop')
1683 except OSError:
1684 pass
1685 finally:
1686 super().stop()
1687
1688 def run(self):
1689 try:
1690 with self._sock:
1691 self._sock.setblocking(False)
1692 self._run()
1693 finally:
1694 self._s1.close()
1695 self._s2.close()
1696
1697 def _run(self):
1698 while self._active:
1699 if self._clients >= self._max_clients:
1700 return
1701
1702 r, w, x = select.select(
1703 [self._sock, self._s1], [], [], self._timeout)
1704
1705 if self._s1 in r:
1706 return
1707
1708 if self._sock in r:
1709 try:
1710 conn, addr = self._sock.accept()
1711 except BlockingIOError:
1712 continue
1713 except socket.timeout:
1714 if not self._active:
1715 return
1716 else:
1717 raise
1718 else:
1719 self._clients += 1
1720 conn.settimeout(self._timeout)
1721 try:
1722 with conn:
1723 self._handle_client(conn)
1724 except (KeyboardInterrupt, SystemExit):
1725 raise
1726 except BaseException as ex:
1727 self._active = False
1728 try:
1729 raise
1730 finally:
1731 self._test._abort_socket_test(ex)
1732
1733 def _handle_client(self, sock):
1734 self._prog(TestSocketWrapper(sock))
1735
1736 @property
1737 def addr(self):
1738 return self._sock.getsockname()