python (3.12.0)

(root)/
lib/
python3.12/
test/
test_asyncio/
test_streams.py
       1  """Tests for streams.py."""
       2  
       3  import gc
       4  import os
       5  import queue
       6  import pickle
       7  import socket
       8  import sys
       9  import threading
      10  import unittest
      11  from unittest import mock
      12  import warnings
      13  from test.support import socket_helper
      14  try:
      15      import ssl
      16  except ImportError:
      17      ssl = None
      18  
      19  import asyncio
      20  from test.test_asyncio import utils as test_utils
      21  
      22  
      23  def tearDownModule():
      24      asyncio.set_event_loop_policy(None)
      25  
      26  
      27  class ESC[4;38;5;81mStreamTests(ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      28  
      29      DATA = b'line1\nline2\nline3\n'
      30  
      31      def setUp(self):
      32          super().setUp()
      33          self.loop = asyncio.new_event_loop()
      34          self.set_event_loop(self.loop)
      35  
      36      def tearDown(self):
      37          # just in case if we have transport close callbacks
      38          test_utils.run_briefly(self.loop)
      39  
      40          self.loop.close()
      41          gc.collect()
      42          super().tearDown()
      43  
      44      def _basetest_open_connection(self, open_connection_fut):
      45          messages = []
      46          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
      47          reader, writer = self.loop.run_until_complete(open_connection_fut)
      48          writer.write(b'GET / HTTP/1.0\r\n\r\n')
      49          f = reader.readline()
      50          data = self.loop.run_until_complete(f)
      51          self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
      52          f = reader.read()
      53          data = self.loop.run_until_complete(f)
      54          self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
      55          writer.close()
      56          self.assertEqual(messages, [])
      57  
      58      def test_open_connection(self):
      59          with test_utils.run_test_server() as httpd:
      60              conn_fut = asyncio.open_connection(*httpd.address)
      61              self._basetest_open_connection(conn_fut)
      62  
      63      @socket_helper.skip_unless_bind_unix_socket
      64      def test_open_unix_connection(self):
      65          with test_utils.run_test_unix_server() as httpd:
      66              conn_fut = asyncio.open_unix_connection(httpd.address)
      67              self._basetest_open_connection(conn_fut)
      68  
      69      def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
      70          messages = []
      71          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
      72          try:
      73              reader, writer = self.loop.run_until_complete(open_connection_fut)
      74          finally:
      75              asyncio.set_event_loop(None)
      76          writer.write(b'GET / HTTP/1.0\r\n\r\n')
      77          f = reader.read()
      78          data = self.loop.run_until_complete(f)
      79          self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
      80  
      81          writer.close()
      82          self.assertEqual(messages, [])
      83  
      84      @unittest.skipIf(ssl is None, 'No ssl module')
      85      def test_open_connection_no_loop_ssl(self):
      86          with test_utils.run_test_server(use_ssl=True) as httpd:
      87              conn_fut = asyncio.open_connection(
      88                  *httpd.address,
      89                  ssl=test_utils.dummy_ssl_context())
      90  
      91              self._basetest_open_connection_no_loop_ssl(conn_fut)
      92  
      93      @socket_helper.skip_unless_bind_unix_socket
      94      @unittest.skipIf(ssl is None, 'No ssl module')
      95      def test_open_unix_connection_no_loop_ssl(self):
      96          with test_utils.run_test_unix_server(use_ssl=True) as httpd:
      97              conn_fut = asyncio.open_unix_connection(
      98                  httpd.address,
      99                  ssl=test_utils.dummy_ssl_context(),
     100                  server_hostname='',
     101              )
     102  
     103              self._basetest_open_connection_no_loop_ssl(conn_fut)
     104  
     105      def _basetest_open_connection_error(self, open_connection_fut):
     106          messages = []
     107          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     108          reader, writer = self.loop.run_until_complete(open_connection_fut)
     109          writer._protocol.connection_lost(ZeroDivisionError())
     110          f = reader.read()
     111          with self.assertRaises(ZeroDivisionError):
     112              self.loop.run_until_complete(f)
     113          writer.close()
     114          test_utils.run_briefly(self.loop)
     115          self.assertEqual(messages, [])
     116  
     117      def test_open_connection_error(self):
     118          with test_utils.run_test_server() as httpd:
     119              conn_fut = asyncio.open_connection(*httpd.address)
     120              self._basetest_open_connection_error(conn_fut)
     121  
     122      @socket_helper.skip_unless_bind_unix_socket
     123      def test_open_unix_connection_error(self):
     124          with test_utils.run_test_unix_server() as httpd:
     125              conn_fut = asyncio.open_unix_connection(httpd.address)
     126              self._basetest_open_connection_error(conn_fut)
     127  
     128      def test_feed_empty_data(self):
     129          stream = asyncio.StreamReader(loop=self.loop)
     130  
     131          stream.feed_data(b'')
     132          self.assertEqual(b'', stream._buffer)
     133  
     134      def test_feed_nonempty_data(self):
     135          stream = asyncio.StreamReader(loop=self.loop)
     136  
     137          stream.feed_data(self.DATA)
     138          self.assertEqual(self.DATA, stream._buffer)
     139  
     140      def test_read_zero(self):
     141          # Read zero bytes.
     142          stream = asyncio.StreamReader(loop=self.loop)
     143          stream.feed_data(self.DATA)
     144  
     145          data = self.loop.run_until_complete(stream.read(0))
     146          self.assertEqual(b'', data)
     147          self.assertEqual(self.DATA, stream._buffer)
     148  
     149      def test_read(self):
     150          # Read bytes.
     151          stream = asyncio.StreamReader(loop=self.loop)
     152          read_task = self.loop.create_task(stream.read(30))
     153  
     154          def cb():
     155              stream.feed_data(self.DATA)
     156          self.loop.call_soon(cb)
     157  
     158          data = self.loop.run_until_complete(read_task)
     159          self.assertEqual(self.DATA, data)
     160          self.assertEqual(b'', stream._buffer)
     161  
     162      def test_read_line_breaks(self):
     163          # Read bytes without line breaks.
     164          stream = asyncio.StreamReader(loop=self.loop)
     165          stream.feed_data(b'line1')
     166          stream.feed_data(b'line2')
     167  
     168          data = self.loop.run_until_complete(stream.read(5))
     169  
     170          self.assertEqual(b'line1', data)
     171          self.assertEqual(b'line2', stream._buffer)
     172  
     173      def test_read_eof(self):
     174          # Read bytes, stop at eof.
     175          stream = asyncio.StreamReader(loop=self.loop)
     176          read_task = self.loop.create_task(stream.read(1024))
     177  
     178          def cb():
     179              stream.feed_eof()
     180          self.loop.call_soon(cb)
     181  
     182          data = self.loop.run_until_complete(read_task)
     183          self.assertEqual(b'', data)
     184          self.assertEqual(b'', stream._buffer)
     185  
     186      def test_read_until_eof(self):
     187          # Read all bytes until eof.
     188          stream = asyncio.StreamReader(loop=self.loop)
     189          read_task = self.loop.create_task(stream.read(-1))
     190  
     191          def cb():
     192              stream.feed_data(b'chunk1\n')
     193              stream.feed_data(b'chunk2')
     194              stream.feed_eof()
     195          self.loop.call_soon(cb)
     196  
     197          data = self.loop.run_until_complete(read_task)
     198  
     199          self.assertEqual(b'chunk1\nchunk2', data)
     200          self.assertEqual(b'', stream._buffer)
     201  
     202      def test_read_exception(self):
     203          stream = asyncio.StreamReader(loop=self.loop)
     204          stream.feed_data(b'line\n')
     205  
     206          data = self.loop.run_until_complete(stream.read(2))
     207          self.assertEqual(b'li', data)
     208  
     209          stream.set_exception(ValueError())
     210          self.assertRaises(
     211              ValueError, self.loop.run_until_complete, stream.read(2))
     212  
     213      def test_invalid_limit(self):
     214          with self.assertRaisesRegex(ValueError, 'imit'):
     215              asyncio.StreamReader(limit=0, loop=self.loop)
     216  
     217          with self.assertRaisesRegex(ValueError, 'imit'):
     218              asyncio.StreamReader(limit=-1, loop=self.loop)
     219  
     220      def test_read_limit(self):
     221          stream = asyncio.StreamReader(limit=3, loop=self.loop)
     222          stream.feed_data(b'chunk')
     223          data = self.loop.run_until_complete(stream.read(5))
     224          self.assertEqual(b'chunk', data)
     225          self.assertEqual(b'', stream._buffer)
     226  
     227      def test_readline(self):
     228          # Read one line. 'readline' will need to wait for the data
     229          # to come from 'cb'
     230          stream = asyncio.StreamReader(loop=self.loop)
     231          stream.feed_data(b'chunk1 ')
     232          read_task = self.loop.create_task(stream.readline())
     233  
     234          def cb():
     235              stream.feed_data(b'chunk2 ')
     236              stream.feed_data(b'chunk3 ')
     237              stream.feed_data(b'\n chunk4')
     238          self.loop.call_soon(cb)
     239  
     240          line = self.loop.run_until_complete(read_task)
     241          self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
     242          self.assertEqual(b' chunk4', stream._buffer)
     243  
     244      def test_readline_limit_with_existing_data(self):
     245          # Read one line. The data is in StreamReader's buffer
     246          # before the event loop is run.
     247  
     248          stream = asyncio.StreamReader(limit=3, loop=self.loop)
     249          stream.feed_data(b'li')
     250          stream.feed_data(b'ne1\nline2\n')
     251  
     252          self.assertRaises(
     253              ValueError, self.loop.run_until_complete, stream.readline())
     254          # The buffer should contain the remaining data after exception
     255          self.assertEqual(b'line2\n', stream._buffer)
     256  
     257          stream = asyncio.StreamReader(limit=3, loop=self.loop)
     258          stream.feed_data(b'li')
     259          stream.feed_data(b'ne1')
     260          stream.feed_data(b'li')
     261  
     262          self.assertRaises(
     263              ValueError, self.loop.run_until_complete, stream.readline())
     264          # No b'\n' at the end. The 'limit' is set to 3. So before
     265          # waiting for the new data in buffer, 'readline' will consume
     266          # the entire buffer, and since the length of the consumed data
     267          # is more than 3, it will raise a ValueError. The buffer is
     268          # expected to be empty now.
     269          self.assertEqual(b'', stream._buffer)
     270  
     271      def test_at_eof(self):
     272          stream = asyncio.StreamReader(loop=self.loop)
     273          self.assertFalse(stream.at_eof())
     274  
     275          stream.feed_data(b'some data\n')
     276          self.assertFalse(stream.at_eof())
     277  
     278          self.loop.run_until_complete(stream.readline())
     279          self.assertFalse(stream.at_eof())
     280  
     281          stream.feed_data(b'some data\n')
     282          stream.feed_eof()
     283          self.loop.run_until_complete(stream.readline())
     284          self.assertTrue(stream.at_eof())
     285  
     286      def test_readline_limit(self):
     287          # Read one line. StreamReaders are fed with data after
     288          # their 'readline' methods are called.
     289  
     290          stream = asyncio.StreamReader(limit=7, loop=self.loop)
     291          def cb():
     292              stream.feed_data(b'chunk1')
     293              stream.feed_data(b'chunk2')
     294              stream.feed_data(b'chunk3\n')
     295              stream.feed_eof()
     296          self.loop.call_soon(cb)
     297  
     298          self.assertRaises(
     299              ValueError, self.loop.run_until_complete, stream.readline())
     300          # The buffer had just one line of data, and after raising
     301          # a ValueError it should be empty.
     302          self.assertEqual(b'', stream._buffer)
     303  
     304          stream = asyncio.StreamReader(limit=7, loop=self.loop)
     305          def cb():
     306              stream.feed_data(b'chunk1')
     307              stream.feed_data(b'chunk2\n')
     308              stream.feed_data(b'chunk3\n')
     309              stream.feed_eof()
     310          self.loop.call_soon(cb)
     311  
     312          self.assertRaises(
     313              ValueError, self.loop.run_until_complete, stream.readline())
     314          self.assertEqual(b'chunk3\n', stream._buffer)
     315  
     316          # check strictness of the limit
     317          stream = asyncio.StreamReader(limit=7, loop=self.loop)
     318          stream.feed_data(b'1234567\n')
     319          line = self.loop.run_until_complete(stream.readline())
     320          self.assertEqual(b'1234567\n', line)
     321          self.assertEqual(b'', stream._buffer)
     322  
     323          stream.feed_data(b'12345678\n')
     324          with self.assertRaises(ValueError) as cm:
     325              self.loop.run_until_complete(stream.readline())
     326          self.assertEqual(b'', stream._buffer)
     327  
     328          stream.feed_data(b'12345678')
     329          with self.assertRaises(ValueError) as cm:
     330              self.loop.run_until_complete(stream.readline())
     331          self.assertEqual(b'', stream._buffer)
     332  
     333      def test_readline_nolimit_nowait(self):
     334          # All needed data for the first 'readline' call will be
     335          # in the buffer.
     336          stream = asyncio.StreamReader(loop=self.loop)
     337          stream.feed_data(self.DATA[:6])
     338          stream.feed_data(self.DATA[6:])
     339  
     340          line = self.loop.run_until_complete(stream.readline())
     341  
     342          self.assertEqual(b'line1\n', line)
     343          self.assertEqual(b'line2\nline3\n', stream._buffer)
     344  
     345      def test_readline_eof(self):
     346          stream = asyncio.StreamReader(loop=self.loop)
     347          stream.feed_data(b'some data')
     348          stream.feed_eof()
     349  
     350          line = self.loop.run_until_complete(stream.readline())
     351          self.assertEqual(b'some data', line)
     352  
     353      def test_readline_empty_eof(self):
     354          stream = asyncio.StreamReader(loop=self.loop)
     355          stream.feed_eof()
     356  
     357          line = self.loop.run_until_complete(stream.readline())
     358          self.assertEqual(b'', line)
     359  
     360      def test_readline_read_byte_count(self):
     361          stream = asyncio.StreamReader(loop=self.loop)
     362          stream.feed_data(self.DATA)
     363  
     364          self.loop.run_until_complete(stream.readline())
     365  
     366          data = self.loop.run_until_complete(stream.read(7))
     367  
     368          self.assertEqual(b'line2\nl', data)
     369          self.assertEqual(b'ine3\n', stream._buffer)
     370  
     371      def test_readline_exception(self):
     372          stream = asyncio.StreamReader(loop=self.loop)
     373          stream.feed_data(b'line\n')
     374  
     375          data = self.loop.run_until_complete(stream.readline())
     376          self.assertEqual(b'line\n', data)
     377  
     378          stream.set_exception(ValueError())
     379          self.assertRaises(
     380              ValueError, self.loop.run_until_complete, stream.readline())
     381          self.assertEqual(b'', stream._buffer)
     382  
     383      def test_readuntil_separator(self):
     384          stream = asyncio.StreamReader(loop=self.loop)
     385          with self.assertRaisesRegex(ValueError, 'Separator should be'):
     386              self.loop.run_until_complete(stream.readuntil(separator=b''))
     387  
     388      def test_readuntil_multi_chunks(self):
     389          stream = asyncio.StreamReader(loop=self.loop)
     390  
     391          stream.feed_data(b'lineAAA')
     392          data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
     393          self.assertEqual(b'lineAAA', data)
     394          self.assertEqual(b'', stream._buffer)
     395  
     396          stream.feed_data(b'lineAAA')
     397          data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
     398          self.assertEqual(b'lineAAA', data)
     399          self.assertEqual(b'', stream._buffer)
     400  
     401          stream.feed_data(b'lineAAAxxx')
     402          data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
     403          self.assertEqual(b'lineAAA', data)
     404          self.assertEqual(b'xxx', stream._buffer)
     405  
     406      def test_readuntil_multi_chunks_1(self):
     407          stream = asyncio.StreamReader(loop=self.loop)
     408  
     409          stream.feed_data(b'QWEaa')
     410          stream.feed_data(b'XYaa')
     411          stream.feed_data(b'a')
     412          data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
     413          self.assertEqual(b'QWEaaXYaaa', data)
     414          self.assertEqual(b'', stream._buffer)
     415  
     416          stream.feed_data(b'QWEaa')
     417          stream.feed_data(b'XYa')
     418          stream.feed_data(b'aa')
     419          data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
     420          self.assertEqual(b'QWEaaXYaaa', data)
     421          self.assertEqual(b'', stream._buffer)
     422  
     423          stream.feed_data(b'aaa')
     424          data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
     425          self.assertEqual(b'aaa', data)
     426          self.assertEqual(b'', stream._buffer)
     427  
     428          stream.feed_data(b'Xaaa')
     429          data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
     430          self.assertEqual(b'Xaaa', data)
     431          self.assertEqual(b'', stream._buffer)
     432  
     433          stream.feed_data(b'XXX')
     434          stream.feed_data(b'a')
     435          stream.feed_data(b'a')
     436          stream.feed_data(b'a')
     437          data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
     438          self.assertEqual(b'XXXaaa', data)
     439          self.assertEqual(b'', stream._buffer)
     440  
     441      def test_readuntil_eof(self):
     442          stream = asyncio.StreamReader(loop=self.loop)
     443          data = b'some dataAA'
     444          stream.feed_data(data)
     445          stream.feed_eof()
     446  
     447          with self.assertRaisesRegex(asyncio.IncompleteReadError,
     448                                      'undefined expected bytes') as cm:
     449              self.loop.run_until_complete(stream.readuntil(b'AAA'))
     450          self.assertEqual(cm.exception.partial, data)
     451          self.assertIsNone(cm.exception.expected)
     452          self.assertEqual(b'', stream._buffer)
     453  
     454      def test_readuntil_limit_found_sep(self):
     455          stream = asyncio.StreamReader(loop=self.loop, limit=3)
     456          stream.feed_data(b'some dataAA')
     457          with self.assertRaisesRegex(asyncio.LimitOverrunError,
     458                                      'not found') as cm:
     459              self.loop.run_until_complete(stream.readuntil(b'AAA'))
     460  
     461          self.assertEqual(b'some dataAA', stream._buffer)
     462  
     463          stream.feed_data(b'A')
     464          with self.assertRaisesRegex(asyncio.LimitOverrunError,
     465                                      'is found') as cm:
     466              self.loop.run_until_complete(stream.readuntil(b'AAA'))
     467  
     468          self.assertEqual(b'some dataAAA', stream._buffer)
     469  
     470      def test_readexactly_zero_or_less(self):
     471          # Read exact number of bytes (zero or less).
     472          stream = asyncio.StreamReader(loop=self.loop)
     473          stream.feed_data(self.DATA)
     474  
     475          data = self.loop.run_until_complete(stream.readexactly(0))
     476          self.assertEqual(b'', data)
     477          self.assertEqual(self.DATA, stream._buffer)
     478  
     479          with self.assertRaisesRegex(ValueError, 'less than zero'):
     480              self.loop.run_until_complete(stream.readexactly(-1))
     481          self.assertEqual(self.DATA, stream._buffer)
     482  
     483      def test_readexactly(self):
     484          # Read exact number of bytes.
     485          stream = asyncio.StreamReader(loop=self.loop)
     486  
     487          n = 2 * len(self.DATA)
     488          read_task = self.loop.create_task(stream.readexactly(n))
     489  
     490          def cb():
     491              stream.feed_data(self.DATA)
     492              stream.feed_data(self.DATA)
     493              stream.feed_data(self.DATA)
     494          self.loop.call_soon(cb)
     495  
     496          data = self.loop.run_until_complete(read_task)
     497          self.assertEqual(self.DATA + self.DATA, data)
     498          self.assertEqual(self.DATA, stream._buffer)
     499  
     500      def test_readexactly_limit(self):
     501          stream = asyncio.StreamReader(limit=3, loop=self.loop)
     502          stream.feed_data(b'chunk')
     503          data = self.loop.run_until_complete(stream.readexactly(5))
     504          self.assertEqual(b'chunk', data)
     505          self.assertEqual(b'', stream._buffer)
     506  
     507      def test_readexactly_eof(self):
     508          # Read exact number of bytes (eof).
     509          stream = asyncio.StreamReader(loop=self.loop)
     510          n = 2 * len(self.DATA)
     511          read_task = self.loop.create_task(stream.readexactly(n))
     512  
     513          def cb():
     514              stream.feed_data(self.DATA)
     515              stream.feed_eof()
     516          self.loop.call_soon(cb)
     517  
     518          with self.assertRaises(asyncio.IncompleteReadError) as cm:
     519              self.loop.run_until_complete(read_task)
     520          self.assertEqual(cm.exception.partial, self.DATA)
     521          self.assertEqual(cm.exception.expected, n)
     522          self.assertEqual(str(cm.exception),
     523                           '18 bytes read on a total of 36 expected bytes')
     524          self.assertEqual(b'', stream._buffer)
     525  
     526      def test_readexactly_exception(self):
     527          stream = asyncio.StreamReader(loop=self.loop)
     528          stream.feed_data(b'line\n')
     529  
     530          data = self.loop.run_until_complete(stream.readexactly(2))
     531          self.assertEqual(b'li', data)
     532  
     533          stream.set_exception(ValueError())
     534          self.assertRaises(
     535              ValueError, self.loop.run_until_complete, stream.readexactly(2))
     536  
     537      def test_exception(self):
     538          stream = asyncio.StreamReader(loop=self.loop)
     539          self.assertIsNone(stream.exception())
     540  
     541          exc = ValueError()
     542          stream.set_exception(exc)
     543          self.assertIs(stream.exception(), exc)
     544  
     545      def test_exception_waiter(self):
     546          stream = asyncio.StreamReader(loop=self.loop)
     547  
     548          async def set_err():
     549              stream.set_exception(ValueError())
     550  
     551          t1 = self.loop.create_task(stream.readline())
     552          t2 = self.loop.create_task(set_err())
     553  
     554          self.loop.run_until_complete(asyncio.wait([t1, t2]))
     555  
     556          self.assertRaises(ValueError, t1.result)
     557  
     558      def test_exception_cancel(self):
     559          stream = asyncio.StreamReader(loop=self.loop)
     560  
     561          t = self.loop.create_task(stream.readline())
     562          test_utils.run_briefly(self.loop)
     563          t.cancel()
     564          test_utils.run_briefly(self.loop)
     565          # The following line fails if set_exception() isn't careful.
     566          stream.set_exception(RuntimeError('message'))
     567          test_utils.run_briefly(self.loop)
     568          self.assertIs(stream._waiter, None)
     569  
     570      def test_start_server(self):
     571  
     572          class ESC[4;38;5;81mMyServer:
     573  
     574              def __init__(self, loop):
     575                  self.server = None
     576                  self.loop = loop
     577  
     578              async def handle_client(self, client_reader, client_writer):
     579                  data = await client_reader.readline()
     580                  client_writer.write(data)
     581                  await client_writer.drain()
     582                  client_writer.close()
     583                  await client_writer.wait_closed()
     584  
     585              def start(self):
     586                  sock = socket.create_server(('127.0.0.1', 0))
     587                  self.server = self.loop.run_until_complete(
     588                      asyncio.start_server(self.handle_client,
     589                                           sock=sock))
     590                  return sock.getsockname()
     591  
     592              def handle_client_callback(self, client_reader, client_writer):
     593                  self.loop.create_task(self.handle_client(client_reader,
     594                                                           client_writer))
     595  
     596              def start_callback(self):
     597                  sock = socket.create_server(('127.0.0.1', 0))
     598                  addr = sock.getsockname()
     599                  sock.close()
     600                  self.server = self.loop.run_until_complete(
     601                      asyncio.start_server(self.handle_client_callback,
     602                                           host=addr[0], port=addr[1]))
     603                  return addr
     604  
     605              def stop(self):
     606                  if self.server is not None:
     607                      self.server.close()
     608                      self.loop.run_until_complete(self.server.wait_closed())
     609                      self.server = None
     610  
     611          async def client(addr):
     612              reader, writer = await asyncio.open_connection(*addr)
     613              # send a line
     614              writer.write(b"hello world!\n")
     615              # read it back
     616              msgback = await reader.readline()
     617              writer.close()
     618              await writer.wait_closed()
     619              return msgback
     620  
     621          messages = []
     622          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     623  
     624          # test the server variant with a coroutine as client handler
     625          server = MyServer(self.loop)
     626          addr = server.start()
     627          msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
     628          server.stop()
     629          self.assertEqual(msg, b"hello world!\n")
     630  
     631          # test the server variant with a callback as client handler
     632          server = MyServer(self.loop)
     633          addr = server.start_callback()
     634          msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
     635          server.stop()
     636          self.assertEqual(msg, b"hello world!\n")
     637  
     638          self.assertEqual(messages, [])
     639  
     640      @socket_helper.skip_unless_bind_unix_socket
     641      def test_start_unix_server(self):
     642  
     643          class ESC[4;38;5;81mMyServer:
     644  
     645              def __init__(self, loop, path):
     646                  self.server = None
     647                  self.loop = loop
     648                  self.path = path
     649  
     650              async def handle_client(self, client_reader, client_writer):
     651                  data = await client_reader.readline()
     652                  client_writer.write(data)
     653                  await client_writer.drain()
     654                  client_writer.close()
     655                  await client_writer.wait_closed()
     656  
     657              def start(self):
     658                  self.server = self.loop.run_until_complete(
     659                      asyncio.start_unix_server(self.handle_client,
     660                                                path=self.path))
     661  
     662              def handle_client_callback(self, client_reader, client_writer):
     663                  self.loop.create_task(self.handle_client(client_reader,
     664                                                           client_writer))
     665  
     666              def start_callback(self):
     667                  start = asyncio.start_unix_server(self.handle_client_callback,
     668                                                    path=self.path)
     669                  self.server = self.loop.run_until_complete(start)
     670  
     671              def stop(self):
     672                  if self.server is not None:
     673                      self.server.close()
     674                      self.loop.run_until_complete(self.server.wait_closed())
     675                      self.server = None
     676  
     677          async def client(path):
     678              reader, writer = await asyncio.open_unix_connection(path)
     679              # send a line
     680              writer.write(b"hello world!\n")
     681              # read it back
     682              msgback = await reader.readline()
     683              writer.close()
     684              await writer.wait_closed()
     685              return msgback
     686  
     687          messages = []
     688          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     689  
     690          # test the server variant with a coroutine as client handler
     691          with test_utils.unix_socket_path() as path:
     692              server = MyServer(self.loop, path)
     693              server.start()
     694              msg = self.loop.run_until_complete(
     695                  self.loop.create_task(client(path)))
     696              server.stop()
     697              self.assertEqual(msg, b"hello world!\n")
     698  
     699          # test the server variant with a callback as client handler
     700          with test_utils.unix_socket_path() as path:
     701              server = MyServer(self.loop, path)
     702              server.start_callback()
     703              msg = self.loop.run_until_complete(
     704                  self.loop.create_task(client(path)))
     705              server.stop()
     706              self.assertEqual(msg, b"hello world!\n")
     707  
     708          self.assertEqual(messages, [])
     709  
     710      @unittest.skipIf(ssl is None, 'No ssl module')
     711      def test_start_tls(self):
     712  
     713          class ESC[4;38;5;81mMyServer:
     714  
     715              def __init__(self, loop):
     716                  self.server = None
     717                  self.loop = loop
     718  
     719              async def handle_client(self, client_reader, client_writer):
     720                  data1 = await client_reader.readline()
     721                  client_writer.write(data1)
     722                  await client_writer.drain()
     723                  assert client_writer.get_extra_info('sslcontext') is None
     724                  await client_writer.start_tls(
     725                      test_utils.simple_server_sslcontext())
     726                  assert client_writer.get_extra_info('sslcontext') is not None
     727                  data2 = await client_reader.readline()
     728                  client_writer.write(data2)
     729                  await client_writer.drain()
     730                  client_writer.close()
     731                  await client_writer.wait_closed()
     732  
     733              def start(self):
     734                  sock = socket.create_server(('127.0.0.1', 0))
     735                  self.server = self.loop.run_until_complete(
     736                      asyncio.start_server(self.handle_client,
     737                                           sock=sock))
     738                  return sock.getsockname()
     739  
     740              def stop(self):
     741                  if self.server is not None:
     742                      self.server.close()
     743                      self.loop.run_until_complete(self.server.wait_closed())
     744                      self.server = None
     745  
     746          async def client(addr):
     747              reader, writer = await asyncio.open_connection(*addr)
     748              writer.write(b"hello world 1!\n")
     749              await writer.drain()
     750              msgback1 = await reader.readline()
     751              assert writer.get_extra_info('sslcontext') is None
     752              await writer.start_tls(test_utils.simple_client_sslcontext())
     753              assert writer.get_extra_info('sslcontext') is not None
     754              writer.write(b"hello world 2!\n")
     755              await writer.drain()
     756              msgback2 = await reader.readline()
     757              writer.close()
     758              await writer.wait_closed()
     759              return msgback1, msgback2
     760  
     761          messages = []
     762          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     763  
     764          server = MyServer(self.loop)
     765          addr = server.start()
     766          msg1, msg2 = self.loop.run_until_complete(client(addr))
     767          server.stop()
     768  
     769          self.assertEqual(messages, [])
     770          self.assertEqual(msg1, b"hello world 1!\n")
     771          self.assertEqual(msg2, b"hello world 2!\n")
     772  
     773      @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
     774      def test_read_all_from_pipe_reader(self):
     775          # See asyncio issue 168.  This test is derived from the example
     776          # subprocess_attach_read_pipe.py, but we configure the
     777          # StreamReader's limit so that twice it is less than the size
     778          # of the data writer.  Also we must explicitly attach a child
     779          # watcher to the event loop.
     780  
     781          code = """\
     782  import os, sys
     783  fd = int(sys.argv[1])
     784  os.write(fd, b'data')
     785  os.close(fd)
     786  """
     787          rfd, wfd = os.pipe()
     788          args = [sys.executable, '-c', code, str(wfd)]
     789  
     790          pipe = open(rfd, 'rb', 0)
     791          reader = asyncio.StreamReader(loop=self.loop, limit=1)
     792          protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
     793          transport, _ = self.loop.run_until_complete(
     794              self.loop.connect_read_pipe(lambda: protocol, pipe))
     795          with warnings.catch_warnings():
     796              warnings.simplefilter('ignore', DeprecationWarning)
     797              watcher = asyncio.SafeChildWatcher()
     798          watcher.attach_loop(self.loop)
     799          try:
     800              with warnings.catch_warnings():
     801                  warnings.simplefilter('ignore', DeprecationWarning)
     802                  asyncio.set_child_watcher(watcher)
     803              create = asyncio.create_subprocess_exec(
     804                  *args,
     805                  pass_fds={wfd},
     806              )
     807              proc = self.loop.run_until_complete(create)
     808              self.loop.run_until_complete(proc.wait())
     809          finally:
     810              with warnings.catch_warnings():
     811                  warnings.simplefilter('ignore', DeprecationWarning)
     812                  asyncio.set_child_watcher(None)
     813  
     814          os.close(wfd)
     815          data = self.loop.run_until_complete(reader.read(-1))
     816          self.assertEqual(data, b'data')
     817  
     818      def test_streamreader_constructor_without_loop(self):
     819          with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
     820              asyncio.StreamReader()
     821  
     822      def test_streamreader_constructor_use_running_loop(self):
     823          # asyncio issue #184: Ensure that StreamReaderProtocol constructor
     824          # retrieves the current loop if the loop parameter is not set
     825          async def test():
     826              return asyncio.StreamReader()
     827  
     828          reader = self.loop.run_until_complete(test())
     829          self.assertIs(reader._loop, self.loop)
     830  
     831      def test_streamreader_constructor_use_global_loop(self):
     832          # asyncio issue #184: Ensure that StreamReaderProtocol constructor
     833          # retrieves the current loop if the loop parameter is not set
     834          # Deprecated in 3.10, undeprecated in 3.12
     835          self.addCleanup(asyncio.set_event_loop, None)
     836          asyncio.set_event_loop(self.loop)
     837          reader = asyncio.StreamReader()
     838          self.assertIs(reader._loop, self.loop)
     839  
     840  
     841      def test_streamreaderprotocol_constructor_without_loop(self):
     842          reader = mock.Mock()
     843          with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
     844              asyncio.StreamReaderProtocol(reader)
     845  
     846      def test_streamreaderprotocol_constructor_use_running_loop(self):
     847          # asyncio issue #184: Ensure that StreamReaderProtocol constructor
     848          # retrieves the current loop if the loop parameter is not set
     849          reader = mock.Mock()
     850          async def test():
     851              return asyncio.StreamReaderProtocol(reader)
     852          protocol = self.loop.run_until_complete(test())
     853          self.assertIs(protocol._loop, self.loop)
     854  
     855      def test_streamreaderprotocol_constructor_use_global_loop(self):
     856          # asyncio issue #184: Ensure that StreamReaderProtocol constructor
     857          # retrieves the current loop if the loop parameter is not set
     858          # Deprecated in 3.10, undeprecated in 3.12
     859          self.addCleanup(asyncio.set_event_loop, None)
     860          asyncio.set_event_loop(self.loop)
     861          reader = mock.Mock()
     862          protocol = asyncio.StreamReaderProtocol(reader)
     863          self.assertIs(protocol._loop, self.loop)
     864  
     865      def test_multiple_drain(self):
     866          # See https://github.com/python/cpython/issues/74116
     867          drained = 0
     868  
     869          async def drainer(stream):
     870              nonlocal drained
     871              await stream._drain_helper()
     872              drained += 1
     873  
     874          async def main():
     875              loop = asyncio.get_running_loop()
     876              stream = asyncio.streams.FlowControlMixin(loop)
     877              stream.pause_writing()
     878              loop.call_later(0.1, stream.resume_writing)
     879              await asyncio.gather(*[drainer(stream) for _ in range(10)])
     880              self.assertEqual(drained, 10)
     881  
     882          self.loop.run_until_complete(main())
     883  
     884      def test_drain_raises(self):
     885          # See http://bugs.python.org/issue25441
     886  
     887          # This test should not use asyncio for the mock server; the
     888          # whole point of the test is to test for a bug in drain()
     889          # where it never gives up the event loop but the socket is
     890          # closed on the  server side.
     891  
     892          messages = []
     893          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
     894          q = queue.Queue()
     895  
     896          def server():
     897              # Runs in a separate thread.
     898              with socket.create_server(('localhost', 0)) as sock:
     899                  addr = sock.getsockname()
     900                  q.put(addr)
     901                  clt, _ = sock.accept()
     902                  clt.close()
     903  
     904          async def client(host, port):
     905              reader, writer = await asyncio.open_connection(host, port)
     906  
     907              while True:
     908                  writer.write(b"foo\n")
     909                  await writer.drain()
     910  
     911          # Start the server thread and wait for it to be listening.
     912          thread = threading.Thread(target=server)
     913          thread.daemon = True
     914          thread.start()
     915          addr = q.get()
     916  
     917          # Should not be stuck in an infinite loop.
     918          with self.assertRaises((ConnectionResetError, ConnectionAbortedError,
     919                                  BrokenPipeError)):
     920              self.loop.run_until_complete(client(*addr))
     921  
     922          # Clean up the thread.  (Only on success; on failure, it may
     923          # be stuck in accept().)
     924          thread.join()
     925          self.assertEqual([], messages)
     926  
     927      def test___repr__(self):
     928          stream = asyncio.StreamReader(loop=self.loop)
     929          self.assertEqual("<StreamReader>", repr(stream))
     930  
     931      def test___repr__nondefault_limit(self):
     932          stream = asyncio.StreamReader(loop=self.loop, limit=123)
     933          self.assertEqual("<StreamReader limit=123>", repr(stream))
     934  
     935      def test___repr__eof(self):
     936          stream = asyncio.StreamReader(loop=self.loop)
     937          stream.feed_eof()
     938          self.assertEqual("<StreamReader eof>", repr(stream))
     939  
     940      def test___repr__data(self):
     941          stream = asyncio.StreamReader(loop=self.loop)
     942          stream.feed_data(b'data')
     943          self.assertEqual("<StreamReader 4 bytes>", repr(stream))
     944  
     945      def test___repr__exception(self):
     946          stream = asyncio.StreamReader(loop=self.loop)
     947          exc = RuntimeError()
     948          stream.set_exception(exc)
     949          self.assertEqual("<StreamReader exception=RuntimeError()>",
     950                           repr(stream))
     951  
     952      def test___repr__waiter(self):
     953          stream = asyncio.StreamReader(loop=self.loop)
     954          stream._waiter = asyncio.Future(loop=self.loop)
     955          self.assertRegex(
     956              repr(stream),
     957              r"<StreamReader waiter=<Future pending[\S ]*>>")
     958          stream._waiter.set_result(None)
     959          self.loop.run_until_complete(stream._waiter)
     960          stream._waiter = None
     961          self.assertEqual("<StreamReader>", repr(stream))
     962  
     963      def test___repr__transport(self):
     964          stream = asyncio.StreamReader(loop=self.loop)
     965          stream._transport = mock.Mock()
     966          stream._transport.__repr__ = mock.Mock()
     967          stream._transport.__repr__.return_value = "<Transport>"
     968          self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
     969  
     970      def test_IncompleteReadError_pickleable(self):
     971          e = asyncio.IncompleteReadError(b'abc', 10)
     972          for proto in range(pickle.HIGHEST_PROTOCOL + 1):
     973              with self.subTest(pickle_protocol=proto):
     974                  e2 = pickle.loads(pickle.dumps(e, protocol=proto))
     975                  self.assertEqual(str(e), str(e2))
     976                  self.assertEqual(e.partial, e2.partial)
     977                  self.assertEqual(e.expected, e2.expected)
     978  
     979      def test_LimitOverrunError_pickleable(self):
     980          e = asyncio.LimitOverrunError('message', 10)
     981          for proto in range(pickle.HIGHEST_PROTOCOL + 1):
     982              with self.subTest(pickle_protocol=proto):
     983                  e2 = pickle.loads(pickle.dumps(e, protocol=proto))
     984                  self.assertEqual(str(e), str(e2))
     985                  self.assertEqual(e.consumed, e2.consumed)
     986  
     987      def test_wait_closed_on_close(self):
     988          with test_utils.run_test_server() as httpd:
     989              rd, wr = self.loop.run_until_complete(
     990                  asyncio.open_connection(*httpd.address))
     991  
     992              wr.write(b'GET / HTTP/1.0\r\n\r\n')
     993              f = rd.readline()
     994              data = self.loop.run_until_complete(f)
     995              self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
     996              f = rd.read()
     997              data = self.loop.run_until_complete(f)
     998              self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
     999              self.assertFalse(wr.is_closing())
    1000              wr.close()
    1001              self.assertTrue(wr.is_closing())
    1002              self.loop.run_until_complete(wr.wait_closed())
    1003  
    1004      def test_wait_closed_on_close_with_unread_data(self):
    1005          with test_utils.run_test_server() as httpd:
    1006              rd, wr = self.loop.run_until_complete(
    1007                  asyncio.open_connection(*httpd.address))
    1008  
    1009              wr.write(b'GET / HTTP/1.0\r\n\r\n')
    1010              f = rd.readline()
    1011              data = self.loop.run_until_complete(f)
    1012              self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
    1013              wr.close()
    1014              self.loop.run_until_complete(wr.wait_closed())
    1015  
    1016      def test_async_writer_api(self):
    1017          async def inner(httpd):
    1018              rd, wr = await asyncio.open_connection(*httpd.address)
    1019  
    1020              wr.write(b'GET / HTTP/1.0\r\n\r\n')
    1021              data = await rd.readline()
    1022              self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
    1023              data = await rd.read()
    1024              self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
    1025              wr.close()
    1026              await wr.wait_closed()
    1027  
    1028          messages = []
    1029          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
    1030  
    1031          with test_utils.run_test_server() as httpd:
    1032              self.loop.run_until_complete(inner(httpd))
    1033  
    1034          self.assertEqual(messages, [])
    1035  
    1036      def test_async_writer_api_exception_after_close(self):
    1037          async def inner(httpd):
    1038              rd, wr = await asyncio.open_connection(*httpd.address)
    1039  
    1040              wr.write(b'GET / HTTP/1.0\r\n\r\n')
    1041              data = await rd.readline()
    1042              self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
    1043              data = await rd.read()
    1044              self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
    1045              wr.close()
    1046              with self.assertRaises(ConnectionResetError):
    1047                  wr.write(b'data')
    1048                  await wr.drain()
    1049  
    1050          messages = []
    1051          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
    1052  
    1053          with test_utils.run_test_server() as httpd:
    1054              self.loop.run_until_complete(inner(httpd))
    1055  
    1056          self.assertEqual(messages, [])
    1057  
    1058      def test_eof_feed_when_closing_writer(self):
    1059          # See http://bugs.python.org/issue35065
    1060          messages = []
    1061          self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
    1062  
    1063          with test_utils.run_test_server() as httpd:
    1064              rd, wr = self.loop.run_until_complete(
    1065                      asyncio.open_connection(*httpd.address))
    1066  
    1067              wr.close()
    1068              f = wr.wait_closed()
    1069              self.loop.run_until_complete(f)
    1070              self.assertTrue(rd.at_eof())
    1071              f = rd.read()
    1072              data = self.loop.run_until_complete(f)
    1073              self.assertEqual(data, b'')
    1074  
    1075          self.assertEqual(messages, [])
    1076  
    1077  
    1078  if __name__ == '__main__':
    1079      unittest.main()