python (3.12.0)
1 """Tests for streams.py."""
2
3 import gc
4 import os
5 import queue
6 import pickle
7 import socket
8 import sys
9 import threading
10 import unittest
11 from unittest import mock
12 import warnings
13 from test.support import socket_helper
14 try:
15 import ssl
16 except ImportError:
17 ssl = None
18
19 import asyncio
20 from test.test_asyncio import utils as test_utils
21
22
23 def tearDownModule():
24 asyncio.set_event_loop_policy(None)
25
26
27 class ESC[4;38;5;81mStreamTests(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
28
29 DATA = b'line1\nline2\nline3\n'
30
31 def setUp(self):
32 super().setUp()
33 self.loop = asyncio.new_event_loop()
34 self.set_event_loop(self.loop)
35
36 def tearDown(self):
37 # just in case if we have transport close callbacks
38 test_utils.run_briefly(self.loop)
39
40 self.loop.close()
41 gc.collect()
42 super().tearDown()
43
44 def _basetest_open_connection(self, open_connection_fut):
45 messages = []
46 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
47 reader, writer = self.loop.run_until_complete(open_connection_fut)
48 writer.write(b'GET / HTTP/1.0\r\n\r\n')
49 f = reader.readline()
50 data = self.loop.run_until_complete(f)
51 self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
52 f = reader.read()
53 data = self.loop.run_until_complete(f)
54 self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
55 writer.close()
56 self.assertEqual(messages, [])
57
58 def test_open_connection(self):
59 with test_utils.run_test_server() as httpd:
60 conn_fut = asyncio.open_connection(*httpd.address)
61 self._basetest_open_connection(conn_fut)
62
63 @socket_helper.skip_unless_bind_unix_socket
64 def test_open_unix_connection(self):
65 with test_utils.run_test_unix_server() as httpd:
66 conn_fut = asyncio.open_unix_connection(httpd.address)
67 self._basetest_open_connection(conn_fut)
68
69 def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
70 messages = []
71 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
72 try:
73 reader, writer = self.loop.run_until_complete(open_connection_fut)
74 finally:
75 asyncio.set_event_loop(None)
76 writer.write(b'GET / HTTP/1.0\r\n\r\n')
77 f = reader.read()
78 data = self.loop.run_until_complete(f)
79 self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
80
81 writer.close()
82 self.assertEqual(messages, [])
83
84 @unittest.skipIf(ssl is None, 'No ssl module')
85 def test_open_connection_no_loop_ssl(self):
86 with test_utils.run_test_server(use_ssl=True) as httpd:
87 conn_fut = asyncio.open_connection(
88 *httpd.address,
89 ssl=test_utils.dummy_ssl_context())
90
91 self._basetest_open_connection_no_loop_ssl(conn_fut)
92
93 @socket_helper.skip_unless_bind_unix_socket
94 @unittest.skipIf(ssl is None, 'No ssl module')
95 def test_open_unix_connection_no_loop_ssl(self):
96 with test_utils.run_test_unix_server(use_ssl=True) as httpd:
97 conn_fut = asyncio.open_unix_connection(
98 httpd.address,
99 ssl=test_utils.dummy_ssl_context(),
100 server_hostname='',
101 )
102
103 self._basetest_open_connection_no_loop_ssl(conn_fut)
104
105 def _basetest_open_connection_error(self, open_connection_fut):
106 messages = []
107 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
108 reader, writer = self.loop.run_until_complete(open_connection_fut)
109 writer._protocol.connection_lost(ZeroDivisionError())
110 f = reader.read()
111 with self.assertRaises(ZeroDivisionError):
112 self.loop.run_until_complete(f)
113 writer.close()
114 test_utils.run_briefly(self.loop)
115 self.assertEqual(messages, [])
116
117 def test_open_connection_error(self):
118 with test_utils.run_test_server() as httpd:
119 conn_fut = asyncio.open_connection(*httpd.address)
120 self._basetest_open_connection_error(conn_fut)
121
122 @socket_helper.skip_unless_bind_unix_socket
123 def test_open_unix_connection_error(self):
124 with test_utils.run_test_unix_server() as httpd:
125 conn_fut = asyncio.open_unix_connection(httpd.address)
126 self._basetest_open_connection_error(conn_fut)
127
128 def test_feed_empty_data(self):
129 stream = asyncio.StreamReader(loop=self.loop)
130
131 stream.feed_data(b'')
132 self.assertEqual(b'', stream._buffer)
133
134 def test_feed_nonempty_data(self):
135 stream = asyncio.StreamReader(loop=self.loop)
136
137 stream.feed_data(self.DATA)
138 self.assertEqual(self.DATA, stream._buffer)
139
140 def test_read_zero(self):
141 # Read zero bytes.
142 stream = asyncio.StreamReader(loop=self.loop)
143 stream.feed_data(self.DATA)
144
145 data = self.loop.run_until_complete(stream.read(0))
146 self.assertEqual(b'', data)
147 self.assertEqual(self.DATA, stream._buffer)
148
149 def test_read(self):
150 # Read bytes.
151 stream = asyncio.StreamReader(loop=self.loop)
152 read_task = self.loop.create_task(stream.read(30))
153
154 def cb():
155 stream.feed_data(self.DATA)
156 self.loop.call_soon(cb)
157
158 data = self.loop.run_until_complete(read_task)
159 self.assertEqual(self.DATA, data)
160 self.assertEqual(b'', stream._buffer)
161
162 def test_read_line_breaks(self):
163 # Read bytes without line breaks.
164 stream = asyncio.StreamReader(loop=self.loop)
165 stream.feed_data(b'line1')
166 stream.feed_data(b'line2')
167
168 data = self.loop.run_until_complete(stream.read(5))
169
170 self.assertEqual(b'line1', data)
171 self.assertEqual(b'line2', stream._buffer)
172
173 def test_read_eof(self):
174 # Read bytes, stop at eof.
175 stream = asyncio.StreamReader(loop=self.loop)
176 read_task = self.loop.create_task(stream.read(1024))
177
178 def cb():
179 stream.feed_eof()
180 self.loop.call_soon(cb)
181
182 data = self.loop.run_until_complete(read_task)
183 self.assertEqual(b'', data)
184 self.assertEqual(b'', stream._buffer)
185
186 def test_read_until_eof(self):
187 # Read all bytes until eof.
188 stream = asyncio.StreamReader(loop=self.loop)
189 read_task = self.loop.create_task(stream.read(-1))
190
191 def cb():
192 stream.feed_data(b'chunk1\n')
193 stream.feed_data(b'chunk2')
194 stream.feed_eof()
195 self.loop.call_soon(cb)
196
197 data = self.loop.run_until_complete(read_task)
198
199 self.assertEqual(b'chunk1\nchunk2', data)
200 self.assertEqual(b'', stream._buffer)
201
202 def test_read_exception(self):
203 stream = asyncio.StreamReader(loop=self.loop)
204 stream.feed_data(b'line\n')
205
206 data = self.loop.run_until_complete(stream.read(2))
207 self.assertEqual(b'li', data)
208
209 stream.set_exception(ValueError())
210 self.assertRaises(
211 ValueError, self.loop.run_until_complete, stream.read(2))
212
213 def test_invalid_limit(self):
214 with self.assertRaisesRegex(ValueError, 'imit'):
215 asyncio.StreamReader(limit=0, loop=self.loop)
216
217 with self.assertRaisesRegex(ValueError, 'imit'):
218 asyncio.StreamReader(limit=-1, loop=self.loop)
219
220 def test_read_limit(self):
221 stream = asyncio.StreamReader(limit=3, loop=self.loop)
222 stream.feed_data(b'chunk')
223 data = self.loop.run_until_complete(stream.read(5))
224 self.assertEqual(b'chunk', data)
225 self.assertEqual(b'', stream._buffer)
226
227 def test_readline(self):
228 # Read one line. 'readline' will need to wait for the data
229 # to come from 'cb'
230 stream = asyncio.StreamReader(loop=self.loop)
231 stream.feed_data(b'chunk1 ')
232 read_task = self.loop.create_task(stream.readline())
233
234 def cb():
235 stream.feed_data(b'chunk2 ')
236 stream.feed_data(b'chunk3 ')
237 stream.feed_data(b'\n chunk4')
238 self.loop.call_soon(cb)
239
240 line = self.loop.run_until_complete(read_task)
241 self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
242 self.assertEqual(b' chunk4', stream._buffer)
243
244 def test_readline_limit_with_existing_data(self):
245 # Read one line. The data is in StreamReader's buffer
246 # before the event loop is run.
247
248 stream = asyncio.StreamReader(limit=3, loop=self.loop)
249 stream.feed_data(b'li')
250 stream.feed_data(b'ne1\nline2\n')
251
252 self.assertRaises(
253 ValueError, self.loop.run_until_complete, stream.readline())
254 # The buffer should contain the remaining data after exception
255 self.assertEqual(b'line2\n', stream._buffer)
256
257 stream = asyncio.StreamReader(limit=3, loop=self.loop)
258 stream.feed_data(b'li')
259 stream.feed_data(b'ne1')
260 stream.feed_data(b'li')
261
262 self.assertRaises(
263 ValueError, self.loop.run_until_complete, stream.readline())
264 # No b'\n' at the end. The 'limit' is set to 3. So before
265 # waiting for the new data in buffer, 'readline' will consume
266 # the entire buffer, and since the length of the consumed data
267 # is more than 3, it will raise a ValueError. The buffer is
268 # expected to be empty now.
269 self.assertEqual(b'', stream._buffer)
270
271 def test_at_eof(self):
272 stream = asyncio.StreamReader(loop=self.loop)
273 self.assertFalse(stream.at_eof())
274
275 stream.feed_data(b'some data\n')
276 self.assertFalse(stream.at_eof())
277
278 self.loop.run_until_complete(stream.readline())
279 self.assertFalse(stream.at_eof())
280
281 stream.feed_data(b'some data\n')
282 stream.feed_eof()
283 self.loop.run_until_complete(stream.readline())
284 self.assertTrue(stream.at_eof())
285
286 def test_readline_limit(self):
287 # Read one line. StreamReaders are fed with data after
288 # their 'readline' methods are called.
289
290 stream = asyncio.StreamReader(limit=7, loop=self.loop)
291 def cb():
292 stream.feed_data(b'chunk1')
293 stream.feed_data(b'chunk2')
294 stream.feed_data(b'chunk3\n')
295 stream.feed_eof()
296 self.loop.call_soon(cb)
297
298 self.assertRaises(
299 ValueError, self.loop.run_until_complete, stream.readline())
300 # The buffer had just one line of data, and after raising
301 # a ValueError it should be empty.
302 self.assertEqual(b'', stream._buffer)
303
304 stream = asyncio.StreamReader(limit=7, loop=self.loop)
305 def cb():
306 stream.feed_data(b'chunk1')
307 stream.feed_data(b'chunk2\n')
308 stream.feed_data(b'chunk3\n')
309 stream.feed_eof()
310 self.loop.call_soon(cb)
311
312 self.assertRaises(
313 ValueError, self.loop.run_until_complete, stream.readline())
314 self.assertEqual(b'chunk3\n', stream._buffer)
315
316 # check strictness of the limit
317 stream = asyncio.StreamReader(limit=7, loop=self.loop)
318 stream.feed_data(b'1234567\n')
319 line = self.loop.run_until_complete(stream.readline())
320 self.assertEqual(b'1234567\n', line)
321 self.assertEqual(b'', stream._buffer)
322
323 stream.feed_data(b'12345678\n')
324 with self.assertRaises(ValueError) as cm:
325 self.loop.run_until_complete(stream.readline())
326 self.assertEqual(b'', stream._buffer)
327
328 stream.feed_data(b'12345678')
329 with self.assertRaises(ValueError) as cm:
330 self.loop.run_until_complete(stream.readline())
331 self.assertEqual(b'', stream._buffer)
332
333 def test_readline_nolimit_nowait(self):
334 # All needed data for the first 'readline' call will be
335 # in the buffer.
336 stream = asyncio.StreamReader(loop=self.loop)
337 stream.feed_data(self.DATA[:6])
338 stream.feed_data(self.DATA[6:])
339
340 line = self.loop.run_until_complete(stream.readline())
341
342 self.assertEqual(b'line1\n', line)
343 self.assertEqual(b'line2\nline3\n', stream._buffer)
344
345 def test_readline_eof(self):
346 stream = asyncio.StreamReader(loop=self.loop)
347 stream.feed_data(b'some data')
348 stream.feed_eof()
349
350 line = self.loop.run_until_complete(stream.readline())
351 self.assertEqual(b'some data', line)
352
353 def test_readline_empty_eof(self):
354 stream = asyncio.StreamReader(loop=self.loop)
355 stream.feed_eof()
356
357 line = self.loop.run_until_complete(stream.readline())
358 self.assertEqual(b'', line)
359
360 def test_readline_read_byte_count(self):
361 stream = asyncio.StreamReader(loop=self.loop)
362 stream.feed_data(self.DATA)
363
364 self.loop.run_until_complete(stream.readline())
365
366 data = self.loop.run_until_complete(stream.read(7))
367
368 self.assertEqual(b'line2\nl', data)
369 self.assertEqual(b'ine3\n', stream._buffer)
370
371 def test_readline_exception(self):
372 stream = asyncio.StreamReader(loop=self.loop)
373 stream.feed_data(b'line\n')
374
375 data = self.loop.run_until_complete(stream.readline())
376 self.assertEqual(b'line\n', data)
377
378 stream.set_exception(ValueError())
379 self.assertRaises(
380 ValueError, self.loop.run_until_complete, stream.readline())
381 self.assertEqual(b'', stream._buffer)
382
383 def test_readuntil_separator(self):
384 stream = asyncio.StreamReader(loop=self.loop)
385 with self.assertRaisesRegex(ValueError, 'Separator should be'):
386 self.loop.run_until_complete(stream.readuntil(separator=b''))
387
388 def test_readuntil_multi_chunks(self):
389 stream = asyncio.StreamReader(loop=self.loop)
390
391 stream.feed_data(b'lineAAA')
392 data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
393 self.assertEqual(b'lineAAA', data)
394 self.assertEqual(b'', stream._buffer)
395
396 stream.feed_data(b'lineAAA')
397 data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
398 self.assertEqual(b'lineAAA', data)
399 self.assertEqual(b'', stream._buffer)
400
401 stream.feed_data(b'lineAAAxxx')
402 data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
403 self.assertEqual(b'lineAAA', data)
404 self.assertEqual(b'xxx', stream._buffer)
405
406 def test_readuntil_multi_chunks_1(self):
407 stream = asyncio.StreamReader(loop=self.loop)
408
409 stream.feed_data(b'QWEaa')
410 stream.feed_data(b'XYaa')
411 stream.feed_data(b'a')
412 data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
413 self.assertEqual(b'QWEaaXYaaa', data)
414 self.assertEqual(b'', stream._buffer)
415
416 stream.feed_data(b'QWEaa')
417 stream.feed_data(b'XYa')
418 stream.feed_data(b'aa')
419 data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
420 self.assertEqual(b'QWEaaXYaaa', data)
421 self.assertEqual(b'', stream._buffer)
422
423 stream.feed_data(b'aaa')
424 data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
425 self.assertEqual(b'aaa', data)
426 self.assertEqual(b'', stream._buffer)
427
428 stream.feed_data(b'Xaaa')
429 data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
430 self.assertEqual(b'Xaaa', data)
431 self.assertEqual(b'', stream._buffer)
432
433 stream.feed_data(b'XXX')
434 stream.feed_data(b'a')
435 stream.feed_data(b'a')
436 stream.feed_data(b'a')
437 data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
438 self.assertEqual(b'XXXaaa', data)
439 self.assertEqual(b'', stream._buffer)
440
441 def test_readuntil_eof(self):
442 stream = asyncio.StreamReader(loop=self.loop)
443 data = b'some dataAA'
444 stream.feed_data(data)
445 stream.feed_eof()
446
447 with self.assertRaisesRegex(asyncio.IncompleteReadError,
448 'undefined expected bytes') as cm:
449 self.loop.run_until_complete(stream.readuntil(b'AAA'))
450 self.assertEqual(cm.exception.partial, data)
451 self.assertIsNone(cm.exception.expected)
452 self.assertEqual(b'', stream._buffer)
453
454 def test_readuntil_limit_found_sep(self):
455 stream = asyncio.StreamReader(loop=self.loop, limit=3)
456 stream.feed_data(b'some dataAA')
457 with self.assertRaisesRegex(asyncio.LimitOverrunError,
458 'not found') as cm:
459 self.loop.run_until_complete(stream.readuntil(b'AAA'))
460
461 self.assertEqual(b'some dataAA', stream._buffer)
462
463 stream.feed_data(b'A')
464 with self.assertRaisesRegex(asyncio.LimitOverrunError,
465 'is found') as cm:
466 self.loop.run_until_complete(stream.readuntil(b'AAA'))
467
468 self.assertEqual(b'some dataAAA', stream._buffer)
469
470 def test_readexactly_zero_or_less(self):
471 # Read exact number of bytes (zero or less).
472 stream = asyncio.StreamReader(loop=self.loop)
473 stream.feed_data(self.DATA)
474
475 data = self.loop.run_until_complete(stream.readexactly(0))
476 self.assertEqual(b'', data)
477 self.assertEqual(self.DATA, stream._buffer)
478
479 with self.assertRaisesRegex(ValueError, 'less than zero'):
480 self.loop.run_until_complete(stream.readexactly(-1))
481 self.assertEqual(self.DATA, stream._buffer)
482
483 def test_readexactly(self):
484 # Read exact number of bytes.
485 stream = asyncio.StreamReader(loop=self.loop)
486
487 n = 2 * len(self.DATA)
488 read_task = self.loop.create_task(stream.readexactly(n))
489
490 def cb():
491 stream.feed_data(self.DATA)
492 stream.feed_data(self.DATA)
493 stream.feed_data(self.DATA)
494 self.loop.call_soon(cb)
495
496 data = self.loop.run_until_complete(read_task)
497 self.assertEqual(self.DATA + self.DATA, data)
498 self.assertEqual(self.DATA, stream._buffer)
499
500 def test_readexactly_limit(self):
501 stream = asyncio.StreamReader(limit=3, loop=self.loop)
502 stream.feed_data(b'chunk')
503 data = self.loop.run_until_complete(stream.readexactly(5))
504 self.assertEqual(b'chunk', data)
505 self.assertEqual(b'', stream._buffer)
506
507 def test_readexactly_eof(self):
508 # Read exact number of bytes (eof).
509 stream = asyncio.StreamReader(loop=self.loop)
510 n = 2 * len(self.DATA)
511 read_task = self.loop.create_task(stream.readexactly(n))
512
513 def cb():
514 stream.feed_data(self.DATA)
515 stream.feed_eof()
516 self.loop.call_soon(cb)
517
518 with self.assertRaises(asyncio.IncompleteReadError) as cm:
519 self.loop.run_until_complete(read_task)
520 self.assertEqual(cm.exception.partial, self.DATA)
521 self.assertEqual(cm.exception.expected, n)
522 self.assertEqual(str(cm.exception),
523 '18 bytes read on a total of 36 expected bytes')
524 self.assertEqual(b'', stream._buffer)
525
526 def test_readexactly_exception(self):
527 stream = asyncio.StreamReader(loop=self.loop)
528 stream.feed_data(b'line\n')
529
530 data = self.loop.run_until_complete(stream.readexactly(2))
531 self.assertEqual(b'li', data)
532
533 stream.set_exception(ValueError())
534 self.assertRaises(
535 ValueError, self.loop.run_until_complete, stream.readexactly(2))
536
537 def test_exception(self):
538 stream = asyncio.StreamReader(loop=self.loop)
539 self.assertIsNone(stream.exception())
540
541 exc = ValueError()
542 stream.set_exception(exc)
543 self.assertIs(stream.exception(), exc)
544
545 def test_exception_waiter(self):
546 stream = asyncio.StreamReader(loop=self.loop)
547
548 async def set_err():
549 stream.set_exception(ValueError())
550
551 t1 = self.loop.create_task(stream.readline())
552 t2 = self.loop.create_task(set_err())
553
554 self.loop.run_until_complete(asyncio.wait([t1, t2]))
555
556 self.assertRaises(ValueError, t1.result)
557
558 def test_exception_cancel(self):
559 stream = asyncio.StreamReader(loop=self.loop)
560
561 t = self.loop.create_task(stream.readline())
562 test_utils.run_briefly(self.loop)
563 t.cancel()
564 test_utils.run_briefly(self.loop)
565 # The following line fails if set_exception() isn't careful.
566 stream.set_exception(RuntimeError('message'))
567 test_utils.run_briefly(self.loop)
568 self.assertIs(stream._waiter, None)
569
570 def test_start_server(self):
571
572 class ESC[4;38;5;81mMyServer:
573
574 def __init__(self, loop):
575 self.server = None
576 self.loop = loop
577
578 async def handle_client(self, client_reader, client_writer):
579 data = await client_reader.readline()
580 client_writer.write(data)
581 await client_writer.drain()
582 client_writer.close()
583 await client_writer.wait_closed()
584
585 def start(self):
586 sock = socket.create_server(('127.0.0.1', 0))
587 self.server = self.loop.run_until_complete(
588 asyncio.start_server(self.handle_client,
589 sock=sock))
590 return sock.getsockname()
591
592 def handle_client_callback(self, client_reader, client_writer):
593 self.loop.create_task(self.handle_client(client_reader,
594 client_writer))
595
596 def start_callback(self):
597 sock = socket.create_server(('127.0.0.1', 0))
598 addr = sock.getsockname()
599 sock.close()
600 self.server = self.loop.run_until_complete(
601 asyncio.start_server(self.handle_client_callback,
602 host=addr[0], port=addr[1]))
603 return addr
604
605 def stop(self):
606 if self.server is not None:
607 self.server.close()
608 self.loop.run_until_complete(self.server.wait_closed())
609 self.server = None
610
611 async def client(addr):
612 reader, writer = await asyncio.open_connection(*addr)
613 # send a line
614 writer.write(b"hello world!\n")
615 # read it back
616 msgback = await reader.readline()
617 writer.close()
618 await writer.wait_closed()
619 return msgback
620
621 messages = []
622 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
623
624 # test the server variant with a coroutine as client handler
625 server = MyServer(self.loop)
626 addr = server.start()
627 msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
628 server.stop()
629 self.assertEqual(msg, b"hello world!\n")
630
631 # test the server variant with a callback as client handler
632 server = MyServer(self.loop)
633 addr = server.start_callback()
634 msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
635 server.stop()
636 self.assertEqual(msg, b"hello world!\n")
637
638 self.assertEqual(messages, [])
639
640 @socket_helper.skip_unless_bind_unix_socket
641 def test_start_unix_server(self):
642
643 class ESC[4;38;5;81mMyServer:
644
645 def __init__(self, loop, path):
646 self.server = None
647 self.loop = loop
648 self.path = path
649
650 async def handle_client(self, client_reader, client_writer):
651 data = await client_reader.readline()
652 client_writer.write(data)
653 await client_writer.drain()
654 client_writer.close()
655 await client_writer.wait_closed()
656
657 def start(self):
658 self.server = self.loop.run_until_complete(
659 asyncio.start_unix_server(self.handle_client,
660 path=self.path))
661
662 def handle_client_callback(self, client_reader, client_writer):
663 self.loop.create_task(self.handle_client(client_reader,
664 client_writer))
665
666 def start_callback(self):
667 start = asyncio.start_unix_server(self.handle_client_callback,
668 path=self.path)
669 self.server = self.loop.run_until_complete(start)
670
671 def stop(self):
672 if self.server is not None:
673 self.server.close()
674 self.loop.run_until_complete(self.server.wait_closed())
675 self.server = None
676
677 async def client(path):
678 reader, writer = await asyncio.open_unix_connection(path)
679 # send a line
680 writer.write(b"hello world!\n")
681 # read it back
682 msgback = await reader.readline()
683 writer.close()
684 await writer.wait_closed()
685 return msgback
686
687 messages = []
688 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
689
690 # test the server variant with a coroutine as client handler
691 with test_utils.unix_socket_path() as path:
692 server = MyServer(self.loop, path)
693 server.start()
694 msg = self.loop.run_until_complete(
695 self.loop.create_task(client(path)))
696 server.stop()
697 self.assertEqual(msg, b"hello world!\n")
698
699 # test the server variant with a callback as client handler
700 with test_utils.unix_socket_path() as path:
701 server = MyServer(self.loop, path)
702 server.start_callback()
703 msg = self.loop.run_until_complete(
704 self.loop.create_task(client(path)))
705 server.stop()
706 self.assertEqual(msg, b"hello world!\n")
707
708 self.assertEqual(messages, [])
709
710 @unittest.skipIf(ssl is None, 'No ssl module')
711 def test_start_tls(self):
712
713 class ESC[4;38;5;81mMyServer:
714
715 def __init__(self, loop):
716 self.server = None
717 self.loop = loop
718
719 async def handle_client(self, client_reader, client_writer):
720 data1 = await client_reader.readline()
721 client_writer.write(data1)
722 await client_writer.drain()
723 assert client_writer.get_extra_info('sslcontext') is None
724 await client_writer.start_tls(
725 test_utils.simple_server_sslcontext())
726 assert client_writer.get_extra_info('sslcontext') is not None
727 data2 = await client_reader.readline()
728 client_writer.write(data2)
729 await client_writer.drain()
730 client_writer.close()
731 await client_writer.wait_closed()
732
733 def start(self):
734 sock = socket.create_server(('127.0.0.1', 0))
735 self.server = self.loop.run_until_complete(
736 asyncio.start_server(self.handle_client,
737 sock=sock))
738 return sock.getsockname()
739
740 def stop(self):
741 if self.server is not None:
742 self.server.close()
743 self.loop.run_until_complete(self.server.wait_closed())
744 self.server = None
745
746 async def client(addr):
747 reader, writer = await asyncio.open_connection(*addr)
748 writer.write(b"hello world 1!\n")
749 await writer.drain()
750 msgback1 = await reader.readline()
751 assert writer.get_extra_info('sslcontext') is None
752 await writer.start_tls(test_utils.simple_client_sslcontext())
753 assert writer.get_extra_info('sslcontext') is not None
754 writer.write(b"hello world 2!\n")
755 await writer.drain()
756 msgback2 = await reader.readline()
757 writer.close()
758 await writer.wait_closed()
759 return msgback1, msgback2
760
761 messages = []
762 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
763
764 server = MyServer(self.loop)
765 addr = server.start()
766 msg1, msg2 = self.loop.run_until_complete(client(addr))
767 server.stop()
768
769 self.assertEqual(messages, [])
770 self.assertEqual(msg1, b"hello world 1!\n")
771 self.assertEqual(msg2, b"hello world 2!\n")
772
773 @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
774 def test_read_all_from_pipe_reader(self):
775 # See asyncio issue 168. This test is derived from the example
776 # subprocess_attach_read_pipe.py, but we configure the
777 # StreamReader's limit so that twice it is less than the size
778 # of the data writer. Also we must explicitly attach a child
779 # watcher to the event loop.
780
781 code = """\
782 import os, sys
783 fd = int(sys.argv[1])
784 os.write(fd, b'data')
785 os.close(fd)
786 """
787 rfd, wfd = os.pipe()
788 args = [sys.executable, '-c', code, str(wfd)]
789
790 pipe = open(rfd, 'rb', 0)
791 reader = asyncio.StreamReader(loop=self.loop, limit=1)
792 protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
793 transport, _ = self.loop.run_until_complete(
794 self.loop.connect_read_pipe(lambda: protocol, pipe))
795 with warnings.catch_warnings():
796 warnings.simplefilter('ignore', DeprecationWarning)
797 watcher = asyncio.SafeChildWatcher()
798 watcher.attach_loop(self.loop)
799 try:
800 with warnings.catch_warnings():
801 warnings.simplefilter('ignore', DeprecationWarning)
802 asyncio.set_child_watcher(watcher)
803 create = asyncio.create_subprocess_exec(
804 *args,
805 pass_fds={wfd},
806 )
807 proc = self.loop.run_until_complete(create)
808 self.loop.run_until_complete(proc.wait())
809 finally:
810 with warnings.catch_warnings():
811 warnings.simplefilter('ignore', DeprecationWarning)
812 asyncio.set_child_watcher(None)
813
814 os.close(wfd)
815 data = self.loop.run_until_complete(reader.read(-1))
816 self.assertEqual(data, b'data')
817
818 def test_streamreader_constructor_without_loop(self):
819 with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
820 asyncio.StreamReader()
821
822 def test_streamreader_constructor_use_running_loop(self):
823 # asyncio issue #184: Ensure that StreamReaderProtocol constructor
824 # retrieves the current loop if the loop parameter is not set
825 async def test():
826 return asyncio.StreamReader()
827
828 reader = self.loop.run_until_complete(test())
829 self.assertIs(reader._loop, self.loop)
830
831 def test_streamreader_constructor_use_global_loop(self):
832 # asyncio issue #184: Ensure that StreamReaderProtocol constructor
833 # retrieves the current loop if the loop parameter is not set
834 # Deprecated in 3.10, undeprecated in 3.12
835 self.addCleanup(asyncio.set_event_loop, None)
836 asyncio.set_event_loop(self.loop)
837 reader = asyncio.StreamReader()
838 self.assertIs(reader._loop, self.loop)
839
840
841 def test_streamreaderprotocol_constructor_without_loop(self):
842 reader = mock.Mock()
843 with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
844 asyncio.StreamReaderProtocol(reader)
845
846 def test_streamreaderprotocol_constructor_use_running_loop(self):
847 # asyncio issue #184: Ensure that StreamReaderProtocol constructor
848 # retrieves the current loop if the loop parameter is not set
849 reader = mock.Mock()
850 async def test():
851 return asyncio.StreamReaderProtocol(reader)
852 protocol = self.loop.run_until_complete(test())
853 self.assertIs(protocol._loop, self.loop)
854
855 def test_streamreaderprotocol_constructor_use_global_loop(self):
856 # asyncio issue #184: Ensure that StreamReaderProtocol constructor
857 # retrieves the current loop if the loop parameter is not set
858 # Deprecated in 3.10, undeprecated in 3.12
859 self.addCleanup(asyncio.set_event_loop, None)
860 asyncio.set_event_loop(self.loop)
861 reader = mock.Mock()
862 protocol = asyncio.StreamReaderProtocol(reader)
863 self.assertIs(protocol._loop, self.loop)
864
865 def test_multiple_drain(self):
866 # See https://github.com/python/cpython/issues/74116
867 drained = 0
868
869 async def drainer(stream):
870 nonlocal drained
871 await stream._drain_helper()
872 drained += 1
873
874 async def main():
875 loop = asyncio.get_running_loop()
876 stream = asyncio.streams.FlowControlMixin(loop)
877 stream.pause_writing()
878 loop.call_later(0.1, stream.resume_writing)
879 await asyncio.gather(*[drainer(stream) for _ in range(10)])
880 self.assertEqual(drained, 10)
881
882 self.loop.run_until_complete(main())
883
884 def test_drain_raises(self):
885 # See http://bugs.python.org/issue25441
886
887 # This test should not use asyncio for the mock server; the
888 # whole point of the test is to test for a bug in drain()
889 # where it never gives up the event loop but the socket is
890 # closed on the server side.
891
892 messages = []
893 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
894 q = queue.Queue()
895
896 def server():
897 # Runs in a separate thread.
898 with socket.create_server(('localhost', 0)) as sock:
899 addr = sock.getsockname()
900 q.put(addr)
901 clt, _ = sock.accept()
902 clt.close()
903
904 async def client(host, port):
905 reader, writer = await asyncio.open_connection(host, port)
906
907 while True:
908 writer.write(b"foo\n")
909 await writer.drain()
910
911 # Start the server thread and wait for it to be listening.
912 thread = threading.Thread(target=server)
913 thread.daemon = True
914 thread.start()
915 addr = q.get()
916
917 # Should not be stuck in an infinite loop.
918 with self.assertRaises((ConnectionResetError, ConnectionAbortedError,
919 BrokenPipeError)):
920 self.loop.run_until_complete(client(*addr))
921
922 # Clean up the thread. (Only on success; on failure, it may
923 # be stuck in accept().)
924 thread.join()
925 self.assertEqual([], messages)
926
927 def test___repr__(self):
928 stream = asyncio.StreamReader(loop=self.loop)
929 self.assertEqual("<StreamReader>", repr(stream))
930
931 def test___repr__nondefault_limit(self):
932 stream = asyncio.StreamReader(loop=self.loop, limit=123)
933 self.assertEqual("<StreamReader limit=123>", repr(stream))
934
935 def test___repr__eof(self):
936 stream = asyncio.StreamReader(loop=self.loop)
937 stream.feed_eof()
938 self.assertEqual("<StreamReader eof>", repr(stream))
939
940 def test___repr__data(self):
941 stream = asyncio.StreamReader(loop=self.loop)
942 stream.feed_data(b'data')
943 self.assertEqual("<StreamReader 4 bytes>", repr(stream))
944
945 def test___repr__exception(self):
946 stream = asyncio.StreamReader(loop=self.loop)
947 exc = RuntimeError()
948 stream.set_exception(exc)
949 self.assertEqual("<StreamReader exception=RuntimeError()>",
950 repr(stream))
951
952 def test___repr__waiter(self):
953 stream = asyncio.StreamReader(loop=self.loop)
954 stream._waiter = asyncio.Future(loop=self.loop)
955 self.assertRegex(
956 repr(stream),
957 r"<StreamReader waiter=<Future pending[\S ]*>>")
958 stream._waiter.set_result(None)
959 self.loop.run_until_complete(stream._waiter)
960 stream._waiter = None
961 self.assertEqual("<StreamReader>", repr(stream))
962
963 def test___repr__transport(self):
964 stream = asyncio.StreamReader(loop=self.loop)
965 stream._transport = mock.Mock()
966 stream._transport.__repr__ = mock.Mock()
967 stream._transport.__repr__.return_value = "<Transport>"
968 self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
969
970 def test_IncompleteReadError_pickleable(self):
971 e = asyncio.IncompleteReadError(b'abc', 10)
972 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
973 with self.subTest(pickle_protocol=proto):
974 e2 = pickle.loads(pickle.dumps(e, protocol=proto))
975 self.assertEqual(str(e), str(e2))
976 self.assertEqual(e.partial, e2.partial)
977 self.assertEqual(e.expected, e2.expected)
978
979 def test_LimitOverrunError_pickleable(self):
980 e = asyncio.LimitOverrunError('message', 10)
981 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
982 with self.subTest(pickle_protocol=proto):
983 e2 = pickle.loads(pickle.dumps(e, protocol=proto))
984 self.assertEqual(str(e), str(e2))
985 self.assertEqual(e.consumed, e2.consumed)
986
987 def test_wait_closed_on_close(self):
988 with test_utils.run_test_server() as httpd:
989 rd, wr = self.loop.run_until_complete(
990 asyncio.open_connection(*httpd.address))
991
992 wr.write(b'GET / HTTP/1.0\r\n\r\n')
993 f = rd.readline()
994 data = self.loop.run_until_complete(f)
995 self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
996 f = rd.read()
997 data = self.loop.run_until_complete(f)
998 self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
999 self.assertFalse(wr.is_closing())
1000 wr.close()
1001 self.assertTrue(wr.is_closing())
1002 self.loop.run_until_complete(wr.wait_closed())
1003
1004 def test_wait_closed_on_close_with_unread_data(self):
1005 with test_utils.run_test_server() as httpd:
1006 rd, wr = self.loop.run_until_complete(
1007 asyncio.open_connection(*httpd.address))
1008
1009 wr.write(b'GET / HTTP/1.0\r\n\r\n')
1010 f = rd.readline()
1011 data = self.loop.run_until_complete(f)
1012 self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
1013 wr.close()
1014 self.loop.run_until_complete(wr.wait_closed())
1015
1016 def test_async_writer_api(self):
1017 async def inner(httpd):
1018 rd, wr = await asyncio.open_connection(*httpd.address)
1019
1020 wr.write(b'GET / HTTP/1.0\r\n\r\n')
1021 data = await rd.readline()
1022 self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
1023 data = await rd.read()
1024 self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
1025 wr.close()
1026 await wr.wait_closed()
1027
1028 messages = []
1029 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1030
1031 with test_utils.run_test_server() as httpd:
1032 self.loop.run_until_complete(inner(httpd))
1033
1034 self.assertEqual(messages, [])
1035
1036 def test_async_writer_api_exception_after_close(self):
1037 async def inner(httpd):
1038 rd, wr = await asyncio.open_connection(*httpd.address)
1039
1040 wr.write(b'GET / HTTP/1.0\r\n\r\n')
1041 data = await rd.readline()
1042 self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
1043 data = await rd.read()
1044 self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
1045 wr.close()
1046 with self.assertRaises(ConnectionResetError):
1047 wr.write(b'data')
1048 await wr.drain()
1049
1050 messages = []
1051 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1052
1053 with test_utils.run_test_server() as httpd:
1054 self.loop.run_until_complete(inner(httpd))
1055
1056 self.assertEqual(messages, [])
1057
1058 def test_eof_feed_when_closing_writer(self):
1059 # See http://bugs.python.org/issue35065
1060 messages = []
1061 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1062
1063 with test_utils.run_test_server() as httpd:
1064 rd, wr = self.loop.run_until_complete(
1065 asyncio.open_connection(*httpd.address))
1066
1067 wr.close()
1068 f = wr.wait_closed()
1069 self.loop.run_until_complete(f)
1070 self.assertTrue(rd.at_eof())
1071 f = rd.read()
1072 data = self.loop.run_until_complete(f)
1073 self.assertEqual(data, b'')
1074
1075 self.assertEqual(messages, [])
1076
1077
1078 if __name__ == '__main__':
1079 unittest.main()