1"""Stream-related things."""
2
3__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
4           'open_connection', 'start_server',
5           'IncompleteReadError',
6           'LimitOverrunError',
7           ]
8
9import socket
10
11if hasattr(socket, 'AF_UNIX'):
12    __all__.extend(['open_unix_connection', 'start_unix_server'])
13
14from . import coroutines
15from . import compat
16from . import events
17from . import protocols
18from .coroutines import coroutine
19from .log import logger
20
21
22_DEFAULT_LIMIT = 2 ** 16
23
24
25class IncompleteReadError(EOFError):
26    """
27    Incomplete read error. Attributes:
28
29    - partial: read bytes string before the end of stream was reached
30    - expected: total number of expected bytes (or None if unknown)
31    """
32    def __init__(self, partial, expected):
33        super().__init__("%d bytes read on a total of %r expected bytes"
34                         % (len(partial), expected))
35        self.partial = partial
36        self.expected = expected
37
38
39class LimitOverrunError(Exception):
40    """Reached the buffer limit while looking for a separator.
41
42    Attributes:
43    - consumed: total number of to be consumed bytes.
44    """
45    def __init__(self, message, consumed):
46        super().__init__(message)
47        self.consumed = consumed
48
49
50@coroutine
51def open_connection(host=None, port=None, *,
52                    loop=None, limit=_DEFAULT_LIMIT, **kwds):
53    """A wrapper for create_connection() returning a (reader, writer) pair.
54
55    The reader returned is a StreamReader instance; the writer is a
56    StreamWriter instance.
57
58    The arguments are all the usual arguments to create_connection()
59    except protocol_factory; most common are positional host and port,
60    with various optional keyword arguments following.
61
62    Additional optional keyword arguments are loop (to set the event loop
63    instance to use) and limit (to set the buffer limit passed to the
64    StreamReader).
65
66    (If you want to customize the StreamReader and/or
67    StreamReaderProtocol classes, just copy the code -- there's
68    really nothing special here except some convenience.)
69    """
70    if loop is None:
71        loop = events.get_event_loop()
72    reader = StreamReader(limit=limit, loop=loop)
73    protocol = StreamReaderProtocol(reader, loop=loop)
74    transport, _ = yield from loop.create_connection(
75        lambda: protocol, host, port, **kwds)
76    writer = StreamWriter(transport, protocol, reader, loop)
77    return reader, writer
78
79
80@coroutine
81def start_server(client_connected_cb, host=None, port=None, *,
82                 loop=None, limit=_DEFAULT_LIMIT, **kwds):
83    """Start a socket server, call back for each client connected.
84
85    The first parameter, `client_connected_cb`, takes two parameters:
86    client_reader, client_writer.  client_reader is a StreamReader
87    object, while client_writer is a StreamWriter object.  This
88    parameter can either be a plain callback function or a coroutine;
89    if it is a coroutine, it will be automatically converted into a
90    Task.
91
92    The rest of the arguments are all the usual arguments to
93    loop.create_server() except protocol_factory; most common are
94    positional host and port, with various optional keyword arguments
95    following.  The return value is the same as loop.create_server().
96
97    Additional optional keyword arguments are loop (to set the event loop
98    instance to use) and limit (to set the buffer limit passed to the
99    StreamReader).
100
101    The return value is the same as loop.create_server(), i.e. a
102    Server object which can be used to stop the service.
103    """
104    if loop is None:
105        loop = events.get_event_loop()
106
107    def factory():
108        reader = StreamReader(limit=limit, loop=loop)
109        protocol = StreamReaderProtocol(reader, client_connected_cb,
110                                        loop=loop)
111        return protocol
112
113    return (yield from loop.create_server(factory, host, port, **kwds))
114
115
116if hasattr(socket, 'AF_UNIX'):
117    # UNIX Domain Sockets are supported on this platform
118
119    @coroutine
120    def open_unix_connection(path=None, *,
121                             loop=None, limit=_DEFAULT_LIMIT, **kwds):
122        """Similar to `open_connection` but works with UNIX Domain Sockets."""
123        if loop is None:
124            loop = events.get_event_loop()
125        reader = StreamReader(limit=limit, loop=loop)
126        protocol = StreamReaderProtocol(reader, loop=loop)
127        transport, _ = yield from loop.create_unix_connection(
128            lambda: protocol, path, **kwds)
129        writer = StreamWriter(transport, protocol, reader, loop)
130        return reader, writer
131
132    @coroutine
133    def start_unix_server(client_connected_cb, path=None, *,
134                          loop=None, limit=_DEFAULT_LIMIT, **kwds):
135        """Similar to `start_server` but works with UNIX Domain Sockets."""
136        if loop is None:
137            loop = events.get_event_loop()
138
139        def factory():
140            reader = StreamReader(limit=limit, loop=loop)
141            protocol = StreamReaderProtocol(reader, client_connected_cb,
142                                            loop=loop)
143            return protocol
144
145        return (yield from loop.create_unix_server(factory, path, **kwds))
146
147
148class FlowControlMixin(protocols.Protocol):
149    """Reusable flow control logic for StreamWriter.drain().
150
151    This implements the protocol methods pause_writing(),
152    resume_reading() and connection_lost().  If the subclass overrides
153    these it must call the super methods.
154
155    StreamWriter.drain() must wait for _drain_helper() coroutine.
156    """
157
158    def __init__(self, loop=None):
159        if loop is None:
160            self._loop = events.get_event_loop()
161        else:
162            self._loop = loop
163        self._paused = False
164        self._drain_waiter = None
165        self._connection_lost = False
166
167    def pause_writing(self):
168        assert not self._paused
169        self._paused = True
170        if self._loop.get_debug():
171            logger.debug("%r pauses writing", self)
172
173    def resume_writing(self):
174        assert self._paused
175        self._paused = False
176        if self._loop.get_debug():
177            logger.debug("%r resumes writing", self)
178
179        waiter = self._drain_waiter
180        if waiter is not None:
181            self._drain_waiter = None
182            if not waiter.done():
183                waiter.set_result(None)
184
185    def connection_lost(self, exc):
186        self._connection_lost = True
187        # Wake up the writer if currently paused.
188        if not self._paused:
189            return
190        waiter = self._drain_waiter
191        if waiter is None:
192            return
193        self._drain_waiter = None
194        if waiter.done():
195            return
196        if exc is None:
197            waiter.set_result(None)
198        else:
199            waiter.set_exception(exc)
200
201    @coroutine
202    def _drain_helper(self):
203        if self._connection_lost:
204            raise ConnectionResetError('Connection lost')
205        if not self._paused:
206            return
207        waiter = self._drain_waiter
208        assert waiter is None or waiter.cancelled()
209        waiter = self._loop.create_future()
210        self._drain_waiter = waiter
211        yield from waiter
212
213
214class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
215    """Helper class to adapt between Protocol and StreamReader.
216
217    (This is a helper class instead of making StreamReader itself a
218    Protocol subclass, because the StreamReader has other potential
219    uses, and to prevent the user of the StreamReader to accidentally
220    call inappropriate methods of the protocol.)
221    """
222
223    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
224        super().__init__(loop=loop)
225        self._stream_reader = stream_reader
226        self._stream_writer = None
227        self._client_connected_cb = client_connected_cb
228        self._over_ssl = False
229
230    def connection_made(self, transport):
231        self._stream_reader.set_transport(transport)
232        self._over_ssl = transport.get_extra_info('sslcontext') is not None
233        if self._client_connected_cb is not None:
234            self._stream_writer = StreamWriter(transport, self,
235                                               self._stream_reader,
236                                               self._loop)
237            res = self._client_connected_cb(self._stream_reader,
238                                            self._stream_writer)
239            if coroutines.iscoroutine(res):
240                self._loop.create_task(res)
241
242    def connection_lost(self, exc):
243        if self._stream_reader is not None:
244            if exc is None:
245                self._stream_reader.feed_eof()
246            else:
247                self._stream_reader.set_exception(exc)
248        super().connection_lost(exc)
249        self._stream_reader = None
250        self._stream_writer = None
251
252    def data_received(self, data):
253        self._stream_reader.feed_data(data)
254
255    def eof_received(self):
256        self._stream_reader.feed_eof()
257        if self._over_ssl:
258            # Prevent a warning in SSLProtocol.eof_received:
259            # "returning true from eof_received()
260            # has no effect when using ssl"
261            return False
262        return True
263
264
265class StreamWriter:
266    """Wraps a Transport.
267
268    This exposes write(), writelines(), [can_]write_eof(),
269    get_extra_info() and close().  It adds drain() which returns an
270    optional Future on which you can wait for flow control.  It also
271    adds a transport property which references the Transport
272    directly.
273    """
274
275    def __init__(self, transport, protocol, reader, loop):
276        self._transport = transport
277        self._protocol = protocol
278        # drain() expects that the reader has an exception() method
279        assert reader is None or isinstance(reader, StreamReader)
280        self._reader = reader
281        self._loop = loop
282
283    def __repr__(self):
284        info = [self.__class__.__name__, 'transport=%r' % self._transport]
285        if self._reader is not None:
286            info.append('reader=%r' % self._reader)
287        return '<%s>' % ' '.join(info)
288
289    @property
290    def transport(self):
291        return self._transport
292
293    def write(self, data):
294        self._transport.write(data)
295
296    def writelines(self, data):
297        self._transport.writelines(data)
298
299    def write_eof(self):
300        return self._transport.write_eof()
301
302    def can_write_eof(self):
303        return self._transport.can_write_eof()
304
305    def close(self):
306        return self._transport.close()
307
308    def get_extra_info(self, name, default=None):
309        return self._transport.get_extra_info(name, default)
310
311    @coroutine
312    def drain(self):
313        """Flush the write buffer.
314
315        The intended use is to write
316
317          w.write(data)
318          yield from w.drain()
319        """
320        if self._reader is not None:
321            exc = self._reader.exception()
322            if exc is not None:
323                raise exc
324        if self._transport is not None:
325            if self._transport.is_closing():
326                # Yield to the event loop so connection_lost() may be
327                # called.  Without this, _drain_helper() would return
328                # immediately, and code that calls
329                #     write(...); yield from drain()
330                # in a loop would never call connection_lost(), so it
331                # would not see an error when the socket is closed.
332                yield
333        yield from self._protocol._drain_helper()
334
335
336class StreamReader:
337
338    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
339        # The line length limit is  a security feature;
340        # it also doubles as half the buffer limit.
341
342        if limit <= 0:
343            raise ValueError('Limit cannot be <= 0')
344
345        self._limit = limit
346        if loop is None:
347            self._loop = events.get_event_loop()
348        else:
349            self._loop = loop
350        self._buffer = bytearray()
351        self._eof = False    # Whether we're done.
352        self._waiter = None  # A future used by _wait_for_data()
353        self._exception = None
354        self._transport = None
355        self._paused = False
356
357    def __repr__(self):
358        info = ['StreamReader']
359        if self._buffer:
360            info.append('%d bytes' % len(self._buffer))
361        if self._eof:
362            info.append('eof')
363        if self._limit != _DEFAULT_LIMIT:
364            info.append('l=%d' % self._limit)
365        if self._waiter:
366            info.append('w=%r' % self._waiter)
367        if self._exception:
368            info.append('e=%r' % self._exception)
369        if self._transport:
370            info.append('t=%r' % self._transport)
371        if self._paused:
372            info.append('paused')
373        return '<%s>' % ' '.join(info)
374
375    def exception(self):
376        return self._exception
377
378    def set_exception(self, exc):
379        self._exception = exc
380
381        waiter = self._waiter
382        if waiter is not None:
383            self._waiter = None
384            if not waiter.cancelled():
385                waiter.set_exception(exc)
386
387    def _wakeup_waiter(self):
388        """Wakeup read*() functions waiting for data or EOF."""
389        waiter = self._waiter
390        if waiter is not None:
391            self._waiter = None
392            if not waiter.cancelled():
393                waiter.set_result(None)
394
395    def set_transport(self, transport):
396        assert self._transport is None, 'Transport already set'
397        self._transport = transport
398
399    def _maybe_resume_transport(self):
400        if self._paused and len(self._buffer) <= self._limit:
401            self._paused = False
402            self._transport.resume_reading()
403
404    def feed_eof(self):
405        self._eof = True
406        self._wakeup_waiter()
407
408    def at_eof(self):
409        """Return True if the buffer is empty and 'feed_eof' was called."""
410        return self._eof and not self._buffer
411
412    def feed_data(self, data):
413        assert not self._eof, 'feed_data after feed_eof'
414
415        if not data:
416            return
417
418        self._buffer.extend(data)
419        self._wakeup_waiter()
420
421        if (self._transport is not None and
422                not self._paused and
423                len(self._buffer) > 2 * self._limit):
424            try:
425                self._transport.pause_reading()
426            except NotImplementedError:
427                # The transport can't be paused.
428                # We'll just have to buffer all data.
429                # Forget the transport so we don't keep trying.
430                self._transport = None
431            else:
432                self._paused = True
433
434    @coroutine
435    def _wait_for_data(self, func_name):
436        """Wait until feed_data() or feed_eof() is called.
437
438        If stream was paused, automatically resume it.
439        """
440        # StreamReader uses a future to link the protocol feed_data() method
441        # to a read coroutine. Running two read coroutines at the same time
442        # would have an unexpected behaviour. It would not possible to know
443        # which coroutine would get the next data.
444        if self._waiter is not None:
445            raise RuntimeError('%s() called while another coroutine is '
446                               'already waiting for incoming data' % func_name)
447
448        assert not self._eof, '_wait_for_data after EOF'
449
450        # Waiting for data while paused will make deadlock, so prevent it.
451        # This is essential for readexactly(n) for case when n > self._limit.
452        if self._paused:
453            self._paused = False
454            self._transport.resume_reading()
455
456        self._waiter = self._loop.create_future()
457        try:
458            yield from self._waiter
459        finally:
460            self._waiter = None
461
462    @coroutine
463    def readline(self):
464        """Read chunk of data from the stream until newline (b'\n') is found.
465
466        On success, return chunk that ends with newline. If only partial
467        line can be read due to EOF, return incomplete line without
468        terminating newline. When EOF was reached while no bytes read, empty
469        bytes object is returned.
470
471        If limit is reached, ValueError will be raised. In that case, if
472        newline was found, complete line including newline will be removed
473        from internal buffer. Else, internal buffer will be cleared. Limit is
474        compared against part of the line without newline.
475
476        If stream was paused, this function will automatically resume it if
477        needed.
478        """
479        sep = b'\n'
480        seplen = len(sep)
481        try:
482            line = yield from self.readuntil(sep)
483        except IncompleteReadError as e:
484            return e.partial
485        except LimitOverrunError as e:
486            if self._buffer.startswith(sep, e.consumed):
487                del self._buffer[:e.consumed + seplen]
488            else:
489                self._buffer.clear()
490            self._maybe_resume_transport()
491            raise ValueError(e.args[0])
492        return line
493
494    @coroutine
495    def readuntil(self, separator=b'\n'):
496        """Read data from the stream until ``separator`` is found.
497
498        On success, the data and separator will be removed from the
499        internal buffer (consumed). Returned data will include the
500        separator at the end.
501
502        Configured stream limit is used to check result. Limit sets the
503        maximal length of data that can be returned, not counting the
504        separator.
505
506        If an EOF occurs and the complete separator is still not found,
507        an IncompleteReadError exception will be raised, and the internal
508        buffer will be reset.  The IncompleteReadError.partial attribute
509        may contain the separator partially.
510
511        If the data cannot be read because of over limit, a
512        LimitOverrunError exception  will be raised, and the data
513        will be left in the internal buffer, so it can be read again.
514        """
515        seplen = len(separator)
516        if seplen == 0:
517            raise ValueError('Separator should be at least one-byte string')
518
519        if self._exception is not None:
520            raise self._exception
521
522        # Consume whole buffer except last bytes, which length is
523        # one less than seplen. Let's check corner cases with
524        # separator='SEPARATOR':
525        # * we have received almost complete separator (without last
526        #   byte). i.e buffer='some textSEPARATO'. In this case we
527        #   can safely consume len(separator) - 1 bytes.
528        # * last byte of buffer is first byte of separator, i.e.
529        #   buffer='abcdefghijklmnopqrS'. We may safely consume
530        #   everything except that last byte, but this require to
531        #   analyze bytes of buffer that match partial separator.
532        #   This is slow and/or require FSM. For this case our
533        #   implementation is not optimal, since require rescanning
534        #   of data that is known to not belong to separator. In
535        #   real world, separator will not be so long to notice
536        #   performance problems. Even when reading MIME-encoded
537        #   messages :)
538
539        # `offset` is the number of bytes from the beginning of the buffer
540        # where there is no occurrence of `separator`.
541        offset = 0
542
543        # Loop until we find `separator` in the buffer, exceed the buffer size,
544        # or an EOF has happened.
545        while True:
546            buflen = len(self._buffer)
547
548            # Check if we now have enough data in the buffer for `separator` to
549            # fit.
550            if buflen - offset >= seplen:
551                isep = self._buffer.find(separator, offset)
552
553                if isep != -1:
554                    # `separator` is in the buffer. `isep` will be used later
555                    # to retrieve the data.
556                    break
557
558                # see upper comment for explanation.
559                offset = buflen + 1 - seplen
560                if offset > self._limit:
561                    raise LimitOverrunError(
562                        'Separator is not found, and chunk exceed the limit',
563                        offset)
564
565            # Complete message (with full separator) may be present in buffer
566            # even when EOF flag is set. This may happen when the last chunk
567            # adds data which makes separator be found. That's why we check for
568            # EOF *ater* inspecting the buffer.
569            if self._eof:
570                chunk = bytes(self._buffer)
571                self._buffer.clear()
572                raise IncompleteReadError(chunk, None)
573
574            # _wait_for_data() will resume reading if stream was paused.
575            yield from self._wait_for_data('readuntil')
576
577        if isep > self._limit:
578            raise LimitOverrunError(
579                'Separator is found, but chunk is longer than limit', isep)
580
581        chunk = self._buffer[:isep + seplen]
582        del self._buffer[:isep + seplen]
583        self._maybe_resume_transport()
584        return bytes(chunk)
585
586    @coroutine
587    def read(self, n=-1):
588        """Read up to `n` bytes from the stream.
589
590        If n is not provided, or set to -1, read until EOF and return all read
591        bytes. If the EOF was received and the internal buffer is empty, return
592        an empty bytes object.
593
594        If n is zero, return empty bytes object immediately.
595
596        If n is positive, this function try to read `n` bytes, and may return
597        less or equal bytes than requested, but at least one byte. If EOF was
598        received before any byte is read, this function returns empty byte
599        object.
600
601        Returned value is not limited with limit, configured at stream
602        creation.
603
604        If stream was paused, this function will automatically resume it if
605        needed.
606        """
607
608        if self._exception is not None:
609            raise self._exception
610
611        if n == 0:
612            return b''
613
614        if n < 0:
615            # This used to just loop creating a new waiter hoping to
616            # collect everything in self._buffer, but that would
617            # deadlock if the subprocess sends more than self.limit
618            # bytes.  So just call self.read(self._limit) until EOF.
619            blocks = []
620            while True:
621                block = yield from self.read(self._limit)
622                if not block:
623                    break
624                blocks.append(block)
625            return b''.join(blocks)
626
627        if not self._buffer and not self._eof:
628            yield from self._wait_for_data('read')
629
630        # This will work right even if buffer is less than n bytes
631        data = bytes(self._buffer[:n])
632        del self._buffer[:n]
633
634        self._maybe_resume_transport()
635        return data
636
637    @coroutine
638    def readexactly(self, n):
639        """Read exactly `n` bytes.
640
641        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
642        read. The IncompleteReadError.partial attribute of the exception will
643        contain the partial read bytes.
644
645        if n is zero, return empty bytes object.
646
647        Returned value is not limited with limit, configured at stream
648        creation.
649
650        If stream was paused, this function will automatically resume it if
651        needed.
652        """
653        if n < 0:
654            raise ValueError('readexactly size can not be less than zero')
655
656        if self._exception is not None:
657            raise self._exception
658
659        if n == 0:
660            return b''
661
662        while len(self._buffer) < n:
663            if self._eof:
664                incomplete = bytes(self._buffer)
665                self._buffer.clear()
666                raise IncompleteReadError(incomplete, n)
667
668            yield from self._wait_for_data('readexactly')
669
670        if len(self._buffer) == n:
671            data = bytes(self._buffer)
672            self._buffer.clear()
673        else:
674            data = bytes(self._buffer[:n])
675            del self._buffer[:n]
676        self._maybe_resume_transport()
677        return data
678
679    if compat.PY35:
680        @coroutine
681        def __aiter__(self):
682            return self
683
684        @coroutine
685        def __anext__(self):
686            val = yield from self.readline()
687            if val == b'':
688                raise StopAsyncIteration
689            return val
690
691    if compat.PY352:
692        # In Python 3.5.2 and greater, __aiter__ should return
693        # the asynchronous iterator directly.
694        def __aiter__(self):
695            return self
696