1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
22                          by the server running on HOST at port PORT.  No
23                          validation of the certificate is performed.
24
25Integer constants:
26
27SSL_ERROR_ZERO_RETURN
28SSL_ERROR_WANT_READ
29SSL_ERROR_WANT_WRITE
30SSL_ERROR_WANT_X509_LOOKUP
31SSL_ERROR_SYSCALL
32SSL_ERROR_SSL
33SSL_ERROR_WANT_CONNECT
34
35SSL_ERROR_EOF
36SSL_ERROR_INVALID_ERROR_CODE
37
38The following group define certificate requirements that one side is
39allowing/requiring from the other side:
40
41CERT_NONE - no certificates from the other side are required (or will
42            be looked at if provided)
43CERT_OPTIONAL - certificates are not required, but if provided will be
44                validated, and if validation fails, the connection will
45                also fail
46CERT_REQUIRED - certificates are required, and will be validated, and
47                if validation fails, the connection will also fail
48
49The following constants identify various SSL protocol variants:
50
51PROTOCOL_SSLv2
52PROTOCOL_SSLv3
53PROTOCOL_SSLv23
54PROTOCOL_TLS
55PROTOCOL_TLS_CLIENT
56PROTOCOL_TLS_SERVER
57PROTOCOL_TLSv1
58PROTOCOL_TLSv1_1
59PROTOCOL_TLSv1_2
60
61The following constants identify various SSL alert message descriptions as per
62http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
63
64ALERT_DESCRIPTION_CLOSE_NOTIFY
65ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
66ALERT_DESCRIPTION_BAD_RECORD_MAC
67ALERT_DESCRIPTION_RECORD_OVERFLOW
68ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
69ALERT_DESCRIPTION_HANDSHAKE_FAILURE
70ALERT_DESCRIPTION_BAD_CERTIFICATE
71ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
72ALERT_DESCRIPTION_CERTIFICATE_REVOKED
73ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
74ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
75ALERT_DESCRIPTION_ILLEGAL_PARAMETER
76ALERT_DESCRIPTION_UNKNOWN_CA
77ALERT_DESCRIPTION_ACCESS_DENIED
78ALERT_DESCRIPTION_DECODE_ERROR
79ALERT_DESCRIPTION_DECRYPT_ERROR
80ALERT_DESCRIPTION_PROTOCOL_VERSION
81ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
82ALERT_DESCRIPTION_INTERNAL_ERROR
83ALERT_DESCRIPTION_USER_CANCELLED
84ALERT_DESCRIPTION_NO_RENEGOTIATION
85ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
86ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
87ALERT_DESCRIPTION_UNRECOGNIZED_NAME
88ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
89ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
90ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
91"""
92
93import sys
94import os
95from collections import namedtuple
96from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
97
98import _ssl             # if we can't import it, let the error propagate
99
100from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
101from _ssl import _SSLContext, MemoryBIO, SSLSession
102from _ssl import (
103    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
104    SSLSyscallError, SSLEOFError, SSLCertVerificationError
105    )
106from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
107from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes
108try:
109    from _ssl import RAND_egd
110except ImportError:
111    # LibreSSL does not provide RAND_egd
112    pass
113
114
115from _ssl import (
116    HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
117    HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
118)
119from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
120
121
122_IntEnum._convert(
123    '_SSLMethod', __name__,
124    lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
125    source=_ssl)
126
127_IntFlag._convert(
128    'Options', __name__,
129    lambda name: name.startswith('OP_'),
130    source=_ssl)
131
132_IntEnum._convert(
133    'AlertDescription', __name__,
134    lambda name: name.startswith('ALERT_DESCRIPTION_'),
135    source=_ssl)
136
137_IntEnum._convert(
138    'SSLErrorNumber', __name__,
139    lambda name: name.startswith('SSL_ERROR_'),
140    source=_ssl)
141
142_IntFlag._convert(
143    'VerifyFlags', __name__,
144    lambda name: name.startswith('VERIFY_'),
145    source=_ssl)
146
147_IntEnum._convert(
148    'VerifyMode', __name__,
149    lambda name: name.startswith('CERT_'),
150    source=_ssl)
151
152PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
153_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
154
155_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
156
157
158class TLSVersion(_IntEnum):
159    MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
160    SSLv3 = _ssl.PROTO_SSLv3
161    TLSv1 = _ssl.PROTO_TLSv1
162    TLSv1_1 = _ssl.PROTO_TLSv1_1
163    TLSv1_2 = _ssl.PROTO_TLSv1_2
164    TLSv1_3 = _ssl.PROTO_TLSv1_3
165    MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
166
167
168if sys.platform == "win32":
169    from _ssl import enum_certificates, enum_crls
170
171from socket import socket, AF_INET, SOCK_STREAM, create_connection
172from socket import SOL_SOCKET, SO_TYPE
173import socket as _socket
174import base64        # for DER-to-PEM translation
175import errno
176import warnings
177
178
179socket_error = OSError  # keep that public name in module namespace
180
181CHANNEL_BINDING_TYPES = ['tls-unique']
182
183HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
184
185
186_RESTRICTED_SERVER_CIPHERS = _DEFAULT_CIPHERS
187
188CertificateError = SSLCertVerificationError
189
190
191def _dnsname_match(dn, hostname):
192    """Matching according to RFC 6125, section 6.4.3
193
194    - Hostnames are compared lower case.
195    - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE).
196    - Partial wildcards like 'www*.example.org', multiple wildcards, sole
197      wildcard or wildcards in labels other then the left-most label are not
198      supported and a CertificateError is raised.
199    - A wildcard must match at least one character.
200    """
201    if not dn:
202        return False
203
204    wildcards = dn.count('*')
205    # speed up common case w/o wildcards
206    if not wildcards:
207        return dn.lower() == hostname.lower()
208
209    if wildcards > 1:
210        raise CertificateError(
211            "too many wildcards in certificate DNS name: {!r}.".format(dn))
212
213    dn_leftmost, sep, dn_remainder = dn.partition('.')
214
215    if '*' in dn_remainder:
216        # Only match wildcard in leftmost segment.
217        raise CertificateError(
218            "wildcard can only be present in the leftmost label: "
219            "{!r}.".format(dn))
220
221    if not sep:
222        # no right side
223        raise CertificateError(
224            "sole wildcard without additional labels are not support: "
225            "{!r}.".format(dn))
226
227    if dn_leftmost != '*':
228        # no partial wildcard matching
229        raise CertificateError(
230            "partial wildcards in leftmost label are not supported: "
231            "{!r}.".format(dn))
232
233    hostname_leftmost, sep, hostname_remainder = hostname.partition('.')
234    if not hostname_leftmost or not sep:
235        # wildcard must match at least one char
236        return False
237    return dn_remainder.lower() == hostname_remainder.lower()
238
239
240def _inet_paton(ipname):
241    """Try to convert an IP address to packed binary form
242
243    Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6
244    support.
245    """
246    # inet_aton() also accepts strings like '1'
247    if ipname.count('.') == 3:
248        try:
249            return _socket.inet_aton(ipname)
250        except OSError:
251            pass
252
253    try:
254        return _socket.inet_pton(_socket.AF_INET6, ipname)
255    except OSError:
256        raise ValueError("{!r} is neither an IPv4 nor an IP6 "
257                         "address.".format(ipname))
258    except AttributeError:
259        # AF_INET6 not available
260        pass
261
262    raise ValueError("{!r} is not an IPv4 address.".format(ipname))
263
264
265def _ipaddress_match(ipname, host_ip):
266    """Exact matching of IP addresses.
267
268    RFC 6125 explicitly doesn't define an algorithm for this
269    (section 1.7.2 - "Out of Scope").
270    """
271    # OpenSSL may add a trailing newline to a subjectAltName's IP address
272    ip = _inet_paton(ipname.rstrip())
273    return ip == host_ip
274
275
276def match_hostname(cert, hostname):
277    """Verify that *cert* (in decoded format as returned by
278    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
279    rules are followed.
280
281    The function matches IP addresses rather than dNSNames if hostname is a
282    valid ipaddress string. IPv4 addresses are supported on all platforms.
283    IPv6 addresses are supported on platforms with IPv6 support (AF_INET6
284    and inet_pton).
285
286    CertificateError is raised on failure. On success, the function
287    returns nothing.
288    """
289    if not cert:
290        raise ValueError("empty or no certificate, match_hostname needs a "
291                         "SSL socket or SSL context with either "
292                         "CERT_OPTIONAL or CERT_REQUIRED")
293    try:
294        host_ip = _inet_paton(hostname)
295    except ValueError:
296        # Not an IP address (common case)
297        host_ip = None
298    dnsnames = []
299    san = cert.get('subjectAltName', ())
300    for key, value in san:
301        if key == 'DNS':
302            if host_ip is None and _dnsname_match(value, hostname):
303                return
304            dnsnames.append(value)
305        elif key == 'IP Address':
306            if host_ip is not None and _ipaddress_match(value, host_ip):
307                return
308            dnsnames.append(value)
309    if not dnsnames:
310        # The subject is only checked when there is no dNSName entry
311        # in subjectAltName
312        for sub in cert.get('subject', ()):
313            for key, value in sub:
314                # XXX according to RFC 2818, the most specific Common Name
315                # must be used.
316                if key == 'commonName':
317                    if _dnsname_match(value, hostname):
318                        return
319                    dnsnames.append(value)
320    if len(dnsnames) > 1:
321        raise CertificateError("hostname %r "
322            "doesn't match either of %s"
323            % (hostname, ', '.join(map(repr, dnsnames))))
324    elif len(dnsnames) == 1:
325        raise CertificateError("hostname %r "
326            "doesn't match %r"
327            % (hostname, dnsnames[0]))
328    else:
329        raise CertificateError("no appropriate commonName or "
330            "subjectAltName fields were found")
331
332
333DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
334    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
335    "openssl_capath")
336
337def get_default_verify_paths():
338    """Return paths to default cafile and capath.
339    """
340    parts = _ssl.get_default_verify_paths()
341
342    # environment vars shadow paths
343    cafile = os.environ.get(parts[0], parts[1])
344    capath = os.environ.get(parts[2], parts[3])
345
346    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
347                              capath if os.path.isdir(capath) else None,
348                              *parts)
349
350
351class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
352    """ASN.1 object identifier lookup
353    """
354    __slots__ = ()
355
356    def __new__(cls, oid):
357        return super().__new__(cls, *_txt2obj(oid, name=False))
358
359    @classmethod
360    def fromnid(cls, nid):
361        """Create _ASN1Object from OpenSSL numeric ID
362        """
363        return super().__new__(cls, *_nid2obj(nid))
364
365    @classmethod
366    def fromname(cls, name):
367        """Create _ASN1Object from short name, long name or OID
368        """
369        return super().__new__(cls, *_txt2obj(name, name=True))
370
371
372class Purpose(_ASN1Object, _Enum):
373    """SSLContext purpose flags with X509v3 Extended Key Usage objects
374    """
375    SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
376    CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
377
378
379class SSLContext(_SSLContext):
380    """An SSLContext holds various SSL-related configuration options and
381    data, such as certificates and possibly a private key."""
382    _windows_cert_stores = ("CA", "ROOT")
383
384    sslsocket_class = None  # SSLSocket is assigned later.
385    sslobject_class = None  # SSLObject is assigned later.
386
387    def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs):
388        self = _SSLContext.__new__(cls, protocol)
389        return self
390
391    def _encode_hostname(self, hostname):
392        if hostname is None:
393            return None
394        elif isinstance(hostname, str):
395            return hostname.encode('idna').decode('ascii')
396        else:
397            return hostname.decode('ascii')
398
399    def wrap_socket(self, sock, server_side=False,
400                    do_handshake_on_connect=True,
401                    suppress_ragged_eofs=True,
402                    server_hostname=None, session=None):
403        # SSLSocket class handles server_hostname encoding before it calls
404        # ctx._wrap_socket()
405        return self.sslsocket_class._create(
406            sock=sock,
407            server_side=server_side,
408            do_handshake_on_connect=do_handshake_on_connect,
409            suppress_ragged_eofs=suppress_ragged_eofs,
410            server_hostname=server_hostname,
411            context=self,
412            session=session
413        )
414
415    def wrap_bio(self, incoming, outgoing, server_side=False,
416                 server_hostname=None, session=None):
417        # Need to encode server_hostname here because _wrap_bio() can only
418        # handle ASCII str.
419        return self.sslobject_class._create(
420            incoming, outgoing, server_side=server_side,
421            server_hostname=self._encode_hostname(server_hostname),
422            session=session, context=self,
423        )
424
425    def set_npn_protocols(self, npn_protocols):
426        protos = bytearray()
427        for protocol in npn_protocols:
428            b = bytes(protocol, 'ascii')
429            if len(b) == 0 or len(b) > 255:
430                raise SSLError('NPN protocols must be 1 to 255 in length')
431            protos.append(len(b))
432            protos.extend(b)
433
434        self._set_npn_protocols(protos)
435
436    def set_servername_callback(self, server_name_callback):
437        if server_name_callback is None:
438            self.sni_callback = None
439        else:
440            if not callable(server_name_callback):
441                raise TypeError("not a callable object")
442
443            def shim_cb(sslobj, servername, sslctx):
444                servername = self._encode_hostname(servername)
445                return server_name_callback(sslobj, servername, sslctx)
446
447            self.sni_callback = shim_cb
448
449    def set_alpn_protocols(self, alpn_protocols):
450        protos = bytearray()
451        for protocol in alpn_protocols:
452            b = bytes(protocol, 'ascii')
453            if len(b) == 0 or len(b) > 255:
454                raise SSLError('ALPN protocols must be 1 to 255 in length')
455            protos.append(len(b))
456            protos.extend(b)
457
458        self._set_alpn_protocols(protos)
459
460    def _load_windows_store_certs(self, storename, purpose):
461        certs = bytearray()
462        try:
463            for cert, encoding, trust in enum_certificates(storename):
464                # CA certs are never PKCS#7 encoded
465                if encoding == "x509_asn":
466                    if trust is True or purpose.oid in trust:
467                        certs.extend(cert)
468        except PermissionError:
469            warnings.warn("unable to enumerate Windows certificate store")
470        if certs:
471            self.load_verify_locations(cadata=certs)
472        return certs
473
474    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
475        if not isinstance(purpose, _ASN1Object):
476            raise TypeError(purpose)
477        if sys.platform == "win32":
478            for storename in self._windows_cert_stores:
479                self._load_windows_store_certs(storename, purpose)
480        self.set_default_verify_paths()
481
482    if hasattr(_SSLContext, 'minimum_version'):
483        @property
484        def minimum_version(self):
485            return TLSVersion(super().minimum_version)
486
487        @minimum_version.setter
488        def minimum_version(self, value):
489            if value == TLSVersion.SSLv3:
490                self.options &= ~Options.OP_NO_SSLv3
491            super(SSLContext, SSLContext).minimum_version.__set__(self, value)
492
493        @property
494        def maximum_version(self):
495            return TLSVersion(super().maximum_version)
496
497        @maximum_version.setter
498        def maximum_version(self, value):
499            super(SSLContext, SSLContext).maximum_version.__set__(self, value)
500
501    @property
502    def options(self):
503        return Options(super().options)
504
505    @options.setter
506    def options(self, value):
507        super(SSLContext, SSLContext).options.__set__(self, value)
508
509    if hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT'):
510        @property
511        def hostname_checks_common_name(self):
512            ncs = self._host_flags & _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
513            return ncs != _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
514
515        @hostname_checks_common_name.setter
516        def hostname_checks_common_name(self, value):
517            if value:
518                self._host_flags &= ~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
519            else:
520                self._host_flags |= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
521    else:
522        @property
523        def hostname_checks_common_name(self):
524            return True
525
526    @property
527    def protocol(self):
528        return _SSLMethod(super().protocol)
529
530    @property
531    def verify_flags(self):
532        return VerifyFlags(super().verify_flags)
533
534    @verify_flags.setter
535    def verify_flags(self, value):
536        super(SSLContext, SSLContext).verify_flags.__set__(self, value)
537
538    @property
539    def verify_mode(self):
540        value = super().verify_mode
541        try:
542            return VerifyMode(value)
543        except ValueError:
544            return value
545
546    @verify_mode.setter
547    def verify_mode(self, value):
548        super(SSLContext, SSLContext).verify_mode.__set__(self, value)
549
550
551def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
552                           capath=None, cadata=None):
553    """Create a SSLContext object with default settings.
554
555    NOTE: The protocol and settings may change anytime without prior
556          deprecation. The values represent a fair balance between maximum
557          compatibility and security.
558    """
559    if not isinstance(purpose, _ASN1Object):
560        raise TypeError(purpose)
561
562    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
563    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
564    # by default.
565    context = SSLContext(PROTOCOL_TLS)
566
567    if purpose == Purpose.SERVER_AUTH:
568        # verify certs and host name in client mode
569        context.verify_mode = CERT_REQUIRED
570        context.check_hostname = True
571
572    if cafile or capath or cadata:
573        context.load_verify_locations(cafile, capath, cadata)
574    elif context.verify_mode != CERT_NONE:
575        # no explicit cafile, capath or cadata but the verify mode is
576        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
577        # root CA certificates for the given purpose. This may fail silently.
578        context.load_default_certs(purpose)
579    return context
580
581def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE,
582                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
583                           certfile=None, keyfile=None,
584                           cafile=None, capath=None, cadata=None):
585    """Create a SSLContext object for Python stdlib modules
586
587    All Python stdlib modules shall use this function to create SSLContext
588    objects in order to keep common settings in one place. The configuration
589    is less restrict than create_default_context()'s to increase backward
590    compatibility.
591    """
592    if not isinstance(purpose, _ASN1Object):
593        raise TypeError(purpose)
594
595    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
596    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
597    # by default.
598    context = SSLContext(protocol)
599
600    if not check_hostname:
601        context.check_hostname = False
602    if cert_reqs is not None:
603        context.verify_mode = cert_reqs
604    if check_hostname:
605        context.check_hostname = True
606
607    if keyfile and not certfile:
608        raise ValueError("certfile must be specified")
609    if certfile or keyfile:
610        context.load_cert_chain(certfile, keyfile)
611
612    # load CA root certs
613    if cafile or capath or cadata:
614        context.load_verify_locations(cafile, capath, cadata)
615    elif context.verify_mode != CERT_NONE:
616        # no explicit cafile, capath or cadata but the verify mode is
617        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
618        # root CA certificates for the given purpose. This may fail silently.
619        context.load_default_certs(purpose)
620
621    return context
622
623# Used by http.client if no context is explicitly passed.
624_create_default_https_context = create_default_context
625
626
627# Backwards compatibility alias, even though it's not a public name.
628_create_stdlib_context = _create_unverified_context
629
630
631class SSLObject:
632    """This class implements an interface on top of a low-level SSL object as
633    implemented by OpenSSL. This object captures the state of an SSL connection
634    but does not provide any network IO itself. IO needs to be performed
635    through separate "BIO" objects which are OpenSSL's IO abstraction layer.
636
637    This class does not have a public constructor. Instances are returned by
638    ``SSLContext.wrap_bio``. This class is typically used by framework authors
639    that want to implement asynchronous IO for SSL through memory buffers.
640
641    When compared to ``SSLSocket``, this object lacks the following features:
642
643     * Any form of network IO, including methods such as ``recv`` and ``send``.
644     * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
645    """
646    def __init__(self, *args, **kwargs):
647        raise TypeError(
648            f"{self.__class__.__name__} does not have a public "
649            f"constructor. Instances are returned by SSLContext.wrap_bio()."
650        )
651
652    @classmethod
653    def _create(cls, incoming, outgoing, server_side=False,
654                 server_hostname=None, session=None, context=None):
655        self = cls.__new__(cls)
656        sslobj = context._wrap_bio(
657            incoming, outgoing, server_side=server_side,
658            server_hostname=server_hostname,
659            owner=self, session=session
660        )
661        self._sslobj = sslobj
662        return self
663
664    @property
665    def context(self):
666        """The SSLContext that is currently in use."""
667        return self._sslobj.context
668
669    @context.setter
670    def context(self, ctx):
671        self._sslobj.context = ctx
672
673    @property
674    def session(self):
675        """The SSLSession for client socket."""
676        return self._sslobj.session
677
678    @session.setter
679    def session(self, session):
680        self._sslobj.session = session
681
682    @property
683    def session_reused(self):
684        """Was the client session reused during handshake"""
685        return self._sslobj.session_reused
686
687    @property
688    def server_side(self):
689        """Whether this is a server-side socket."""
690        return self._sslobj.server_side
691
692    @property
693    def server_hostname(self):
694        """The currently set server hostname (for SNI), or ``None`` if no
695        server hostame is set."""
696        return self._sslobj.server_hostname
697
698    def read(self, len=1024, buffer=None):
699        """Read up to 'len' bytes from the SSL object and return them.
700
701        If 'buffer' is provided, read into this buffer and return the number of
702        bytes read.
703        """
704        if buffer is not None:
705            v = self._sslobj.read(len, buffer)
706        else:
707            v = self._sslobj.read(len)
708        return v
709
710    def write(self, data):
711        """Write 'data' to the SSL object and return the number of bytes
712        written.
713
714        The 'data' argument must support the buffer interface.
715        """
716        return self._sslobj.write(data)
717
718    def getpeercert(self, binary_form=False):
719        """Returns a formatted version of the data in the certificate provided
720        by the other end of the SSL channel.
721
722        Return None if no certificate was provided, {} if a certificate was
723        provided, but not validated.
724        """
725        return self._sslobj.getpeercert(binary_form)
726
727    def selected_npn_protocol(self):
728        """Return the currently selected NPN protocol as a string, or ``None``
729        if a next protocol was not negotiated or if NPN is not supported by one
730        of the peers."""
731        if _ssl.HAS_NPN:
732            return self._sslobj.selected_npn_protocol()
733
734    def selected_alpn_protocol(self):
735        """Return the currently selected ALPN protocol as a string, or ``None``
736        if a next protocol was not negotiated or if ALPN is not supported by one
737        of the peers."""
738        if _ssl.HAS_ALPN:
739            return self._sslobj.selected_alpn_protocol()
740
741    def cipher(self):
742        """Return the currently selected cipher as a 3-tuple ``(name,
743        ssl_version, secret_bits)``."""
744        return self._sslobj.cipher()
745
746    def shared_ciphers(self):
747        """Return a list of ciphers shared by the client during the handshake or
748        None if this is not a valid server connection.
749        """
750        return self._sslobj.shared_ciphers()
751
752    def compression(self):
753        """Return the current compression algorithm in use, or ``None`` if
754        compression was not negotiated or not supported by one of the peers."""
755        return self._sslobj.compression()
756
757    def pending(self):
758        """Return the number of bytes that can be read immediately."""
759        return self._sslobj.pending()
760
761    def do_handshake(self):
762        """Start the SSL/TLS handshake."""
763        self._sslobj.do_handshake()
764
765    def unwrap(self):
766        """Start the SSL shutdown handshake."""
767        return self._sslobj.shutdown()
768
769    def get_channel_binding(self, cb_type="tls-unique"):
770        """Get channel binding data for current connection.  Raise ValueError
771        if the requested `cb_type` is not supported.  Return bytes of the data
772        or None if the data is not available (e.g. before the handshake)."""
773        return self._sslobj.get_channel_binding(cb_type)
774
775    def version(self):
776        """Return a string identifying the protocol version used by the
777        current SSL channel. """
778        return self._sslobj.version()
779
780    def verify_client_post_handshake(self):
781        return self._sslobj.verify_client_post_handshake()
782
783
784class SSLSocket(socket):
785    """This class implements a subtype of socket.socket that wraps
786    the underlying OS socket in an SSL context when necessary, and
787    provides read and write methods over that channel. """
788
789    def __init__(self, *args, **kwargs):
790        raise TypeError(
791            f"{self.__class__.__name__} does not have a public "
792            f"constructor. Instances are returned by "
793            f"SSLContext.wrap_socket()."
794        )
795
796    @classmethod
797    def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
798                suppress_ragged_eofs=True, server_hostname=None,
799                context=None, session=None):
800        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
801            raise NotImplementedError("only stream sockets are supported")
802        if server_side:
803            if server_hostname:
804                raise ValueError("server_hostname can only be specified "
805                                 "in client mode")
806            if session is not None:
807                raise ValueError("session can only be specified in "
808                                 "client mode")
809        if context.check_hostname and not server_hostname:
810            raise ValueError("check_hostname requires server_hostname")
811
812        kwargs = dict(
813            family=sock.family, type=sock.type, proto=sock.proto,
814            fileno=sock.fileno()
815        )
816        self = cls.__new__(cls, **kwargs)
817        super(SSLSocket, self).__init__(**kwargs)
818        self.settimeout(sock.gettimeout())
819        sock.detach()
820
821        self._context = context
822        self._session = session
823        self._closed = False
824        self._sslobj = None
825        self.server_side = server_side
826        self.server_hostname = context._encode_hostname(server_hostname)
827        self.do_handshake_on_connect = do_handshake_on_connect
828        self.suppress_ragged_eofs = suppress_ragged_eofs
829
830        # See if we are connected
831        try:
832            self.getpeername()
833        except OSError as e:
834            if e.errno != errno.ENOTCONN:
835                raise
836            connected = False
837        else:
838            connected = True
839
840        self._connected = connected
841        if connected:
842            # create the SSL object
843            try:
844                self._sslobj = self._context._wrap_socket(
845                    self, server_side, self.server_hostname,
846                    owner=self, session=self._session,
847                )
848                if do_handshake_on_connect:
849                    timeout = self.gettimeout()
850                    if timeout == 0.0:
851                        # non-blocking
852                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
853                    self.do_handshake()
854            except (OSError, ValueError):
855                self.close()
856                raise
857        return self
858
859    @property
860    def context(self):
861        return self._context
862
863    @context.setter
864    def context(self, ctx):
865        self._context = ctx
866        self._sslobj.context = ctx
867
868    @property
869    def session(self):
870        """The SSLSession for client socket."""
871        if self._sslobj is not None:
872            return self._sslobj.session
873
874    @session.setter
875    def session(self, session):
876        self._session = session
877        if self._sslobj is not None:
878            self._sslobj.session = session
879
880    @property
881    def session_reused(self):
882        """Was the client session reused during handshake"""
883        if self._sslobj is not None:
884            return self._sslobj.session_reused
885
886    def dup(self):
887        raise NotImplementedError("Can't dup() %s instances" %
888                                  self.__class__.__name__)
889
890    def _checkClosed(self, msg=None):
891        # raise an exception here if you wish to check for spurious closes
892        pass
893
894    def _check_connected(self):
895        if not self._connected:
896            # getpeername() will raise ENOTCONN if the socket is really
897            # not connected; note that we can be connected even without
898            # _connected being set, e.g. if connect() first returned
899            # EAGAIN.
900            self.getpeername()
901
902    def read(self, len=1024, buffer=None):
903        """Read up to LEN bytes and return them.
904        Return zero-length string on EOF."""
905
906        self._checkClosed()
907        if self._sslobj is None:
908            raise ValueError("Read on closed or unwrapped SSL socket.")
909        try:
910            if buffer is not None:
911                return self._sslobj.read(len, buffer)
912            else:
913                return self._sslobj.read(len)
914        except SSLError as x:
915            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
916                if buffer is not None:
917                    return 0
918                else:
919                    return b''
920            else:
921                raise
922
923    def write(self, data):
924        """Write DATA to the underlying SSL channel.  Returns
925        number of bytes of DATA actually transmitted."""
926
927        self._checkClosed()
928        if self._sslobj is None:
929            raise ValueError("Write on closed or unwrapped SSL socket.")
930        return self._sslobj.write(data)
931
932    def getpeercert(self, binary_form=False):
933        """Returns a formatted version of the data in the
934        certificate provided by the other end of the SSL channel.
935        Return None if no certificate was provided, {} if a
936        certificate was provided, but not validated."""
937
938        self._checkClosed()
939        self._check_connected()
940        return self._sslobj.getpeercert(binary_form)
941
942    def selected_npn_protocol(self):
943        self._checkClosed()
944        if self._sslobj is None or not _ssl.HAS_NPN:
945            return None
946        else:
947            return self._sslobj.selected_npn_protocol()
948
949    def selected_alpn_protocol(self):
950        self._checkClosed()
951        if self._sslobj is None or not _ssl.HAS_ALPN:
952            return None
953        else:
954            return self._sslobj.selected_alpn_protocol()
955
956    def cipher(self):
957        self._checkClosed()
958        if self._sslobj is None:
959            return None
960        else:
961            return self._sslobj.cipher()
962
963    def shared_ciphers(self):
964        self._checkClosed()
965        if self._sslobj is None:
966            return None
967        else:
968            return self._sslobj.shared_ciphers()
969
970    def compression(self):
971        self._checkClosed()
972        if self._sslobj is None:
973            return None
974        else:
975            return self._sslobj.compression()
976
977    def send(self, data, flags=0):
978        self._checkClosed()
979        if self._sslobj is not None:
980            if flags != 0:
981                raise ValueError(
982                    "non-zero flags not allowed in calls to send() on %s" %
983                    self.__class__)
984            return self._sslobj.write(data)
985        else:
986            return super().send(data, flags)
987
988    def sendto(self, data, flags_or_addr, addr=None):
989        self._checkClosed()
990        if self._sslobj is not None:
991            raise ValueError("sendto not allowed on instances of %s" %
992                             self.__class__)
993        elif addr is None:
994            return super().sendto(data, flags_or_addr)
995        else:
996            return super().sendto(data, flags_or_addr, addr)
997
998    def sendmsg(self, *args, **kwargs):
999        # Ensure programs don't send data unencrypted if they try to
1000        # use this method.
1001        raise NotImplementedError("sendmsg not allowed on instances of %s" %
1002                                  self.__class__)
1003
1004    def sendall(self, data, flags=0):
1005        self._checkClosed()
1006        if self._sslobj is not None:
1007            if flags != 0:
1008                raise ValueError(
1009                    "non-zero flags not allowed in calls to sendall() on %s" %
1010                    self.__class__)
1011            count = 0
1012            with memoryview(data) as view, view.cast("B") as byte_view:
1013                amount = len(byte_view)
1014                while count < amount:
1015                    v = self.send(byte_view[count:])
1016                    count += v
1017        else:
1018            return super().sendall(data, flags)
1019
1020    def sendfile(self, file, offset=0, count=None):
1021        """Send a file, possibly by using os.sendfile() if this is a
1022        clear-text socket.  Return the total number of bytes sent.
1023        """
1024        if self._sslobj is not None:
1025            return self._sendfile_use_send(file, offset, count)
1026        else:
1027            # os.sendfile() works with plain sockets only
1028            return super().sendfile(file, offset, count)
1029
1030    def recv(self, buflen=1024, flags=0):
1031        self._checkClosed()
1032        if self._sslobj is not None:
1033            if flags != 0:
1034                raise ValueError(
1035                    "non-zero flags not allowed in calls to recv() on %s" %
1036                    self.__class__)
1037            return self.read(buflen)
1038        else:
1039            return super().recv(buflen, flags)
1040
1041    def recv_into(self, buffer, nbytes=None, flags=0):
1042        self._checkClosed()
1043        if buffer and (nbytes is None):
1044            nbytes = len(buffer)
1045        elif nbytes is None:
1046            nbytes = 1024
1047        if self._sslobj is not None:
1048            if flags != 0:
1049                raise ValueError(
1050                  "non-zero flags not allowed in calls to recv_into() on %s" %
1051                  self.__class__)
1052            return self.read(nbytes, buffer)
1053        else:
1054            return super().recv_into(buffer, nbytes, flags)
1055
1056    def recvfrom(self, buflen=1024, flags=0):
1057        self._checkClosed()
1058        if self._sslobj is not None:
1059            raise ValueError("recvfrom not allowed on instances of %s" %
1060                             self.__class__)
1061        else:
1062            return super().recvfrom(buflen, flags)
1063
1064    def recvfrom_into(self, buffer, nbytes=None, flags=0):
1065        self._checkClosed()
1066        if self._sslobj is not None:
1067            raise ValueError("recvfrom_into not allowed on instances of %s" %
1068                             self.__class__)
1069        else:
1070            return super().recvfrom_into(buffer, nbytes, flags)
1071
1072    def recvmsg(self, *args, **kwargs):
1073        raise NotImplementedError("recvmsg not allowed on instances of %s" %
1074                                  self.__class__)
1075
1076    def recvmsg_into(self, *args, **kwargs):
1077        raise NotImplementedError("recvmsg_into not allowed on instances of "
1078                                  "%s" % self.__class__)
1079
1080    def pending(self):
1081        self._checkClosed()
1082        if self._sslobj is not None:
1083            return self._sslobj.pending()
1084        else:
1085            return 0
1086
1087    def shutdown(self, how):
1088        self._checkClosed()
1089        self._sslobj = None
1090        super().shutdown(how)
1091
1092    def unwrap(self):
1093        if self._sslobj:
1094            s = self._sslobj.shutdown()
1095            self._sslobj = None
1096            return s
1097        else:
1098            raise ValueError("No SSL wrapper around " + str(self))
1099
1100    def verify_client_post_handshake(self):
1101        if self._sslobj:
1102            return self._sslobj.verify_client_post_handshake()
1103        else:
1104            raise ValueError("No SSL wrapper around " + str(self))
1105
1106    def _real_close(self):
1107        self._sslobj = None
1108        super()._real_close()
1109
1110    def do_handshake(self, block=False):
1111        """Perform a TLS/SSL handshake."""
1112        self._check_connected()
1113        timeout = self.gettimeout()
1114        try:
1115            if timeout == 0.0 and block:
1116                self.settimeout(None)
1117            self._sslobj.do_handshake()
1118        finally:
1119            self.settimeout(timeout)
1120
1121    def _real_connect(self, addr, connect_ex):
1122        if self.server_side:
1123            raise ValueError("can't connect in server-side mode")
1124        # Here we assume that the socket is client-side, and not
1125        # connected at the time of the call.  We connect it, then wrap it.
1126        if self._connected or self._sslobj is not None:
1127            raise ValueError("attempt to connect already-connected SSLSocket!")
1128        self._sslobj = self.context._wrap_socket(
1129            self, False, self.server_hostname,
1130            owner=self, session=self._session
1131        )
1132        try:
1133            if connect_ex:
1134                rc = super().connect_ex(addr)
1135            else:
1136                rc = None
1137                super().connect(addr)
1138            if not rc:
1139                self._connected = True
1140                if self.do_handshake_on_connect:
1141                    self.do_handshake()
1142            return rc
1143        except (OSError, ValueError):
1144            self._sslobj = None
1145            raise
1146
1147    def connect(self, addr):
1148        """Connects to remote ADDR, and then wraps the connection in
1149        an SSL channel."""
1150        self._real_connect(addr, False)
1151
1152    def connect_ex(self, addr):
1153        """Connects to remote ADDR, and then wraps the connection in
1154        an SSL channel."""
1155        return self._real_connect(addr, True)
1156
1157    def accept(self):
1158        """Accepts a new connection from a remote client, and returns
1159        a tuple containing that new connection wrapped with a server-side
1160        SSL channel, and the address of the remote client."""
1161
1162        newsock, addr = super().accept()
1163        newsock = self.context.wrap_socket(newsock,
1164                    do_handshake_on_connect=self.do_handshake_on_connect,
1165                    suppress_ragged_eofs=self.suppress_ragged_eofs,
1166                    server_side=True)
1167        return newsock, addr
1168
1169    def get_channel_binding(self, cb_type="tls-unique"):
1170        """Get channel binding data for current connection.  Raise ValueError
1171        if the requested `cb_type` is not supported.  Return bytes of the data
1172        or None if the data is not available (e.g. before the handshake).
1173        """
1174        if self._sslobj is not None:
1175            return self._sslobj.get_channel_binding(cb_type)
1176        else:
1177            if cb_type not in CHANNEL_BINDING_TYPES:
1178                raise ValueError(
1179                    "{0} channel binding type not implemented".format(cb_type)
1180                )
1181            return None
1182
1183    def version(self):
1184        """
1185        Return a string identifying the protocol version used by the
1186        current SSL channel, or None if there is no established channel.
1187        """
1188        if self._sslobj is not None:
1189            return self._sslobj.version()
1190        else:
1191            return None
1192
1193
1194# Python does not support forward declaration of types.
1195SSLContext.sslsocket_class = SSLSocket
1196SSLContext.sslobject_class = SSLObject
1197
1198
1199def wrap_socket(sock, keyfile=None, certfile=None,
1200                server_side=False, cert_reqs=CERT_NONE,
1201                ssl_version=PROTOCOL_TLS, ca_certs=None,
1202                do_handshake_on_connect=True,
1203                suppress_ragged_eofs=True,
1204                ciphers=None):
1205
1206    if server_side and not certfile:
1207        raise ValueError("certfile must be specified for server-side "
1208                         "operations")
1209    if keyfile and not certfile:
1210        raise ValueError("certfile must be specified")
1211    context = SSLContext(ssl_version)
1212    context.verify_mode = cert_reqs
1213    if ca_certs:
1214        context.load_verify_locations(ca_certs)
1215    if certfile:
1216        context.load_cert_chain(certfile, keyfile)
1217    if ciphers:
1218        context.set_ciphers(ciphers)
1219    return context.wrap_socket(
1220        sock=sock, server_side=server_side,
1221        do_handshake_on_connect=do_handshake_on_connect,
1222        suppress_ragged_eofs=suppress_ragged_eofs
1223    )
1224
1225# some utility functions
1226
1227def cert_time_to_seconds(cert_time):
1228    """Return the time in seconds since the Epoch, given the timestring
1229    representing the "notBefore" or "notAfter" date from a certificate
1230    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
1231
1232    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
1233
1234    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
1235    UTC should be specified as GMT (see ASN1_TIME_print())
1236    """
1237    from time import strptime
1238    from calendar import timegm
1239
1240    months = (
1241        "Jan","Feb","Mar","Apr","May","Jun",
1242        "Jul","Aug","Sep","Oct","Nov","Dec"
1243    )
1244    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
1245    try:
1246        month_number = months.index(cert_time[:3].title()) + 1
1247    except ValueError:
1248        raise ValueError('time data %r does not match '
1249                         'format "%%b%s"' % (cert_time, time_format))
1250    else:
1251        # found valid month
1252        tt = strptime(cert_time[3:], time_format)
1253        # return an integer, the previous mktime()-based implementation
1254        # returned a float (fractional seconds are always zero here).
1255        return timegm((tt[0], month_number) + tt[2:6])
1256
1257PEM_HEADER = "-----BEGIN CERTIFICATE-----"
1258PEM_FOOTER = "-----END CERTIFICATE-----"
1259
1260def DER_cert_to_PEM_cert(der_cert_bytes):
1261    """Takes a certificate in binary DER format and returns the
1262    PEM version of it as a string."""
1263
1264    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
1265    ss = [PEM_HEADER]
1266    ss += [f[i:i+64] for i in range(0, len(f), 64)]
1267    ss.append(PEM_FOOTER + '\n')
1268    return '\n'.join(ss)
1269
1270def PEM_cert_to_DER_cert(pem_cert_string):
1271    """Takes a certificate in ASCII PEM format and returns the
1272    DER-encoded version of it as a byte sequence"""
1273
1274    if not pem_cert_string.startswith(PEM_HEADER):
1275        raise ValueError("Invalid PEM encoding; must start with %s"
1276                         % PEM_HEADER)
1277    if not pem_cert_string.strip().endswith(PEM_FOOTER):
1278        raise ValueError("Invalid PEM encoding; must end with %s"
1279                         % PEM_FOOTER)
1280    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1281    return base64.decodebytes(d.encode('ASCII', 'strict'))
1282
1283def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None):
1284    """Retrieve the certificate from the server at the specified address,
1285    and return it as a PEM-encoded string.
1286    If 'ca_certs' is specified, validate the server cert against it.
1287    If 'ssl_version' is specified, use it in the connection attempt."""
1288
1289    host, port = addr
1290    if ca_certs is not None:
1291        cert_reqs = CERT_REQUIRED
1292    else:
1293        cert_reqs = CERT_NONE
1294    context = _create_stdlib_context(ssl_version,
1295                                     cert_reqs=cert_reqs,
1296                                     cafile=ca_certs)
1297    with  create_connection(addr) as sock:
1298        with context.wrap_socket(sock) as sslsock:
1299            dercert = sslsock.getpeercert(True)
1300    return DER_cert_to_PEM_cert(dercert)
1301
1302def get_protocol_name(protocol_code):
1303    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1304