python (3.12.0)
1 __all__ = (
2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3 'open_connection', 'start_server')
4
5 import collections
6 import socket
7 import sys
8 import weakref
9
10 if hasattr(socket, 'AF_UNIX'):
11 __all__ += ('open_unix_connection', 'start_unix_server')
12
13 from . import coroutines
14 from . import events
15 from . import exceptions
16 from . import format_helpers
17 from . import protocols
18 from .log import logger
19 from .tasks import sleep
20
21
22 _DEFAULT_LIMIT = 2 ** 16 # 64 KiB
23
24
25 async def open_connection(host=None, port=None, *,
26 limit=_DEFAULT_LIMIT, **kwds):
27 """A wrapper for create_connection() returning a (reader, writer) pair.
28
29 The reader returned is a StreamReader instance; the writer is a
30 StreamWriter instance.
31
32 The arguments are all the usual arguments to create_connection()
33 except protocol_factory; most common are positional host and port,
34 with various optional keyword arguments following.
35
36 Additional optional keyword arguments are loop (to set the event loop
37 instance to use) and limit (to set the buffer limit passed to the
38 StreamReader).
39
40 (If you want to customize the StreamReader and/or
41 StreamReaderProtocol classes, just copy the code -- there's
42 really nothing special here except some convenience.)
43 """
44 loop = events.get_running_loop()
45 reader = StreamReader(limit=limit, loop=loop)
46 protocol = StreamReaderProtocol(reader, loop=loop)
47 transport, _ = await loop.create_connection(
48 lambda: protocol, host, port, **kwds)
49 writer = StreamWriter(transport, protocol, reader, loop)
50 return reader, writer
51
52
53 async def start_server(client_connected_cb, host=None, port=None, *,
54 limit=_DEFAULT_LIMIT, **kwds):
55 """Start a socket server, call back for each client connected.
56
57 The first parameter, `client_connected_cb`, takes two parameters:
58 client_reader, client_writer. client_reader is a StreamReader
59 object, while client_writer is a StreamWriter object. This
60 parameter can either be a plain callback function or a coroutine;
61 if it is a coroutine, it will be automatically converted into a
62 Task.
63
64 The rest of the arguments are all the usual arguments to
65 loop.create_server() except protocol_factory; most common are
66 positional host and port, with various optional keyword arguments
67 following. The return value is the same as loop.create_server().
68
69 Additional optional keyword arguments are loop (to set the event loop
70 instance to use) and limit (to set the buffer limit passed to the
71 StreamReader).
72
73 The return value is the same as loop.create_server(), i.e. a
74 Server object which can be used to stop the service.
75 """
76 loop = events.get_running_loop()
77
78 def factory():
79 reader = StreamReader(limit=limit, loop=loop)
80 protocol = StreamReaderProtocol(reader, client_connected_cb,
81 loop=loop)
82 return protocol
83
84 return await loop.create_server(factory, host, port, **kwds)
85
86
87 if hasattr(socket, 'AF_UNIX'):
88 # UNIX Domain Sockets are supported on this platform
89
90 async def open_unix_connection(path=None, *,
91 limit=_DEFAULT_LIMIT, **kwds):
92 """Similar to `open_connection` but works with UNIX Domain Sockets."""
93 loop = events.get_running_loop()
94
95 reader = StreamReader(limit=limit, loop=loop)
96 protocol = StreamReaderProtocol(reader, loop=loop)
97 transport, _ = await loop.create_unix_connection(
98 lambda: protocol, path, **kwds)
99 writer = StreamWriter(transport, protocol, reader, loop)
100 return reader, writer
101
102 async def start_unix_server(client_connected_cb, path=None, *,
103 limit=_DEFAULT_LIMIT, **kwds):
104 """Similar to `start_server` but works with UNIX Domain Sockets."""
105 loop = events.get_running_loop()
106
107 def factory():
108 reader = StreamReader(limit=limit, loop=loop)
109 protocol = StreamReaderProtocol(reader, client_connected_cb,
110 loop=loop)
111 return protocol
112
113 return await loop.create_unix_server(factory, path, **kwds)
114
115
116 class ESC[4;38;5;81mFlowControlMixin(ESC[4;38;5;149mprotocolsESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
117 """Reusable flow control logic for StreamWriter.drain().
118
119 This implements the protocol methods pause_writing(),
120 resume_writing() and connection_lost(). If the subclass overrides
121 these it must call the super methods.
122
123 StreamWriter.drain() must wait for _drain_helper() coroutine.
124 """
125
126 def __init__(self, loop=None):
127 if loop is None:
128 self._loop = events.get_event_loop()
129 else:
130 self._loop = loop
131 self._paused = False
132 self._drain_waiters = collections.deque()
133 self._connection_lost = False
134
135 def pause_writing(self):
136 assert not self._paused
137 self._paused = True
138 if self._loop.get_debug():
139 logger.debug("%r pauses writing", self)
140
141 def resume_writing(self):
142 assert self._paused
143 self._paused = False
144 if self._loop.get_debug():
145 logger.debug("%r resumes writing", self)
146
147 for waiter in self._drain_waiters:
148 if not waiter.done():
149 waiter.set_result(None)
150
151 def connection_lost(self, exc):
152 self._connection_lost = True
153 # Wake up the writer(s) if currently paused.
154 if not self._paused:
155 return
156
157 for waiter in self._drain_waiters:
158 if not waiter.done():
159 if exc is None:
160 waiter.set_result(None)
161 else:
162 waiter.set_exception(exc)
163
164 async def _drain_helper(self):
165 if self._connection_lost:
166 raise ConnectionResetError('Connection lost')
167 if not self._paused:
168 return
169 waiter = self._loop.create_future()
170 self._drain_waiters.append(waiter)
171 try:
172 await waiter
173 finally:
174 self._drain_waiters.remove(waiter)
175
176 def _get_close_waiter(self, stream):
177 raise NotImplementedError
178
179
180 class ESC[4;38;5;81mStreamReaderProtocol(ESC[4;38;5;149mFlowControlMixin, ESC[4;38;5;149mprotocolsESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
181 """Helper class to adapt between Protocol and StreamReader.
182
183 (This is a helper class instead of making StreamReader itself a
184 Protocol subclass, because the StreamReader has other potential
185 uses, and to prevent the user of the StreamReader to accidentally
186 call inappropriate methods of the protocol.)
187 """
188
189 _source_traceback = None
190
191 def __init__(self, stream_reader, client_connected_cb=None, loop=None):
192 super().__init__(loop=loop)
193 if stream_reader is not None:
194 self._stream_reader_wr = weakref.ref(stream_reader)
195 self._source_traceback = stream_reader._source_traceback
196 else:
197 self._stream_reader_wr = None
198 if client_connected_cb is not None:
199 # This is a stream created by the `create_server()` function.
200 # Keep a strong reference to the reader until a connection
201 # is established.
202 self._strong_reader = stream_reader
203 self._reject_connection = False
204 self._stream_writer = None
205 self._task = None
206 self._transport = None
207 self._client_connected_cb = client_connected_cb
208 self._over_ssl = False
209 self._closed = self._loop.create_future()
210
211 @property
212 def _stream_reader(self):
213 if self._stream_reader_wr is None:
214 return None
215 return self._stream_reader_wr()
216
217 def _replace_writer(self, writer):
218 loop = self._loop
219 transport = writer.transport
220 self._stream_writer = writer
221 self._transport = transport
222 self._over_ssl = transport.get_extra_info('sslcontext') is not None
223
224 def connection_made(self, transport):
225 if self._reject_connection:
226 context = {
227 'message': ('An open stream was garbage collected prior to '
228 'establishing network connection; '
229 'call "stream.close()" explicitly.')
230 }
231 if self._source_traceback:
232 context['source_traceback'] = self._source_traceback
233 self._loop.call_exception_handler(context)
234 transport.abort()
235 return
236 self._transport = transport
237 reader = self._stream_reader
238 if reader is not None:
239 reader.set_transport(transport)
240 self._over_ssl = transport.get_extra_info('sslcontext') is not None
241 if self._client_connected_cb is not None:
242 self._stream_writer = StreamWriter(transport, self,
243 reader,
244 self._loop)
245 res = self._client_connected_cb(reader,
246 self._stream_writer)
247 if coroutines.iscoroutine(res):
248 self._task = self._loop.create_task(res)
249 self._strong_reader = None
250
251 def connection_lost(self, exc):
252 reader = self._stream_reader
253 if reader is not None:
254 if exc is None:
255 reader.feed_eof()
256 else:
257 reader.set_exception(exc)
258 if not self._closed.done():
259 if exc is None:
260 self._closed.set_result(None)
261 else:
262 self._closed.set_exception(exc)
263 super().connection_lost(exc)
264 self._stream_reader_wr = None
265 self._stream_writer = None
266 self._task = None
267 self._transport = None
268
269 def data_received(self, data):
270 reader = self._stream_reader
271 if reader is not None:
272 reader.feed_data(data)
273
274 def eof_received(self):
275 reader = self._stream_reader
276 if reader is not None:
277 reader.feed_eof()
278 if self._over_ssl:
279 # Prevent a warning in SSLProtocol.eof_received:
280 # "returning true from eof_received()
281 # has no effect when using ssl"
282 return False
283 return True
284
285 def _get_close_waiter(self, stream):
286 return self._closed
287
288 def __del__(self):
289 # Prevent reports about unhandled exceptions.
290 # Better than self._closed._log_traceback = False hack
291 try:
292 closed = self._closed
293 except AttributeError:
294 pass # failed constructor
295 else:
296 if closed.done() and not closed.cancelled():
297 closed.exception()
298
299
300 class ESC[4;38;5;81mStreamWriter:
301 """Wraps a Transport.
302
303 This exposes write(), writelines(), [can_]write_eof(),
304 get_extra_info() and close(). It adds drain() which returns an
305 optional Future on which you can wait for flow control. It also
306 adds a transport property which references the Transport
307 directly.
308 """
309
310 def __init__(self, transport, protocol, reader, loop):
311 self._transport = transport
312 self._protocol = protocol
313 # drain() expects that the reader has an exception() method
314 assert reader is None or isinstance(reader, StreamReader)
315 self._reader = reader
316 self._loop = loop
317 self._complete_fut = self._loop.create_future()
318 self._complete_fut.set_result(None)
319
320 def __repr__(self):
321 info = [self.__class__.__name__, f'transport={self._transport!r}']
322 if self._reader is not None:
323 info.append(f'reader={self._reader!r}')
324 return '<{}>'.format(' '.join(info))
325
326 @property
327 def transport(self):
328 return self._transport
329
330 def write(self, data):
331 self._transport.write(data)
332
333 def writelines(self, data):
334 self._transport.writelines(data)
335
336 def write_eof(self):
337 return self._transport.write_eof()
338
339 def can_write_eof(self):
340 return self._transport.can_write_eof()
341
342 def close(self):
343 return self._transport.close()
344
345 def is_closing(self):
346 return self._transport.is_closing()
347
348 async def wait_closed(self):
349 await self._protocol._get_close_waiter(self)
350
351 def get_extra_info(self, name, default=None):
352 return self._transport.get_extra_info(name, default)
353
354 async def drain(self):
355 """Flush the write buffer.
356
357 The intended use is to write
358
359 w.write(data)
360 await w.drain()
361 """
362 if self._reader is not None:
363 exc = self._reader.exception()
364 if exc is not None:
365 raise exc
366 if self._transport.is_closing():
367 # Wait for protocol.connection_lost() call
368 # Raise connection closing error if any,
369 # ConnectionResetError otherwise
370 # Yield to the event loop so connection_lost() may be
371 # called. Without this, _drain_helper() would return
372 # immediately, and code that calls
373 # write(...); await drain()
374 # in a loop would never call connection_lost(), so it
375 # would not see an error when the socket is closed.
376 await sleep(0)
377 await self._protocol._drain_helper()
378
379 async def start_tls(self, sslcontext, *,
380 server_hostname=None,
381 ssl_handshake_timeout=None,
382 ssl_shutdown_timeout=None):
383 """Upgrade an existing stream-based connection to TLS."""
384 server_side = self._protocol._client_connected_cb is not None
385 protocol = self._protocol
386 await self.drain()
387 new_transport = await self._loop.start_tls( # type: ignore
388 self._transport, protocol, sslcontext,
389 server_side=server_side, server_hostname=server_hostname,
390 ssl_handshake_timeout=ssl_handshake_timeout,
391 ssl_shutdown_timeout=ssl_shutdown_timeout)
392 self._transport = new_transport
393 protocol._replace_writer(self)
394
395 def __del__(self):
396 if not self._transport.is_closing():
397 self.close()
398
399
400 class ESC[4;38;5;81mStreamReader:
401
402 _source_traceback = None
403
404 def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
405 # The line length limit is a security feature;
406 # it also doubles as half the buffer limit.
407
408 if limit <= 0:
409 raise ValueError('Limit cannot be <= 0')
410
411 self._limit = limit
412 if loop is None:
413 self._loop = events.get_event_loop()
414 else:
415 self._loop = loop
416 self._buffer = bytearray()
417 self._eof = False # Whether we're done.
418 self._waiter = None # A future used by _wait_for_data()
419 self._exception = None
420 self._transport = None
421 self._paused = False
422 if self._loop.get_debug():
423 self._source_traceback = format_helpers.extract_stack(
424 sys._getframe(1))
425
426 def __repr__(self):
427 info = ['StreamReader']
428 if self._buffer:
429 info.append(f'{len(self._buffer)} bytes')
430 if self._eof:
431 info.append('eof')
432 if self._limit != _DEFAULT_LIMIT:
433 info.append(f'limit={self._limit}')
434 if self._waiter:
435 info.append(f'waiter={self._waiter!r}')
436 if self._exception:
437 info.append(f'exception={self._exception!r}')
438 if self._transport:
439 info.append(f'transport={self._transport!r}')
440 if self._paused:
441 info.append('paused')
442 return '<{}>'.format(' '.join(info))
443
444 def exception(self):
445 return self._exception
446
447 def set_exception(self, exc):
448 self._exception = exc
449
450 waiter = self._waiter
451 if waiter is not None:
452 self._waiter = None
453 if not waiter.cancelled():
454 waiter.set_exception(exc)
455
456 def _wakeup_waiter(self):
457 """Wakeup read*() functions waiting for data or EOF."""
458 waiter = self._waiter
459 if waiter is not None:
460 self._waiter = None
461 if not waiter.cancelled():
462 waiter.set_result(None)
463
464 def set_transport(self, transport):
465 assert self._transport is None, 'Transport already set'
466 self._transport = transport
467
468 def _maybe_resume_transport(self):
469 if self._paused and len(self._buffer) <= self._limit:
470 self._paused = False
471 self._transport.resume_reading()
472
473 def feed_eof(self):
474 self._eof = True
475 self._wakeup_waiter()
476
477 def at_eof(self):
478 """Return True if the buffer is empty and 'feed_eof' was called."""
479 return self._eof and not self._buffer
480
481 def feed_data(self, data):
482 assert not self._eof, 'feed_data after feed_eof'
483
484 if not data:
485 return
486
487 self._buffer.extend(data)
488 self._wakeup_waiter()
489
490 if (self._transport is not None and
491 not self._paused and
492 len(self._buffer) > 2 * self._limit):
493 try:
494 self._transport.pause_reading()
495 except NotImplementedError:
496 # The transport can't be paused.
497 # We'll just have to buffer all data.
498 # Forget the transport so we don't keep trying.
499 self._transport = None
500 else:
501 self._paused = True
502
503 async def _wait_for_data(self, func_name):
504 """Wait until feed_data() or feed_eof() is called.
505
506 If stream was paused, automatically resume it.
507 """
508 # StreamReader uses a future to link the protocol feed_data() method
509 # to a read coroutine. Running two read coroutines at the same time
510 # would have an unexpected behaviour. It would not possible to know
511 # which coroutine would get the next data.
512 if self._waiter is not None:
513 raise RuntimeError(
514 f'{func_name}() called while another coroutine is '
515 f'already waiting for incoming data')
516
517 assert not self._eof, '_wait_for_data after EOF'
518
519 # Waiting for data while paused will make deadlock, so prevent it.
520 # This is essential for readexactly(n) for case when n > self._limit.
521 if self._paused:
522 self._paused = False
523 self._transport.resume_reading()
524
525 self._waiter = self._loop.create_future()
526 try:
527 await self._waiter
528 finally:
529 self._waiter = None
530
531 async def readline(self):
532 """Read chunk of data from the stream until newline (b'\n') is found.
533
534 On success, return chunk that ends with newline. If only partial
535 line can be read due to EOF, return incomplete line without
536 terminating newline. When EOF was reached while no bytes read, empty
537 bytes object is returned.
538
539 If limit is reached, ValueError will be raised. In that case, if
540 newline was found, complete line including newline will be removed
541 from internal buffer. Else, internal buffer will be cleared. Limit is
542 compared against part of the line without newline.
543
544 If stream was paused, this function will automatically resume it if
545 needed.
546 """
547 sep = b'\n'
548 seplen = len(sep)
549 try:
550 line = await self.readuntil(sep)
551 except exceptions.IncompleteReadError as e:
552 return e.partial
553 except exceptions.LimitOverrunError as e:
554 if self._buffer.startswith(sep, e.consumed):
555 del self._buffer[:e.consumed + seplen]
556 else:
557 self._buffer.clear()
558 self._maybe_resume_transport()
559 raise ValueError(e.args[0])
560 return line
561
562 async def readuntil(self, separator=b'\n'):
563 """Read data from the stream until ``separator`` is found.
564
565 On success, the data and separator will be removed from the
566 internal buffer (consumed). Returned data will include the
567 separator at the end.
568
569 Configured stream limit is used to check result. Limit sets the
570 maximal length of data that can be returned, not counting the
571 separator.
572
573 If an EOF occurs and the complete separator is still not found,
574 an IncompleteReadError exception will be raised, and the internal
575 buffer will be reset. The IncompleteReadError.partial attribute
576 may contain the separator partially.
577
578 If the data cannot be read because of over limit, a
579 LimitOverrunError exception will be raised, and the data
580 will be left in the internal buffer, so it can be read again.
581 """
582 seplen = len(separator)
583 if seplen == 0:
584 raise ValueError('Separator should be at least one-byte string')
585
586 if self._exception is not None:
587 raise self._exception
588
589 # Consume whole buffer except last bytes, which length is
590 # one less than seplen. Let's check corner cases with
591 # separator='SEPARATOR':
592 # * we have received almost complete separator (without last
593 # byte). i.e buffer='some textSEPARATO'. In this case we
594 # can safely consume len(separator) - 1 bytes.
595 # * last byte of buffer is first byte of separator, i.e.
596 # buffer='abcdefghijklmnopqrS'. We may safely consume
597 # everything except that last byte, but this require to
598 # analyze bytes of buffer that match partial separator.
599 # This is slow and/or require FSM. For this case our
600 # implementation is not optimal, since require rescanning
601 # of data that is known to not belong to separator. In
602 # real world, separator will not be so long to notice
603 # performance problems. Even when reading MIME-encoded
604 # messages :)
605
606 # `offset` is the number of bytes from the beginning of the buffer
607 # where there is no occurrence of `separator`.
608 offset = 0
609
610 # Loop until we find `separator` in the buffer, exceed the buffer size,
611 # or an EOF has happened.
612 while True:
613 buflen = len(self._buffer)
614
615 # Check if we now have enough data in the buffer for `separator` to
616 # fit.
617 if buflen - offset >= seplen:
618 isep = self._buffer.find(separator, offset)
619
620 if isep != -1:
621 # `separator` is in the buffer. `isep` will be used later
622 # to retrieve the data.
623 break
624
625 # see upper comment for explanation.
626 offset = buflen + 1 - seplen
627 if offset > self._limit:
628 raise exceptions.LimitOverrunError(
629 'Separator is not found, and chunk exceed the limit',
630 offset)
631
632 # Complete message (with full separator) may be present in buffer
633 # even when EOF flag is set. This may happen when the last chunk
634 # adds data which makes separator be found. That's why we check for
635 # EOF *ater* inspecting the buffer.
636 if self._eof:
637 chunk = bytes(self._buffer)
638 self._buffer.clear()
639 raise exceptions.IncompleteReadError(chunk, None)
640
641 # _wait_for_data() will resume reading if stream was paused.
642 await self._wait_for_data('readuntil')
643
644 if isep > self._limit:
645 raise exceptions.LimitOverrunError(
646 'Separator is found, but chunk is longer than limit', isep)
647
648 chunk = self._buffer[:isep + seplen]
649 del self._buffer[:isep + seplen]
650 self._maybe_resume_transport()
651 return bytes(chunk)
652
653 async def read(self, n=-1):
654 """Read up to `n` bytes from the stream.
655
656 If `n` is not provided or set to -1,
657 read until EOF, then return all read bytes.
658 If EOF was received and the internal buffer is empty,
659 return an empty bytes object.
660
661 If `n` is 0, return an empty bytes object immediately.
662
663 If `n` is positive, return at most `n` available bytes
664 as soon as at least 1 byte is available in the internal buffer.
665 If EOF is received before any byte is read, return an empty
666 bytes object.
667
668 Returned value is not limited with limit, configured at stream
669 creation.
670
671 If stream was paused, this function will automatically resume it if
672 needed.
673 """
674
675 if self._exception is not None:
676 raise self._exception
677
678 if n == 0:
679 return b''
680
681 if n < 0:
682 # This used to just loop creating a new waiter hoping to
683 # collect everything in self._buffer, but that would
684 # deadlock if the subprocess sends more than self.limit
685 # bytes. So just call self.read(self._limit) until EOF.
686 blocks = []
687 while True:
688 block = await self.read(self._limit)
689 if not block:
690 break
691 blocks.append(block)
692 return b''.join(blocks)
693
694 if not self._buffer and not self._eof:
695 await self._wait_for_data('read')
696
697 # This will work right even if buffer is less than n bytes
698 data = bytes(memoryview(self._buffer)[:n])
699 del self._buffer[:n]
700
701 self._maybe_resume_transport()
702 return data
703
704 async def readexactly(self, n):
705 """Read exactly `n` bytes.
706
707 Raise an IncompleteReadError if EOF is reached before `n` bytes can be
708 read. The IncompleteReadError.partial attribute of the exception will
709 contain the partial read bytes.
710
711 if n is zero, return empty bytes object.
712
713 Returned value is not limited with limit, configured at stream
714 creation.
715
716 If stream was paused, this function will automatically resume it if
717 needed.
718 """
719 if n < 0:
720 raise ValueError('readexactly size can not be less than zero')
721
722 if self._exception is not None:
723 raise self._exception
724
725 if n == 0:
726 return b''
727
728 while len(self._buffer) < n:
729 if self._eof:
730 incomplete = bytes(self._buffer)
731 self._buffer.clear()
732 raise exceptions.IncompleteReadError(incomplete, n)
733
734 await self._wait_for_data('readexactly')
735
736 if len(self._buffer) == n:
737 data = bytes(self._buffer)
738 self._buffer.clear()
739 else:
740 data = bytes(memoryview(self._buffer)[:n])
741 del self._buffer[:n]
742 self._maybe_resume_transport()
743 return data
744
745 def __aiter__(self):
746 return self
747
748 async def __anext__(self):
749 val = await self.readline()
750 if val == b'':
751 raise StopAsyncIteration
752 return val