1__all__ = (
2    'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3    'open_connection', 'start_server',
4    'IncompleteReadError', 'LimitOverrunError',
5)
6
7import socket
8
9if hasattr(socket, 'AF_UNIX'):
10    __all__ += ('open_unix_connection', 'start_unix_server')
11
12from . import coroutines
13from . import events
14from . import protocols
15from .log import logger
16from .tasks import sleep
17
18
19_DEFAULT_LIMIT = 2 ** 16  # 64 KiB
20
21
22class IncompleteReadError(EOFError):
23    """
24    Incomplete read error. Attributes:
25
26    - partial: read bytes string before the end of stream was reached
27    - expected: total number of expected bytes (or None if unknown)
28    """
29    def __init__(self, partial, expected):
30        super().__init__(f'{len(partial)} bytes read on a total of '
31                         f'{expected!r} expected bytes')
32        self.partial = partial
33        self.expected = expected
34
35    def __reduce__(self):
36        return type(self), (self.partial, self.expected)
37
38
39class LimitOverrunError(Exception):
40    """Reached the buffer limit while looking for a separator.
41
42    Attributes:
43    - consumed: total number of to be consumed bytes.
44    """
45    def __init__(self, message, consumed):
46        super().__init__(message)
47        self.consumed = consumed
48
49    def __reduce__(self):
50        return type(self), (self.args[0], self.consumed)
51
52
53async def open_connection(host=None, port=None, *,
54                          loop=None, limit=_DEFAULT_LIMIT, **kwds):
55    """A wrapper for create_connection() returning a (reader, writer) pair.
56
57    The reader returned is a StreamReader instance; the writer is a
58    StreamWriter instance.
59
60    The arguments are all the usual arguments to create_connection()
61    except protocol_factory; most common are positional host and port,
62    with various optional keyword arguments following.
63
64    Additional optional keyword arguments are loop (to set the event loop
65    instance to use) and limit (to set the buffer limit passed to the
66    StreamReader).
67
68    (If you want to customize the StreamReader and/or
69    StreamReaderProtocol classes, just copy the code -- there's
70    really nothing special here except some convenience.)
71    """
72    if loop is None:
73        loop = events.get_event_loop()
74    reader = StreamReader(limit=limit, loop=loop)
75    protocol = StreamReaderProtocol(reader, loop=loop)
76    transport, _ = await loop.create_connection(
77        lambda: protocol, host, port, **kwds)
78    writer = StreamWriter(transport, protocol, reader, loop)
79    return reader, writer
80
81
82async def start_server(client_connected_cb, host=None, port=None, *,
83                       loop=None, limit=_DEFAULT_LIMIT, **kwds):
84    """Start a socket server, call back for each client connected.
85
86    The first parameter, `client_connected_cb`, takes two parameters:
87    client_reader, client_writer.  client_reader is a StreamReader
88    object, while client_writer is a StreamWriter object.  This
89    parameter can either be a plain callback function or a coroutine;
90    if it is a coroutine, it will be automatically converted into a
91    Task.
92
93    The rest of the arguments are all the usual arguments to
94    loop.create_server() except protocol_factory; most common are
95    positional host and port, with various optional keyword arguments
96    following.  The return value is the same as loop.create_server().
97
98    Additional optional keyword arguments are loop (to set the event loop
99    instance to use) and limit (to set the buffer limit passed to the
100    StreamReader).
101
102    The return value is the same as loop.create_server(), i.e. a
103    Server object which can be used to stop the service.
104    """
105    if loop is None:
106        loop = events.get_event_loop()
107
108    def factory():
109        reader = StreamReader(limit=limit, loop=loop)
110        protocol = StreamReaderProtocol(reader, client_connected_cb,
111                                        loop=loop)
112        return protocol
113
114    return await loop.create_server(factory, host, port, **kwds)
115
116
117if hasattr(socket, 'AF_UNIX'):
118    # UNIX Domain Sockets are supported on this platform
119
120    async def open_unix_connection(path=None, *,
121                                   loop=None, limit=_DEFAULT_LIMIT, **kwds):
122        """Similar to `open_connection` but works with UNIX Domain Sockets."""
123        if loop is None:
124            loop = events.get_event_loop()
125        reader = StreamReader(limit=limit, loop=loop)
126        protocol = StreamReaderProtocol(reader, loop=loop)
127        transport, _ = await loop.create_unix_connection(
128            lambda: protocol, path, **kwds)
129        writer = StreamWriter(transport, protocol, reader, loop)
130        return reader, writer
131
132    async def start_unix_server(client_connected_cb, path=None, *,
133                                loop=None, limit=_DEFAULT_LIMIT, **kwds):
134        """Similar to `start_server` but works with UNIX Domain Sockets."""
135        if loop is None:
136            loop = events.get_event_loop()
137
138        def factory():
139            reader = StreamReader(limit=limit, loop=loop)
140            protocol = StreamReaderProtocol(reader, client_connected_cb,
141                                            loop=loop)
142            return protocol
143
144        return await loop.create_unix_server(factory, path, **kwds)
145
146
147class FlowControlMixin(protocols.Protocol):
148    """Reusable flow control logic for StreamWriter.drain().
149
150    This implements the protocol methods pause_writing(),
151    resume_writing() and connection_lost().  If the subclass overrides
152    these it must call the super methods.
153
154    StreamWriter.drain() must wait for _drain_helper() coroutine.
155    """
156
157    def __init__(self, loop=None):
158        if loop is None:
159            self._loop = events.get_event_loop()
160        else:
161            self._loop = loop
162        self._paused = False
163        self._drain_waiter = None
164        self._connection_lost = False
165
166    def pause_writing(self):
167        assert not self._paused
168        self._paused = True
169        if self._loop.get_debug():
170            logger.debug("%r pauses writing", self)
171
172    def resume_writing(self):
173        assert self._paused
174        self._paused = False
175        if self._loop.get_debug():
176            logger.debug("%r resumes writing", self)
177
178        waiter = self._drain_waiter
179        if waiter is not None:
180            self._drain_waiter = None
181            if not waiter.done():
182                waiter.set_result(None)
183
184    def connection_lost(self, exc):
185        self._connection_lost = True
186        # Wake up the writer if currently paused.
187        if not self._paused:
188            return
189        waiter = self._drain_waiter
190        if waiter is None:
191            return
192        self._drain_waiter = None
193        if waiter.done():
194            return
195        if exc is None:
196            waiter.set_result(None)
197        else:
198            waiter.set_exception(exc)
199
200    async def _drain_helper(self):
201        if self._connection_lost:
202            raise ConnectionResetError('Connection lost')
203        if not self._paused:
204            return
205        waiter = self._drain_waiter
206        assert waiter is None or waiter.cancelled()
207        waiter = self._loop.create_future()
208        self._drain_waiter = waiter
209        await waiter
210
211
212class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
213    """Helper class to adapt between Protocol and StreamReader.
214
215    (This is a helper class instead of making StreamReader itself a
216    Protocol subclass, because the StreamReader has other potential
217    uses, and to prevent the user of the StreamReader to accidentally
218    call inappropriate methods of the protocol.)
219    """
220
221    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
222        super().__init__(loop=loop)
223        self._stream_reader = stream_reader
224        self._stream_writer = None
225        self._client_connected_cb = client_connected_cb
226        self._over_ssl = False
227        self._closed = self._loop.create_future()
228
229    def connection_made(self, transport):
230        self._stream_reader.set_transport(transport)
231        self._over_ssl = transport.get_extra_info('sslcontext') is not None
232        if self._client_connected_cb is not None:
233            self._stream_writer = StreamWriter(transport, self,
234                                               self._stream_reader,
235                                               self._loop)
236            res = self._client_connected_cb(self._stream_reader,
237                                            self._stream_writer)
238            if coroutines.iscoroutine(res):
239                self._loop.create_task(res)
240
241    def connection_lost(self, exc):
242        if self._stream_reader is not None:
243            if exc is None:
244                self._stream_reader.feed_eof()
245            else:
246                self._stream_reader.set_exception(exc)
247        if not self._closed.done():
248            if exc is None:
249                self._closed.set_result(None)
250            else:
251                self._closed.set_exception(exc)
252        super().connection_lost(exc)
253        self._stream_reader = None
254        self._stream_writer = None
255
256    def data_received(self, data):
257        self._stream_reader.feed_data(data)
258
259    def eof_received(self):
260        self._stream_reader.feed_eof()
261        if self._over_ssl:
262            # Prevent a warning in SSLProtocol.eof_received:
263            # "returning true from eof_received()
264            # has no effect when using ssl"
265            return False
266        return True
267
268    def __del__(self):
269        # Prevent reports about unhandled exceptions.
270        # Better than self._closed._log_traceback = False hack
271        closed = self._closed
272        if closed.done() and not closed.cancelled():
273            closed.exception()
274
275
276class StreamWriter:
277    """Wraps a Transport.
278
279    This exposes write(), writelines(), [can_]write_eof(),
280    get_extra_info() and close().  It adds drain() which returns an
281    optional Future on which you can wait for flow control.  It also
282    adds a transport property which references the Transport
283    directly.
284    """
285
286    def __init__(self, transport, protocol, reader, loop):
287        self._transport = transport
288        self._protocol = protocol
289        # drain() expects that the reader has an exception() method
290        assert reader is None or isinstance(reader, StreamReader)
291        self._reader = reader
292        self._loop = loop
293
294    def __repr__(self):
295        info = [self.__class__.__name__, f'transport={self._transport!r}']
296        if self._reader is not None:
297            info.append(f'reader={self._reader!r}')
298        return '<{}>'.format(' '.join(info))
299
300    @property
301    def transport(self):
302        return self._transport
303
304    def write(self, data):
305        self._transport.write(data)
306
307    def writelines(self, data):
308        self._transport.writelines(data)
309
310    def write_eof(self):
311        return self._transport.write_eof()
312
313    def can_write_eof(self):
314        return self._transport.can_write_eof()
315
316    def close(self):
317        return self._transport.close()
318
319    def is_closing(self):
320        return self._transport.is_closing()
321
322    async def wait_closed(self):
323        await self._protocol._closed
324
325    def get_extra_info(self, name, default=None):
326        return self._transport.get_extra_info(name, default)
327
328    async def drain(self):
329        """Flush the write buffer.
330
331        The intended use is to write
332
333          w.write(data)
334          await w.drain()
335        """
336        if self._reader is not None:
337            exc = self._reader.exception()
338            if exc is not None:
339                raise exc
340        if self._transport.is_closing():
341            # Yield to the event loop so connection_lost() may be
342            # called.  Without this, _drain_helper() would return
343            # immediately, and code that calls
344            #     write(...); await drain()
345            # in a loop would never call connection_lost(), so it
346            # would not see an error when the socket is closed.
347            await sleep(0, loop=self._loop)
348        await self._protocol._drain_helper()
349
350
351class StreamReader:
352
353    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
354        # The line length limit is  a security feature;
355        # it also doubles as half the buffer limit.
356
357        if limit <= 0:
358            raise ValueError('Limit cannot be <= 0')
359
360        self._limit = limit
361        if loop is None:
362            self._loop = events.get_event_loop()
363        else:
364            self._loop = loop
365        self._buffer = bytearray()
366        self._eof = False    # Whether we're done.
367        self._waiter = None  # A future used by _wait_for_data()
368        self._exception = None
369        self._transport = None
370        self._paused = False
371
372    def __repr__(self):
373        info = ['StreamReader']
374        if self._buffer:
375            info.append(f'{len(self._buffer)} bytes')
376        if self._eof:
377            info.append('eof')
378        if self._limit != _DEFAULT_LIMIT:
379            info.append(f'limit={self._limit}')
380        if self._waiter:
381            info.append(f'waiter={self._waiter!r}')
382        if self._exception:
383            info.append(f'exception={self._exception!r}')
384        if self._transport:
385            info.append(f'transport={self._transport!r}')
386        if self._paused:
387            info.append('paused')
388        return '<{}>'.format(' '.join(info))
389
390    def exception(self):
391        return self._exception
392
393    def set_exception(self, exc):
394        self._exception = exc
395
396        waiter = self._waiter
397        if waiter is not None:
398            self._waiter = None
399            if not waiter.cancelled():
400                waiter.set_exception(exc)
401
402    def _wakeup_waiter(self):
403        """Wakeup read*() functions waiting for data or EOF."""
404        waiter = self._waiter
405        if waiter is not None:
406            self._waiter = None
407            if not waiter.cancelled():
408                waiter.set_result(None)
409
410    def set_transport(self, transport):
411        assert self._transport is None, 'Transport already set'
412        self._transport = transport
413
414    def _maybe_resume_transport(self):
415        if self._paused and len(self._buffer) <= self._limit:
416            self._paused = False
417            self._transport.resume_reading()
418
419    def feed_eof(self):
420        self._eof = True
421        self._wakeup_waiter()
422
423    def at_eof(self):
424        """Return True if the buffer is empty and 'feed_eof' was called."""
425        return self._eof and not self._buffer
426
427    def feed_data(self, data):
428        assert not self._eof, 'feed_data after feed_eof'
429
430        if not data:
431            return
432
433        self._buffer.extend(data)
434        self._wakeup_waiter()
435
436        if (self._transport is not None and
437                not self._paused and
438                len(self._buffer) > 2 * self._limit):
439            try:
440                self._transport.pause_reading()
441            except NotImplementedError:
442                # The transport can't be paused.
443                # We'll just have to buffer all data.
444                # Forget the transport so we don't keep trying.
445                self._transport = None
446            else:
447                self._paused = True
448
449    async def _wait_for_data(self, func_name):
450        """Wait until feed_data() or feed_eof() is called.
451
452        If stream was paused, automatically resume it.
453        """
454        # StreamReader uses a future to link the protocol feed_data() method
455        # to a read coroutine. Running two read coroutines at the same time
456        # would have an unexpected behaviour. It would not possible to know
457        # which coroutine would get the next data.
458        if self._waiter is not None:
459            raise RuntimeError(
460                f'{func_name}() called while another coroutine is '
461                f'already waiting for incoming data')
462
463        assert not self._eof, '_wait_for_data after EOF'
464
465        # Waiting for data while paused will make deadlock, so prevent it.
466        # This is essential for readexactly(n) for case when n > self._limit.
467        if self._paused:
468            self._paused = False
469            self._transport.resume_reading()
470
471        self._waiter = self._loop.create_future()
472        try:
473            await self._waiter
474        finally:
475            self._waiter = None
476
477    async def readline(self):
478        """Read chunk of data from the stream until newline (b'\n') is found.
479
480        On success, return chunk that ends with newline. If only partial
481        line can be read due to EOF, return incomplete line without
482        terminating newline. When EOF was reached while no bytes read, empty
483        bytes object is returned.
484
485        If limit is reached, ValueError will be raised. In that case, if
486        newline was found, complete line including newline will be removed
487        from internal buffer. Else, internal buffer will be cleared. Limit is
488        compared against part of the line without newline.
489
490        If stream was paused, this function will automatically resume it if
491        needed.
492        """
493        sep = b'\n'
494        seplen = len(sep)
495        try:
496            line = await self.readuntil(sep)
497        except IncompleteReadError as e:
498            return e.partial
499        except LimitOverrunError as e:
500            if self._buffer.startswith(sep, e.consumed):
501                del self._buffer[:e.consumed + seplen]
502            else:
503                self._buffer.clear()
504            self._maybe_resume_transport()
505            raise ValueError(e.args[0])
506        return line
507
508    async def readuntil(self, separator=b'\n'):
509        """Read data from the stream until ``separator`` is found.
510
511        On success, the data and separator will be removed from the
512        internal buffer (consumed). Returned data will include the
513        separator at the end.
514
515        Configured stream limit is used to check result. Limit sets the
516        maximal length of data that can be returned, not counting the
517        separator.
518
519        If an EOF occurs and the complete separator is still not found,
520        an IncompleteReadError exception will be raised, and the internal
521        buffer will be reset.  The IncompleteReadError.partial attribute
522        may contain the separator partially.
523
524        If the data cannot be read because of over limit, a
525        LimitOverrunError exception  will be raised, and the data
526        will be left in the internal buffer, so it can be read again.
527        """
528        seplen = len(separator)
529        if seplen == 0:
530            raise ValueError('Separator should be at least one-byte string')
531
532        if self._exception is not None:
533            raise self._exception
534
535        # Consume whole buffer except last bytes, which length is
536        # one less than seplen. Let's check corner cases with
537        # separator='SEPARATOR':
538        # * we have received almost complete separator (without last
539        #   byte). i.e buffer='some textSEPARATO'. In this case we
540        #   can safely consume len(separator) - 1 bytes.
541        # * last byte of buffer is first byte of separator, i.e.
542        #   buffer='abcdefghijklmnopqrS'. We may safely consume
543        #   everything except that last byte, but this require to
544        #   analyze bytes of buffer that match partial separator.
545        #   This is slow and/or require FSM. For this case our
546        #   implementation is not optimal, since require rescanning
547        #   of data that is known to not belong to separator. In
548        #   real world, separator will not be so long to notice
549        #   performance problems. Even when reading MIME-encoded
550        #   messages :)
551
552        # `offset` is the number of bytes from the beginning of the buffer
553        # where there is no occurrence of `separator`.
554        offset = 0
555
556        # Loop until we find `separator` in the buffer, exceed the buffer size,
557        # or an EOF has happened.
558        while True:
559            buflen = len(self._buffer)
560
561            # Check if we now have enough data in the buffer for `separator` to
562            # fit.
563            if buflen - offset >= seplen:
564                isep = self._buffer.find(separator, offset)
565
566                if isep != -1:
567                    # `separator` is in the buffer. `isep` will be used later
568                    # to retrieve the data.
569                    break
570
571                # see upper comment for explanation.
572                offset = buflen + 1 - seplen
573                if offset > self._limit:
574                    raise LimitOverrunError(
575                        'Separator is not found, and chunk exceed the limit',
576                        offset)
577
578            # Complete message (with full separator) may be present in buffer
579            # even when EOF flag is set. This may happen when the last chunk
580            # adds data which makes separator be found. That's why we check for
581            # EOF *ater* inspecting the buffer.
582            if self._eof:
583                chunk = bytes(self._buffer)
584                self._buffer.clear()
585                raise IncompleteReadError(chunk, None)
586
587            # _wait_for_data() will resume reading if stream was paused.
588            await self._wait_for_data('readuntil')
589
590        if isep > self._limit:
591            raise LimitOverrunError(
592                'Separator is found, but chunk is longer than limit', isep)
593
594        chunk = self._buffer[:isep + seplen]
595        del self._buffer[:isep + seplen]
596        self._maybe_resume_transport()
597        return bytes(chunk)
598
599    async def read(self, n=-1):
600        """Read up to `n` bytes from the stream.
601
602        If n is not provided, or set to -1, read until EOF and return all read
603        bytes. If the EOF was received and the internal buffer is empty, return
604        an empty bytes object.
605
606        If n is zero, return empty bytes object immediately.
607
608        If n is positive, this function try to read `n` bytes, and may return
609        less or equal bytes than requested, but at least one byte. If EOF was
610        received before any byte is read, this function returns empty byte
611        object.
612
613        Returned value is not limited with limit, configured at stream
614        creation.
615
616        If stream was paused, this function will automatically resume it if
617        needed.
618        """
619
620        if self._exception is not None:
621            raise self._exception
622
623        if n == 0:
624            return b''
625
626        if n < 0:
627            # This used to just loop creating a new waiter hoping to
628            # collect everything in self._buffer, but that would
629            # deadlock if the subprocess sends more than self.limit
630            # bytes.  So just call self.read(self._limit) until EOF.
631            blocks = []
632            while True:
633                block = await self.read(self._limit)
634                if not block:
635                    break
636                blocks.append(block)
637            return b''.join(blocks)
638
639        if not self._buffer and not self._eof:
640            await self._wait_for_data('read')
641
642        # This will work right even if buffer is less than n bytes
643        data = bytes(self._buffer[:n])
644        del self._buffer[:n]
645
646        self._maybe_resume_transport()
647        return data
648
649    async def readexactly(self, n):
650        """Read exactly `n` bytes.
651
652        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
653        read. The IncompleteReadError.partial attribute of the exception will
654        contain the partial read bytes.
655
656        if n is zero, return empty bytes object.
657
658        Returned value is not limited with limit, configured at stream
659        creation.
660
661        If stream was paused, this function will automatically resume it if
662        needed.
663        """
664        if n < 0:
665            raise ValueError('readexactly size can not be less than zero')
666
667        if self._exception is not None:
668            raise self._exception
669
670        if n == 0:
671            return b''
672
673        while len(self._buffer) < n:
674            if self._eof:
675                incomplete = bytes(self._buffer)
676                self._buffer.clear()
677                raise IncompleteReadError(incomplete, n)
678
679            await self._wait_for_data('readexactly')
680
681        if len(self._buffer) == n:
682            data = bytes(self._buffer)
683            self._buffer.clear()
684        else:
685            data = bytes(self._buffer[:n])
686            del self._buffer[:n]
687        self._maybe_resume_transport()
688        return data
689
690    def __aiter__(self):
691        return self
692
693    async def __anext__(self):
694        val = await self.readline()
695        if val == b'':
696            raise StopAsyncIteration
697        return val
698