1 __all__ = (
2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3 'open_connection', 'start_server')
4
5 import collections
6 import socket
7 import sys
8 import warnings
9 import weakref
10
11 if hasattr(socket, 'AF_UNIX'):
12 __all__ += ('open_unix_connection', 'start_unix_server')
13
14 from . import coroutines
15 from . import events
16 from . import exceptions
17 from . import format_helpers
18 from . import protocols
19 from .log import logger
20 from .tasks import sleep
21
22
23 _DEFAULT_LIMIT = 2 ** 16 # 64 KiB
24
25
26 async def open_connection(host=None, port=None, *,
27 limit=_DEFAULT_LIMIT, **kwds):
28 """A wrapper for create_connection() returning a (reader, writer) pair.
29
30 The reader returned is a StreamReader instance; the writer is a
31 StreamWriter instance.
32
33 The arguments are all the usual arguments to create_connection()
34 except protocol_factory; most common are positional host and port,
35 with various optional keyword arguments following.
36
37 Additional optional keyword arguments are loop (to set the event loop
38 instance to use) and limit (to set the buffer limit passed to the
39 StreamReader).
40
41 (If you want to customize the StreamReader and/or
42 StreamReaderProtocol classes, just copy the code -- there's
43 really nothing special here except some convenience.)
44 """
45 loop = events.get_running_loop()
46 reader = StreamReader(limit=limit, loop=loop)
47 protocol = StreamReaderProtocol(reader, loop=loop)
48 transport, _ = await loop.create_connection(
49 lambda: protocol, host, port, **kwds)
50 writer = StreamWriter(transport, protocol, reader, loop)
51 return reader, writer
52
53
54 async def start_server(client_connected_cb, host=None, port=None, *,
55 limit=_DEFAULT_LIMIT, **kwds):
56 """Start a socket server, call back for each client connected.
57
58 The first parameter, `client_connected_cb`, takes two parameters:
59 client_reader, client_writer. client_reader is a StreamReader
60 object, while client_writer is a StreamWriter object. This
61 parameter can either be a plain callback function or a coroutine;
62 if it is a coroutine, it will be automatically converted into a
63 Task.
64
65 The rest of the arguments are all the usual arguments to
66 loop.create_server() except protocol_factory; most common are
67 positional host and port, with various optional keyword arguments
68 following. The return value is the same as loop.create_server().
69
70 Additional optional keyword argument is limit (to set the buffer
71 limit passed to the StreamReader).
72
73 The return value is the same as loop.create_server(), i.e. a
74 Server object which can be used to stop the service.
75 """
76 loop = events.get_running_loop()
77
78 def factory():
79 reader = StreamReader(limit=limit, loop=loop)
80 protocol = StreamReaderProtocol(reader, client_connected_cb,
81 loop=loop)
82 return protocol
83
84 return await loop.create_server(factory, host, port, **kwds)
85
86
87 if hasattr(socket, 'AF_UNIX'):
88 # UNIX Domain Sockets are supported on this platform
89
90 async def open_unix_connection(path=None, *,
91 limit=_DEFAULT_LIMIT, **kwds):
92 """Similar to `open_connection` but works with UNIX Domain Sockets."""
93 loop = events.get_running_loop()
94
95 reader = StreamReader(limit=limit, loop=loop)
96 protocol = StreamReaderProtocol(reader, loop=loop)
97 transport, _ = await loop.create_unix_connection(
98 lambda: protocol, path, **kwds)
99 writer = StreamWriter(transport, protocol, reader, loop)
100 return reader, writer
101
102 async def start_unix_server(client_connected_cb, path=None, *,
103 limit=_DEFAULT_LIMIT, **kwds):
104 """Similar to `start_server` but works with UNIX Domain Sockets."""
105 loop = events.get_running_loop()
106
107 def factory():
108 reader = StreamReader(limit=limit, loop=loop)
109 protocol = StreamReaderProtocol(reader, client_connected_cb,
110 loop=loop)
111 return protocol
112
113 return await loop.create_unix_server(factory, path, **kwds)
114
115
116 class ESC[4;38;5;81mFlowControlMixin(ESC[4;38;5;149mprotocolsESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
117 """Reusable flow control logic for StreamWriter.drain().
118
119 This implements the protocol methods pause_writing(),
120 resume_writing() and connection_lost(). If the subclass overrides
121 these it must call the super methods.
122
123 StreamWriter.drain() must wait for _drain_helper() coroutine.
124 """
125
126 def __init__(self, loop=None):
127 if loop is None:
128 self._loop = events._get_event_loop(stacklevel=4)
129 else:
130 self._loop = loop
131 self._paused = False
132 self._drain_waiters = collections.deque()
133 self._connection_lost = False
134
135 def pause_writing(self):
136 assert not self._paused
137 self._paused = True
138 if self._loop.get_debug():
139 logger.debug("%r pauses writing", self)
140
141 def resume_writing(self):
142 assert self._paused
143 self._paused = False
144 if self._loop.get_debug():
145 logger.debug("%r resumes writing", self)
146
147 for waiter in self._drain_waiters:
148 if not waiter.done():
149 waiter.set_result(None)
150
151 def connection_lost(self, exc):
152 self._connection_lost = True
153 # Wake up the writer(s) if currently paused.
154 if not self._paused:
155 return
156
157 for waiter in self._drain_waiters:
158 if not waiter.done():
159 if exc is None:
160 waiter.set_result(None)
161 else:
162 waiter.set_exception(exc)
163
164 async def _drain_helper(self):
165 if self._connection_lost:
166 raise ConnectionResetError('Connection lost')
167 if not self._paused:
168 return
169 waiter = self._loop.create_future()
170 self._drain_waiters.append(waiter)
171 try:
172 await waiter
173 finally:
174 self._drain_waiters.remove(waiter)
175
176 def _get_close_waiter(self, stream):
177 raise NotImplementedError
178
179
180 class ESC[4;38;5;81mStreamReaderProtocol(ESC[4;38;5;149mFlowControlMixin, ESC[4;38;5;149mprotocolsESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
181 """Helper class to adapt between Protocol and StreamReader.
182
183 (This is a helper class instead of making StreamReader itself a
184 Protocol subclass, because the StreamReader has other potential
185 uses, and to prevent the user of the StreamReader to accidentally
186 call inappropriate methods of the protocol.)
187 """
188
189 _source_traceback = None
190
191 def __init__(self, stream_reader, client_connected_cb=None, loop=None):
192 super().__init__(loop=loop)
193 if stream_reader is not None:
194 self._stream_reader_wr = weakref.ref(stream_reader)
195 self._source_traceback = stream_reader._source_traceback
196 else:
197 self._stream_reader_wr = None
198 if client_connected_cb is not None:
199 # This is a stream created by the `create_server()` function.
200 # Keep a strong reference to the reader until a connection
201 # is established.
202 self._strong_reader = stream_reader
203 self._reject_connection = False
204 self._stream_writer = None
205 self._task = None
206 self._transport = None
207 self._client_connected_cb = client_connected_cb
208 self._over_ssl = False
209 self._closed = self._loop.create_future()
210
211 @property
212 def _stream_reader(self):
213 if self._stream_reader_wr is None:
214 return None
215 return self._stream_reader_wr()
216
217 def _replace_writer(self, writer):
218 loop = self._loop
219 transport = writer.transport
220 self._stream_writer = writer
221 self._transport = transport
222 self._over_ssl = transport.get_extra_info('sslcontext') is not None
223
224 def connection_made(self, transport):
225 if self._reject_connection:
226 context = {
227 'message': ('An open stream was garbage collected prior to '
228 'establishing network connection; '
229 'call "stream.close()" explicitly.')
230 }
231 if self._source_traceback:
232 context['source_traceback'] = self._source_traceback
233 self._loop.call_exception_handler(context)
234 transport.abort()
235 return
236 self._transport = transport
237 reader = self._stream_reader
238 if reader is not None:
239 reader.set_transport(transport)
240 self._over_ssl = transport.get_extra_info('sslcontext') is not None
241 if self._client_connected_cb is not None:
242 self._stream_writer = StreamWriter(transport, self,
243 reader,
244 self._loop)
245 res = self._client_connected_cb(reader,
246 self._stream_writer)
247 if coroutines.iscoroutine(res):
248 def callback(task):
249 exc = task.exception()
250 if exc is not None:
251 self._loop.call_exception_handler({
252 'message': 'Unhandled exception in client_connected_cb',
253 'exception': exc,
254 'transport': transport,
255 })
256 transport.close()
257
258 self._task = self._loop.create_task(res)
259 self._task.add_done_callback(callback)
260
261 self._strong_reader = None
262
263 def connection_lost(self, exc):
264 reader = self._stream_reader
265 if reader is not None:
266 if exc is None:
267 reader.feed_eof()
268 else:
269 reader.set_exception(exc)
270 if not self._closed.done():
271 if exc is None:
272 self._closed.set_result(None)
273 else:
274 self._closed.set_exception(exc)
275 super().connection_lost(exc)
276 self._stream_reader_wr = None
277 self._stream_writer = None
278 self._task = None
279 self._transport = None
280
281 def data_received(self, data):
282 reader = self._stream_reader
283 if reader is not None:
284 reader.feed_data(data)
285
286 def eof_received(self):
287 reader = self._stream_reader
288 if reader is not None:
289 reader.feed_eof()
290 if self._over_ssl:
291 # Prevent a warning in SSLProtocol.eof_received:
292 # "returning true from eof_received()
293 # has no effect when using ssl"
294 return False
295 return True
296
297 def _get_close_waiter(self, stream):
298 return self._closed
299
300 def __del__(self):
301 # Prevent reports about unhandled exceptions.
302 # Better than self._closed._log_traceback = False hack
303 try:
304 closed = self._closed
305 except AttributeError:
306 pass # failed constructor
307 else:
308 if closed.done() and not closed.cancelled():
309 closed.exception()
310
311
312 class ESC[4;38;5;81mStreamWriter:
313 """Wraps a Transport.
314
315 This exposes write(), writelines(), [can_]write_eof(),
316 get_extra_info() and close(). It adds drain() which returns an
317 optional Future on which you can wait for flow control. It also
318 adds a transport property which references the Transport
319 directly.
320 """
321
322 def __init__(self, transport, protocol, reader, loop):
323 self._transport = transport
324 self._protocol = protocol
325 # drain() expects that the reader has an exception() method
326 assert reader is None or isinstance(reader, StreamReader)
327 self._reader = reader
328 self._loop = loop
329 self._complete_fut = self._loop.create_future()
330 self._complete_fut.set_result(None)
331
332 def __repr__(self):
333 info = [self.__class__.__name__, f'transport={self._transport!r}']
334 if self._reader is not None:
335 info.append(f'reader={self._reader!r}')
336 return '<{}>'.format(' '.join(info))
337
338 @property
339 def transport(self):
340 return self._transport
341
342 def write(self, data):
343 self._transport.write(data)
344
345 def writelines(self, data):
346 self._transport.writelines(data)
347
348 def write_eof(self):
349 return self._transport.write_eof()
350
351 def can_write_eof(self):
352 return self._transport.can_write_eof()
353
354 def close(self):
355 return self._transport.close()
356
357 def is_closing(self):
358 return self._transport.is_closing()
359
360 async def wait_closed(self):
361 await self._protocol._get_close_waiter(self)
362
363 def get_extra_info(self, name, default=None):
364 return self._transport.get_extra_info(name, default)
365
366 async def drain(self):
367 """Flush the write buffer.
368
369 The intended use is to write
370
371 w.write(data)
372 await w.drain()
373 """
374 if self._reader is not None:
375 exc = self._reader.exception()
376 if exc is not None:
377 raise exc
378 if self._transport.is_closing():
379 # Wait for protocol.connection_lost() call
380 # Raise connection closing error if any,
381 # ConnectionResetError otherwise
382 # Yield to the event loop so connection_lost() may be
383 # called. Without this, _drain_helper() would return
384 # immediately, and code that calls
385 # write(...); await drain()
386 # in a loop would never call connection_lost(), so it
387 # would not see an error when the socket is closed.
388 await sleep(0)
389 await self._protocol._drain_helper()
390
391 async def start_tls(self, sslcontext, *,
392 server_hostname=None,
393 ssl_handshake_timeout=None):
394 """Upgrade an existing stream-based connection to TLS."""
395 server_side = self._protocol._client_connected_cb is not None
396 protocol = self._protocol
397 await self.drain()
398 new_transport = await self._loop.start_tls( # type: ignore
399 self._transport, protocol, sslcontext,
400 server_side=server_side, server_hostname=server_hostname,
401 ssl_handshake_timeout=ssl_handshake_timeout)
402 self._transport = new_transport
403 protocol._replace_writer(self)
404
405 def __del__(self):
406 if not self._transport.is_closing():
407 if self._loop.is_closed():
408 warnings.warn("loop is closed", ResourceWarning)
409 else:
410 self.close()
411 warnings.warn(f"unclosed {self!r}", ResourceWarning)
412
413 class ESC[4;38;5;81mStreamReader:
414
415 _source_traceback = None
416
417 def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
418 # The line length limit is a security feature;
419 # it also doubles as half the buffer limit.
420
421 if limit <= 0:
422 raise ValueError('Limit cannot be <= 0')
423
424 self._limit = limit
425 if loop is None:
426 self._loop = events._get_event_loop()
427 else:
428 self._loop = loop
429 self._buffer = bytearray()
430 self._eof = False # Whether we're done.
431 self._waiter = None # A future used by _wait_for_data()
432 self._exception = None
433 self._transport = None
434 self._paused = False
435 if self._loop.get_debug():
436 self._source_traceback = format_helpers.extract_stack(
437 sys._getframe(1))
438
439 def __repr__(self):
440 info = ['StreamReader']
441 if self._buffer:
442 info.append(f'{len(self._buffer)} bytes')
443 if self._eof:
444 info.append('eof')
445 if self._limit != _DEFAULT_LIMIT:
446 info.append(f'limit={self._limit}')
447 if self._waiter:
448 info.append(f'waiter={self._waiter!r}')
449 if self._exception:
450 info.append(f'exception={self._exception!r}')
451 if self._transport:
452 info.append(f'transport={self._transport!r}')
453 if self._paused:
454 info.append('paused')
455 return '<{}>'.format(' '.join(info))
456
457 def exception(self):
458 return self._exception
459
460 def set_exception(self, exc):
461 self._exception = exc
462
463 waiter = self._waiter
464 if waiter is not None:
465 self._waiter = None
466 if not waiter.cancelled():
467 waiter.set_exception(exc)
468
469 def _wakeup_waiter(self):
470 """Wakeup read*() functions waiting for data or EOF."""
471 waiter = self._waiter
472 if waiter is not None:
473 self._waiter = None
474 if not waiter.cancelled():
475 waiter.set_result(None)
476
477 def set_transport(self, transport):
478 assert self._transport is None, 'Transport already set'
479 self._transport = transport
480
481 def _maybe_resume_transport(self):
482 if self._paused and len(self._buffer) <= self._limit:
483 self._paused = False
484 self._transport.resume_reading()
485
486 def feed_eof(self):
487 self._eof = True
488 self._wakeup_waiter()
489
490 def at_eof(self):
491 """Return True if the buffer is empty and 'feed_eof' was called."""
492 return self._eof and not self._buffer
493
494 def feed_data(self, data):
495 assert not self._eof, 'feed_data after feed_eof'
496
497 if not data:
498 return
499
500 self._buffer.extend(data)
501 self._wakeup_waiter()
502
503 if (self._transport is not None and
504 not self._paused and
505 len(self._buffer) > 2 * self._limit):
506 try:
507 self._transport.pause_reading()
508 except NotImplementedError:
509 # The transport can't be paused.
510 # We'll just have to buffer all data.
511 # Forget the transport so we don't keep trying.
512 self._transport = None
513 else:
514 self._paused = True
515
516 async def _wait_for_data(self, func_name):
517 """Wait until feed_data() or feed_eof() is called.
518
519 If stream was paused, automatically resume it.
520 """
521 # StreamReader uses a future to link the protocol feed_data() method
522 # to a read coroutine. Running two read coroutines at the same time
523 # would have an unexpected behaviour. It would not possible to know
524 # which coroutine would get the next data.
525 if self._waiter is not None:
526 raise RuntimeError(
527 f'{func_name}() called while another coroutine is '
528 f'already waiting for incoming data')
529
530 assert not self._eof, '_wait_for_data after EOF'
531
532 # Waiting for data while paused will make deadlock, so prevent it.
533 # This is essential for readexactly(n) for case when n > self._limit.
534 if self._paused:
535 self._paused = False
536 self._transport.resume_reading()
537
538 self._waiter = self._loop.create_future()
539 try:
540 await self._waiter
541 finally:
542 self._waiter = None
543
544 async def readline(self):
545 """Read chunk of data from the stream until newline (b'\n') is found.
546
547 On success, return chunk that ends with newline. If only partial
548 line can be read due to EOF, return incomplete line without
549 terminating newline. When EOF was reached while no bytes read, empty
550 bytes object is returned.
551
552 If limit is reached, ValueError will be raised. In that case, if
553 newline was found, complete line including newline will be removed
554 from internal buffer. Else, internal buffer will be cleared. Limit is
555 compared against part of the line without newline.
556
557 If stream was paused, this function will automatically resume it if
558 needed.
559 """
560 sep = b'\n'
561 seplen = len(sep)
562 try:
563 line = await self.readuntil(sep)
564 except exceptions.IncompleteReadError as e:
565 return e.partial
566 except exceptions.LimitOverrunError as e:
567 if self._buffer.startswith(sep, e.consumed):
568 del self._buffer[:e.consumed + seplen]
569 else:
570 self._buffer.clear()
571 self._maybe_resume_transport()
572 raise ValueError(e.args[0])
573 return line
574
575 async def readuntil(self, separator=b'\n'):
576 """Read data from the stream until ``separator`` is found.
577
578 On success, the data and separator will be removed from the
579 internal buffer (consumed). Returned data will include the
580 separator at the end.
581
582 Configured stream limit is used to check result. Limit sets the
583 maximal length of data that can be returned, not counting the
584 separator.
585
586 If an EOF occurs and the complete separator is still not found,
587 an IncompleteReadError exception will be raised, and the internal
588 buffer will be reset. The IncompleteReadError.partial attribute
589 may contain the separator partially.
590
591 If the data cannot be read because of over limit, a
592 LimitOverrunError exception will be raised, and the data
593 will be left in the internal buffer, so it can be read again.
594 """
595 seplen = len(separator)
596 if seplen == 0:
597 raise ValueError('Separator should be at least one-byte string')
598
599 if self._exception is not None:
600 raise self._exception
601
602 # Consume whole buffer except last bytes, which length is
603 # one less than seplen. Let's check corner cases with
604 # separator='SEPARATOR':
605 # * we have received almost complete separator (without last
606 # byte). i.e buffer='some textSEPARATO'. In this case we
607 # can safely consume len(separator) - 1 bytes.
608 # * last byte of buffer is first byte of separator, i.e.
609 # buffer='abcdefghijklmnopqrS'. We may safely consume
610 # everything except that last byte, but this require to
611 # analyze bytes of buffer that match partial separator.
612 # This is slow and/or require FSM. For this case our
613 # implementation is not optimal, since require rescanning
614 # of data that is known to not belong to separator. In
615 # real world, separator will not be so long to notice
616 # performance problems. Even when reading MIME-encoded
617 # messages :)
618
619 # `offset` is the number of bytes from the beginning of the buffer
620 # where there is no occurrence of `separator`.
621 offset = 0
622
623 # Loop until we find `separator` in the buffer, exceed the buffer size,
624 # or an EOF has happened.
625 while True:
626 buflen = len(self._buffer)
627
628 # Check if we now have enough data in the buffer for `separator` to
629 # fit.
630 if buflen - offset >= seplen:
631 isep = self._buffer.find(separator, offset)
632
633 if isep != -1:
634 # `separator` is in the buffer. `isep` will be used later
635 # to retrieve the data.
636 break
637
638 # see upper comment for explanation.
639 offset = buflen + 1 - seplen
640 if offset > self._limit:
641 raise exceptions.LimitOverrunError(
642 'Separator is not found, and chunk exceed the limit',
643 offset)
644
645 # Complete message (with full separator) may be present in buffer
646 # even when EOF flag is set. This may happen when the last chunk
647 # adds data which makes separator be found. That's why we check for
648 # EOF *ater* inspecting the buffer.
649 if self._eof:
650 chunk = bytes(self._buffer)
651 self._buffer.clear()
652 raise exceptions.IncompleteReadError(chunk, None)
653
654 # _wait_for_data() will resume reading if stream was paused.
655 await self._wait_for_data('readuntil')
656
657 if isep > self._limit:
658 raise exceptions.LimitOverrunError(
659 'Separator is found, but chunk is longer than limit', isep)
660
661 chunk = self._buffer[:isep + seplen]
662 del self._buffer[:isep + seplen]
663 self._maybe_resume_transport()
664 return bytes(chunk)
665
666 async def read(self, n=-1):
667 """Read up to `n` bytes from the stream.
668
669 If `n` is not provided or set to -1,
670 read until EOF, then return all read bytes.
671 If EOF was received and the internal buffer is empty,
672 return an empty bytes object.
673
674 If `n` is 0, return an empty bytes object immediately.
675
676 If `n` is positive, return at most `n` available bytes
677 as soon as at least 1 byte is available in the internal buffer.
678 If EOF is received before any byte is read, return an empty
679 bytes object.
680
681 Returned value is not limited with limit, configured at stream
682 creation.
683
684 If stream was paused, this function will automatically resume it if
685 needed.
686 """
687
688 if self._exception is not None:
689 raise self._exception
690
691 if n == 0:
692 return b''
693
694 if n < 0:
695 # This used to just loop creating a new waiter hoping to
696 # collect everything in self._buffer, but that would
697 # deadlock if the subprocess sends more than self.limit
698 # bytes. So just call self.read(self._limit) until EOF.
699 blocks = []
700 while True:
701 block = await self.read(self._limit)
702 if not block:
703 break
704 blocks.append(block)
705 return b''.join(blocks)
706
707 if not self._buffer and not self._eof:
708 await self._wait_for_data('read')
709
710 # This will work right even if buffer is less than n bytes
711 data = bytes(self._buffer[:n])
712 del self._buffer[:n]
713
714 self._maybe_resume_transport()
715 return data
716
717 async def readexactly(self, n):
718 """Read exactly `n` bytes.
719
720 Raise an IncompleteReadError if EOF is reached before `n` bytes can be
721 read. The IncompleteReadError.partial attribute of the exception will
722 contain the partial read bytes.
723
724 if n is zero, return empty bytes object.
725
726 Returned value is not limited with limit, configured at stream
727 creation.
728
729 If stream was paused, this function will automatically resume it if
730 needed.
731 """
732 if n < 0:
733 raise ValueError('readexactly size can not be less than zero')
734
735 if self._exception is not None:
736 raise self._exception
737
738 if n == 0:
739 return b''
740
741 while len(self._buffer) < n:
742 if self._eof:
743 incomplete = bytes(self._buffer)
744 self._buffer.clear()
745 raise exceptions.IncompleteReadError(incomplete, n)
746
747 await self._wait_for_data('readexactly')
748
749 if len(self._buffer) == n:
750 data = bytes(self._buffer)
751 self._buffer.clear()
752 else:
753 data = bytes(self._buffer[:n])
754 del self._buffer[:n]
755 self._maybe_resume_transport()
756 return data
757
758 def __aiter__(self):
759 return self
760
761 async def __anext__(self):
762 val = await self.readline()
763 if val == b'':
764 raise StopAsyncIteration
765 return val