1import collections
2import warnings
3try:
4    import ssl
5except ImportError:  # pragma: no cover
6    ssl = None
7
8from . import base_events
9from . import constants
10from . import protocols
11from . import transports
12from .log import logger
13
14
15def _create_transport_context(server_side, server_hostname):
16    if server_side:
17        raise ValueError('Server side SSL needs a valid SSLContext')
18
19    # Client side may pass ssl=True to use a default
20    # context; in that case the sslcontext passed is None.
21    # The default is secure for client connections.
22    # Python 3.4+: use up-to-date strong settings.
23    sslcontext = ssl.create_default_context()
24    if not server_hostname:
25        sslcontext.check_hostname = False
26    return sslcontext
27
28
29# States of an _SSLPipe.
30_UNWRAPPED = "UNWRAPPED"
31_DO_HANDSHAKE = "DO_HANDSHAKE"
32_WRAPPED = "WRAPPED"
33_SHUTDOWN = "SHUTDOWN"
34
35
36class _SSLPipe(object):
37    """An SSL "Pipe".
38
39    An SSL pipe allows you to communicate with an SSL/TLS protocol instance
40    through memory buffers. It can be used to implement a security layer for an
41    existing connection where you don't have access to the connection's file
42    descriptor, or for some reason you don't want to use it.
43
44    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
45    data is passed through untransformed. In wrapped mode, application level
46    data is encrypted to SSL record level data and vice versa. The SSL record
47    level is the lowest level in the SSL protocol suite and is what travels
48    as-is over the wire.
49
50    An SslPipe initially is in "unwrapped" mode. To start SSL, call
51    do_handshake(). To shutdown SSL again, call unwrap().
52    """
53
54    max_size = 256 * 1024   # Buffer size passed to read()
55
56    def __init__(self, context, server_side, server_hostname=None):
57        """
58        The *context* argument specifies the ssl.SSLContext to use.
59
60        The *server_side* argument indicates whether this is a server side or
61        client side transport.
62
63        The optional *server_hostname* argument can be used to specify the
64        hostname you are connecting to. You may only specify this parameter if
65        the _ssl module supports Server Name Indication (SNI).
66        """
67        self._context = context
68        self._server_side = server_side
69        self._server_hostname = server_hostname
70        self._state = _UNWRAPPED
71        self._incoming = ssl.MemoryBIO()
72        self._outgoing = ssl.MemoryBIO()
73        self._sslobj = None
74        self._need_ssldata = False
75        self._handshake_cb = None
76        self._shutdown_cb = None
77
78    @property
79    def context(self):
80        """The SSL context passed to the constructor."""
81        return self._context
82
83    @property
84    def ssl_object(self):
85        """The internal ssl.SSLObject instance.
86
87        Return None if the pipe is not wrapped.
88        """
89        return self._sslobj
90
91    @property
92    def need_ssldata(self):
93        """Whether more record level data is needed to complete a handshake
94        that is currently in progress."""
95        return self._need_ssldata
96
97    @property
98    def wrapped(self):
99        """
100        Whether a security layer is currently in effect.
101
102        Return False during handshake.
103        """
104        return self._state == _WRAPPED
105
106    def do_handshake(self, callback=None):
107        """Start the SSL handshake.
108
109        Return a list of ssldata. A ssldata element is a list of buffers
110
111        The optional *callback* argument can be used to install a callback that
112        will be called when the handshake is complete. The callback will be
113        called with None if successful, else an exception instance.
114        """
115        if self._state != _UNWRAPPED:
116            raise RuntimeError('handshake in progress or completed')
117        self._sslobj = self._context.wrap_bio(
118            self._incoming, self._outgoing,
119            server_side=self._server_side,
120            server_hostname=self._server_hostname)
121        self._state = _DO_HANDSHAKE
122        self._handshake_cb = callback
123        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
124        assert len(appdata) == 0
125        return ssldata
126
127    def shutdown(self, callback=None):
128        """Start the SSL shutdown sequence.
129
130        Return a list of ssldata. A ssldata element is a list of buffers
131
132        The optional *callback* argument can be used to install a callback that
133        will be called when the shutdown is complete. The callback will be
134        called without arguments.
135        """
136        if self._state == _UNWRAPPED:
137            raise RuntimeError('no security layer present')
138        if self._state == _SHUTDOWN:
139            raise RuntimeError('shutdown in progress')
140        assert self._state in (_WRAPPED, _DO_HANDSHAKE)
141        self._state = _SHUTDOWN
142        self._shutdown_cb = callback
143        ssldata, appdata = self.feed_ssldata(b'')
144        assert appdata == [] or appdata == [b'']
145        return ssldata
146
147    def feed_eof(self):
148        """Send a potentially "ragged" EOF.
149
150        This method will raise an SSL_ERROR_EOF exception if the EOF is
151        unexpected.
152        """
153        self._incoming.write_eof()
154        ssldata, appdata = self.feed_ssldata(b'')
155        assert appdata == [] or appdata == [b'']
156
157    def feed_ssldata(self, data, only_handshake=False):
158        """Feed SSL record level data into the pipe.
159
160        The data must be a bytes instance. It is OK to send an empty bytes
161        instance. This can be used to get ssldata for a handshake initiated by
162        this endpoint.
163
164        Return a (ssldata, appdata) tuple. The ssldata element is a list of
165        buffers containing SSL data that needs to be sent to the remote SSL.
166
167        The appdata element is a list of buffers containing plaintext data that
168        needs to be forwarded to the application. The appdata list may contain
169        an empty buffer indicating an SSL "close_notify" alert. This alert must
170        be acknowledged by calling shutdown().
171        """
172        if self._state == _UNWRAPPED:
173            # If unwrapped, pass plaintext data straight through.
174            if data:
175                appdata = [data]
176            else:
177                appdata = []
178            return ([], appdata)
179
180        self._need_ssldata = False
181        if data:
182            self._incoming.write(data)
183
184        ssldata = []
185        appdata = []
186        try:
187            if self._state == _DO_HANDSHAKE:
188                # Call do_handshake() until it doesn't raise anymore.
189                self._sslobj.do_handshake()
190                self._state = _WRAPPED
191                if self._handshake_cb:
192                    self._handshake_cb(None)
193                if only_handshake:
194                    return (ssldata, appdata)
195                # Handshake done: execute the wrapped block
196
197            if self._state == _WRAPPED:
198                # Main state: read data from SSL until close_notify
199                while True:
200                    chunk = self._sslobj.read(self.max_size)
201                    appdata.append(chunk)
202                    if not chunk:  # close_notify
203                        break
204
205            elif self._state == _SHUTDOWN:
206                # Call shutdown() until it doesn't raise anymore.
207                self._sslobj.unwrap()
208                self._sslobj = None
209                self._state = _UNWRAPPED
210                if self._shutdown_cb:
211                    self._shutdown_cb()
212
213            elif self._state == _UNWRAPPED:
214                # Drain possible plaintext data after close_notify.
215                appdata.append(self._incoming.read())
216        except (ssl.SSLError, ssl.CertificateError) as exc:
217            exc_errno = getattr(exc, 'errno', None)
218            if exc_errno not in (
219                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
220                    ssl.SSL_ERROR_SYSCALL):
221                if self._state == _DO_HANDSHAKE and self._handshake_cb:
222                    self._handshake_cb(exc)
223                raise
224            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
225
226        # Check for record level data that needs to be sent back.
227        # Happens for the initial handshake and renegotiations.
228        if self._outgoing.pending:
229            ssldata.append(self._outgoing.read())
230        return (ssldata, appdata)
231
232    def feed_appdata(self, data, offset=0):
233        """Feed plaintext data into the pipe.
234
235        Return an (ssldata, offset) tuple. The ssldata element is a list of
236        buffers containing record level data that needs to be sent to the
237        remote SSL instance. The offset is the number of plaintext bytes that
238        were processed, which may be less than the length of data.
239
240        NOTE: In case of short writes, this call MUST be retried with the SAME
241        buffer passed into the *data* argument (i.e. the id() must be the
242        same). This is an OpenSSL requirement. A further particularity is that
243        a short write will always have offset == 0, because the _ssl module
244        does not enable partial writes. And even though the offset is zero,
245        there will still be encrypted data in ssldata.
246        """
247        assert 0 <= offset <= len(data)
248        if self._state == _UNWRAPPED:
249            # pass through data in unwrapped mode
250            if offset < len(data):
251                ssldata = [data[offset:]]
252            else:
253                ssldata = []
254            return (ssldata, len(data))
255
256        ssldata = []
257        view = memoryview(data)
258        while True:
259            self._need_ssldata = False
260            try:
261                if offset < len(view):
262                    offset += self._sslobj.write(view[offset:])
263            except ssl.SSLError as exc:
264                # It is not allowed to call write() after unwrap() until the
265                # close_notify is acknowledged. We return the condition to the
266                # caller as a short write.
267                exc_errno = getattr(exc, 'errno', None)
268                if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
269                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
270                if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
271                                     ssl.SSL_ERROR_WANT_WRITE,
272                                     ssl.SSL_ERROR_SYSCALL):
273                    raise
274                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
275
276            # See if there's any record level data back for us.
277            if self._outgoing.pending:
278                ssldata.append(self._outgoing.read())
279            if offset == len(view) or self._need_ssldata:
280                break
281        return (ssldata, offset)
282
283
284class _SSLProtocolTransport(transports._FlowControlMixin,
285                            transports.Transport):
286
287    _sendfile_compatible = constants._SendfileMode.FALLBACK
288
289    def __init__(self, loop, ssl_protocol):
290        self._loop = loop
291        # SSLProtocol instance
292        self._ssl_protocol = ssl_protocol
293        self._closed = False
294
295    def get_extra_info(self, name, default=None):
296        """Get optional transport information."""
297        return self._ssl_protocol._get_extra_info(name, default)
298
299    def set_protocol(self, protocol):
300        self._ssl_protocol._set_app_protocol(protocol)
301
302    def get_protocol(self):
303        return self._ssl_protocol._app_protocol
304
305    def is_closing(self):
306        return self._closed
307
308    def close(self):
309        """Close the transport.
310
311        Buffered data will be flushed asynchronously.  No more data
312        will be received.  After all buffered data is flushed, the
313        protocol's connection_lost() method will (eventually) called
314        with None as its argument.
315        """
316        self._closed = True
317        self._ssl_protocol._start_shutdown()
318
319    def __del__(self):
320        if not self._closed:
321            warnings.warn(f"unclosed transport {self!r}", ResourceWarning,
322                          source=self)
323            self.close()
324
325    def is_reading(self):
326        tr = self._ssl_protocol._transport
327        if tr is None:
328            raise RuntimeError('SSL transport has not been initialized yet')
329        return tr.is_reading()
330
331    def pause_reading(self):
332        """Pause the receiving end.
333
334        No data will be passed to the protocol's data_received()
335        method until resume_reading() is called.
336        """
337        self._ssl_protocol._transport.pause_reading()
338
339    def resume_reading(self):
340        """Resume the receiving end.
341
342        Data received will once again be passed to the protocol's
343        data_received() method.
344        """
345        self._ssl_protocol._transport.resume_reading()
346
347    def set_write_buffer_limits(self, high=None, low=None):
348        """Set the high- and low-water limits for write flow control.
349
350        These two values control when to call the protocol's
351        pause_writing() and resume_writing() methods.  If specified,
352        the low-water limit must be less than or equal to the
353        high-water limit.  Neither value can be negative.
354
355        The defaults are implementation-specific.  If only the
356        high-water limit is given, the low-water limit defaults to an
357        implementation-specific value less than or equal to the
358        high-water limit.  Setting high to zero forces low to zero as
359        well, and causes pause_writing() to be called whenever the
360        buffer becomes non-empty.  Setting low to zero causes
361        resume_writing() to be called only once the buffer is empty.
362        Use of zero for either limit is generally sub-optimal as it
363        reduces opportunities for doing I/O and computation
364        concurrently.
365        """
366        self._ssl_protocol._transport.set_write_buffer_limits(high, low)
367
368    def get_write_buffer_size(self):
369        """Return the current size of the write buffer."""
370        return self._ssl_protocol._transport.get_write_buffer_size()
371
372    @property
373    def _protocol_paused(self):
374        # Required for sendfile fallback pause_writing/resume_writing logic
375        return self._ssl_protocol._transport._protocol_paused
376
377    def write(self, data):
378        """Write some data bytes to the transport.
379
380        This does not block; it buffers the data and arranges for it
381        to be sent out asynchronously.
382        """
383        if not isinstance(data, (bytes, bytearray, memoryview)):
384            raise TypeError(f"data: expecting a bytes-like instance, "
385                            f"got {type(data).__name__}")
386        if not data:
387            return
388        self._ssl_protocol._write_appdata(data)
389
390    def can_write_eof(self):
391        """Return True if this transport supports write_eof(), False if not."""
392        return False
393
394    def abort(self):
395        """Close the transport immediately.
396
397        Buffered data will be lost.  No more data will be received.
398        The protocol's connection_lost() method will (eventually) be
399        called with None as its argument.
400        """
401        self._ssl_protocol._abort()
402        self._closed = True
403
404
405class SSLProtocol(protocols.Protocol):
406    """SSL protocol.
407
408    Implementation of SSL on top of a socket using incoming and outgoing
409    buffers which are ssl.MemoryBIO objects.
410    """
411
412    def __init__(self, loop, app_protocol, sslcontext, waiter,
413                 server_side=False, server_hostname=None,
414                 call_connection_made=True,
415                 ssl_handshake_timeout=None):
416        if ssl is None:
417            raise RuntimeError('stdlib ssl module not available')
418
419        if ssl_handshake_timeout is None:
420            ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
421        elif ssl_handshake_timeout <= 0:
422            raise ValueError(
423                f"ssl_handshake_timeout should be a positive number, "
424                f"got {ssl_handshake_timeout}")
425
426        if not sslcontext:
427            sslcontext = _create_transport_context(
428                server_side, server_hostname)
429
430        self._server_side = server_side
431        if server_hostname and not server_side:
432            self._server_hostname = server_hostname
433        else:
434            self._server_hostname = None
435        self._sslcontext = sslcontext
436        # SSL-specific extra info. More info are set when the handshake
437        # completes.
438        self._extra = dict(sslcontext=sslcontext)
439
440        # App data write buffering
441        self._write_backlog = collections.deque()
442        self._write_buffer_size = 0
443
444        self._waiter = waiter
445        self._loop = loop
446        self._set_app_protocol(app_protocol)
447        self._app_transport = _SSLProtocolTransport(self._loop, self)
448        # _SSLPipe instance (None until the connection is made)
449        self._sslpipe = None
450        self._session_established = False
451        self._in_handshake = False
452        self._in_shutdown = False
453        # transport, ex: SelectorSocketTransport
454        self._transport = None
455        self._call_connection_made = call_connection_made
456        self._ssl_handshake_timeout = ssl_handshake_timeout
457
458    def _set_app_protocol(self, app_protocol):
459        self._app_protocol = app_protocol
460        self._app_protocol_is_buffer = \
461            isinstance(app_protocol, protocols.BufferedProtocol)
462
463    def _wakeup_waiter(self, exc=None):
464        if self._waiter is None:
465            return
466        if not self._waiter.cancelled():
467            if exc is not None:
468                self._waiter.set_exception(exc)
469            else:
470                self._waiter.set_result(None)
471        self._waiter = None
472
473    def connection_made(self, transport):
474        """Called when the low-level connection is made.
475
476        Start the SSL handshake.
477        """
478        self._transport = transport
479        self._sslpipe = _SSLPipe(self._sslcontext,
480                                 self._server_side,
481                                 self._server_hostname)
482        self._start_handshake()
483
484    def connection_lost(self, exc):
485        """Called when the low-level connection is lost or closed.
486
487        The argument is an exception object or None (the latter
488        meaning a regular EOF is received or the connection was
489        aborted or closed).
490        """
491        if self._session_established:
492            self._session_established = False
493            self._loop.call_soon(self._app_protocol.connection_lost, exc)
494        else:
495            # Most likely an exception occurred while in SSL handshake.
496            # Just mark the app transport as closed so that its __del__
497            # doesn't complain.
498            if self._app_transport is not None:
499                self._app_transport._closed = True
500        self._transport = None
501        self._app_transport = None
502        self._wakeup_waiter(exc)
503
504    def pause_writing(self):
505        """Called when the low-level transport's buffer goes over
506        the high-water mark.
507        """
508        self._app_protocol.pause_writing()
509
510    def resume_writing(self):
511        """Called when the low-level transport's buffer drains below
512        the low-water mark.
513        """
514        self._app_protocol.resume_writing()
515
516    def data_received(self, data):
517        """Called when some SSL data is received.
518
519        The argument is a bytes object.
520        """
521        if self._sslpipe is None:
522            # transport closing, sslpipe is destroyed
523            return
524
525        try:
526            ssldata, appdata = self._sslpipe.feed_ssldata(data)
527        except Exception as e:
528            self._fatal_error(e, 'SSL error in data received')
529            return
530
531        for chunk in ssldata:
532            self._transport.write(chunk)
533
534        for chunk in appdata:
535            if chunk:
536                try:
537                    if self._app_protocol_is_buffer:
538                        protocols._feed_data_to_buffered_proto(
539                            self._app_protocol, chunk)
540                    else:
541                        self._app_protocol.data_received(chunk)
542                except Exception as ex:
543                    self._fatal_error(
544                        ex, 'application protocol failed to receive SSL data')
545                    return
546            else:
547                self._start_shutdown()
548                break
549
550    def eof_received(self):
551        """Called when the other end of the low-level stream
552        is half-closed.
553
554        If this returns a false value (including None), the transport
555        will close itself.  If it returns a true value, closing the
556        transport is up to the protocol.
557        """
558        try:
559            if self._loop.get_debug():
560                logger.debug("%r received EOF", self)
561
562            self._wakeup_waiter(ConnectionResetError)
563
564            if not self._in_handshake:
565                keep_open = self._app_protocol.eof_received()
566                if keep_open:
567                    logger.warning('returning true from eof_received() '
568                                   'has no effect when using ssl')
569        finally:
570            self._transport.close()
571
572    def _get_extra_info(self, name, default=None):
573        if name in self._extra:
574            return self._extra[name]
575        elif self._transport is not None:
576            return self._transport.get_extra_info(name, default)
577        else:
578            return default
579
580    def _start_shutdown(self):
581        if self._in_shutdown:
582            return
583        if self._in_handshake:
584            self._abort()
585        else:
586            self._in_shutdown = True
587            self._write_appdata(b'')
588
589    def _write_appdata(self, data):
590        self._write_backlog.append((data, 0))
591        self._write_buffer_size += len(data)
592        self._process_write_backlog()
593
594    def _start_handshake(self):
595        if self._loop.get_debug():
596            logger.debug("%r starts SSL handshake", self)
597            self._handshake_start_time = self._loop.time()
598        else:
599            self._handshake_start_time = None
600        self._in_handshake = True
601        # (b'', 1) is a special value in _process_write_backlog() to do
602        # the SSL handshake
603        self._write_backlog.append((b'', 1))
604        self._handshake_timeout_handle = \
605            self._loop.call_later(self._ssl_handshake_timeout,
606                                  self._check_handshake_timeout)
607        self._process_write_backlog()
608
609    def _check_handshake_timeout(self):
610        if self._in_handshake is True:
611            msg = (
612                f"SSL handshake is taking longer than "
613                f"{self._ssl_handshake_timeout} seconds: "
614                f"aborting the connection"
615            )
616            self._fatal_error(ConnectionAbortedError(msg))
617
618    def _on_handshake_complete(self, handshake_exc):
619        self._in_handshake = False
620        self._handshake_timeout_handle.cancel()
621
622        sslobj = self._sslpipe.ssl_object
623        try:
624            if handshake_exc is not None:
625                raise handshake_exc
626
627            peercert = sslobj.getpeercert()
628        except Exception as exc:
629            if isinstance(exc, ssl.CertificateError):
630                msg = 'SSL handshake failed on verifying the certificate'
631            else:
632                msg = 'SSL handshake failed'
633            self._fatal_error(exc, msg)
634            return
635
636        if self._loop.get_debug():
637            dt = self._loop.time() - self._handshake_start_time
638            logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
639
640        # Add extra info that becomes available after handshake.
641        self._extra.update(peercert=peercert,
642                           cipher=sslobj.cipher(),
643                           compression=sslobj.compression(),
644                           ssl_object=sslobj,
645                           )
646        if self._call_connection_made:
647            self._app_protocol.connection_made(self._app_transport)
648        self._wakeup_waiter()
649        self._session_established = True
650        # In case transport.write() was already called. Don't call
651        # immediately _process_write_backlog(), but schedule it:
652        # _on_handshake_complete() can be called indirectly from
653        # _process_write_backlog(), and _process_write_backlog() is not
654        # reentrant.
655        self._loop.call_soon(self._process_write_backlog)
656
657    def _process_write_backlog(self):
658        # Try to make progress on the write backlog.
659        if self._transport is None or self._sslpipe is None:
660            return
661
662        try:
663            for i in range(len(self._write_backlog)):
664                data, offset = self._write_backlog[0]
665                if data:
666                    ssldata, offset = self._sslpipe.feed_appdata(data, offset)
667                elif offset:
668                    ssldata = self._sslpipe.do_handshake(
669                        self._on_handshake_complete)
670                    offset = 1
671                else:
672                    ssldata = self._sslpipe.shutdown(self._finalize)
673                    offset = 1
674
675                for chunk in ssldata:
676                    self._transport.write(chunk)
677
678                if offset < len(data):
679                    self._write_backlog[0] = (data, offset)
680                    # A short write means that a write is blocked on a read
681                    # We need to enable reading if it is paused!
682                    assert self._sslpipe.need_ssldata
683                    if self._transport._paused:
684                        self._transport.resume_reading()
685                    break
686
687                # An entire chunk from the backlog was processed. We can
688                # delete it and reduce the outstanding buffer size.
689                del self._write_backlog[0]
690                self._write_buffer_size -= len(data)
691        except Exception as exc:
692            if self._in_handshake:
693                # Exceptions will be re-raised in _on_handshake_complete.
694                self._on_handshake_complete(exc)
695            else:
696                self._fatal_error(exc, 'Fatal error on SSL transport')
697
698    def _fatal_error(self, exc, message='Fatal error on transport'):
699        if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
700            if self._loop.get_debug():
701                logger.debug("%r: %s", self, message, exc_info=True)
702        else:
703            self._loop.call_exception_handler({
704                'message': message,
705                'exception': exc,
706                'transport': self._transport,
707                'protocol': self,
708            })
709        if self._transport:
710            self._transport._force_close(exc)
711
712    def _finalize(self):
713        self._sslpipe = None
714
715        if self._transport is not None:
716            self._transport.close()
717
718    def _abort(self):
719        try:
720            if self._transport is not None:
721                self._transport.abort()
722        finally:
723            self._finalize()
724