1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
22                          by the server running on HOST at port PORT.  No
23                          validation of the certificate is performed.
24
25Integer constants:
26
27SSL_ERROR_ZERO_RETURN
28SSL_ERROR_WANT_READ
29SSL_ERROR_WANT_WRITE
30SSL_ERROR_WANT_X509_LOOKUP
31SSL_ERROR_SYSCALL
32SSL_ERROR_SSL
33SSL_ERROR_WANT_CONNECT
34
35SSL_ERROR_EOF
36SSL_ERROR_INVALID_ERROR_CODE
37
38The following group define certificate requirements that one side is
39allowing/requiring from the other side:
40
41CERT_NONE - no certificates from the other side are required (or will
42            be looked at if provided)
43CERT_OPTIONAL - certificates are not required, but if provided will be
44                validated, and if validation fails, the connection will
45                also fail
46CERT_REQUIRED - certificates are required, and will be validated, and
47                if validation fails, the connection will also fail
48
49The following constants identify various SSL protocol variants:
50
51PROTOCOL_SSLv2
52PROTOCOL_SSLv3
53PROTOCOL_SSLv23
54PROTOCOL_TLS
55PROTOCOL_TLSv1
56PROTOCOL_TLSv1_1
57PROTOCOL_TLSv1_2
58
59The following constants identify various SSL alert message descriptions as per
60http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
61
62ALERT_DESCRIPTION_CLOSE_NOTIFY
63ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
64ALERT_DESCRIPTION_BAD_RECORD_MAC
65ALERT_DESCRIPTION_RECORD_OVERFLOW
66ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
67ALERT_DESCRIPTION_HANDSHAKE_FAILURE
68ALERT_DESCRIPTION_BAD_CERTIFICATE
69ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
70ALERT_DESCRIPTION_CERTIFICATE_REVOKED
71ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
72ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
73ALERT_DESCRIPTION_ILLEGAL_PARAMETER
74ALERT_DESCRIPTION_UNKNOWN_CA
75ALERT_DESCRIPTION_ACCESS_DENIED
76ALERT_DESCRIPTION_DECODE_ERROR
77ALERT_DESCRIPTION_DECRYPT_ERROR
78ALERT_DESCRIPTION_PROTOCOL_VERSION
79ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
80ALERT_DESCRIPTION_INTERNAL_ERROR
81ALERT_DESCRIPTION_USER_CANCELLED
82ALERT_DESCRIPTION_NO_RENEGOTIATION
83ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
84ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
85ALERT_DESCRIPTION_UNRECOGNIZED_NAME
86ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
87ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
88ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
89"""
90
91import textwrap
92import re
93import sys
94import os
95from collections import namedtuple
96from contextlib import closing
97
98import _ssl             # if we can't import it, let the error propagate
99
100from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
101from _ssl import _SSLContext
102from _ssl import (
103    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
104    SSLSyscallError, SSLEOFError,
105    )
106from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
107from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
108from _ssl import RAND_status, RAND_add
109try:
110    from _ssl import RAND_egd
111except ImportError:
112    # LibreSSL does not provide RAND_egd
113    pass
114
115def _import_symbols(prefix):
116    for n in dir(_ssl):
117        if n.startswith(prefix):
118            globals()[n] = getattr(_ssl, n)
119
120_import_symbols('OP_')
121_import_symbols('ALERT_DESCRIPTION_')
122_import_symbols('SSL_ERROR_')
123_import_symbols('PROTOCOL_')
124_import_symbols('VERIFY_')
125
126from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3
127
128from _ssl import _OPENSSL_API_VERSION
129
130_PROTOCOL_NAMES = {value: name for name, value in globals().items()
131                   if name.startswith('PROTOCOL_')
132                       and name != 'PROTOCOL_SSLv23'}
133PROTOCOL_SSLv23 = PROTOCOL_TLS
134
135try:
136    _SSLv2_IF_EXISTS = PROTOCOL_SSLv2
137except NameError:
138    _SSLv2_IF_EXISTS = None
139
140from socket import socket, _fileobject, _delegate_methods, error as socket_error
141if sys.platform == "win32":
142    from _ssl import enum_certificates, enum_crls
143
144from socket import socket, AF_INET, SOCK_STREAM, create_connection
145from socket import SOL_SOCKET, SO_TYPE
146import base64        # for DER-to-PEM translation
147import errno
148import warnings
149
150if _ssl.HAS_TLS_UNIQUE:
151    CHANNEL_BINDING_TYPES = ['tls-unique']
152else:
153    CHANNEL_BINDING_TYPES = []
154
155
156# Disable weak or insecure ciphers by default
157# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')
158# Enable a better set of ciphers by default
159# This list has been explicitly chosen to:
160#   * TLS 1.3 ChaCha20 and AES-GCM cipher suites
161#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
162#   * Prefer ECDHE over DHE for better performance
163#   * Prefer AEAD over CBC for better performance and security
164#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
165#     (ChaCha20 needs OpenSSL 1.1.0 or patched 1.0.2)
166#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
167#     performance and security
168#   * Then Use HIGH cipher suites as a fallback
169#   * Disable NULL authentication, NULL encryption, 3DES and MD5 MACs
170#     for security reasons
171_DEFAULT_CIPHERS = (
172    'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:'
173    'TLS13-AES-128-GCM-SHA256:'
174    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
175    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
176    '!aNULL:!eNULL:!MD5:!3DES'
177    )
178
179# Restricted and more secure ciphers for the server side
180# This list has been explicitly chosen to:
181#   * TLS 1.3 ChaCha20 and AES-GCM cipher suites
182#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
183#   * Prefer ECDHE over DHE for better performance
184#   * Prefer AEAD over CBC for better performance and security
185#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
186#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
187#     performance and security
188#   * Then Use HIGH cipher suites as a fallback
189#   * Disable NULL authentication, NULL encryption, MD5 MACs, DSS, RC4, and
190#     3DES for security reasons
191_RESTRICTED_SERVER_CIPHERS = (
192    'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:'
193    'TLS13-AES-128-GCM-SHA256:'
194    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
195    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
196    '!aNULL:!eNULL:!MD5:!DSS:!RC4:!3DES'
197)
198
199
200class CertificateError(ValueError):
201    pass
202
203
204def _dnsname_match(dn, hostname, max_wildcards=1):
205    """Matching according to RFC 6125, section 6.4.3
206
207    http://tools.ietf.org/html/rfc6125#section-6.4.3
208    """
209    pats = []
210    if not dn:
211        return False
212
213    pieces = dn.split(r'.')
214    leftmost = pieces[0]
215    remainder = pieces[1:]
216
217    wildcards = leftmost.count('*')
218    if wildcards > max_wildcards:
219        # Issue #17980: avoid denials of service by refusing more
220        # than one wildcard per fragment.  A survery of established
221        # policy among SSL implementations showed it to be a
222        # reasonable choice.
223        raise CertificateError(
224            "too many wildcards in certificate DNS name: " + repr(dn))
225
226    # speed up common case w/o wildcards
227    if not wildcards:
228        return dn.lower() == hostname.lower()
229
230    # RFC 6125, section 6.4.3, subitem 1.
231    # The client SHOULD NOT attempt to match a presented identifier in which
232    # the wildcard character comprises a label other than the left-most label.
233    if leftmost == '*':
234        # When '*' is a fragment by itself, it matches a non-empty dotless
235        # fragment.
236        pats.append('[^.]+')
237    elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
238        # RFC 6125, section 6.4.3, subitem 3.
239        # The client SHOULD NOT attempt to match a presented identifier
240        # where the wildcard character is embedded within an A-label or
241        # U-label of an internationalized domain name.
242        pats.append(re.escape(leftmost))
243    else:
244        # Otherwise, '*' matches any dotless string, e.g. www*
245        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
246
247    # add the remaining fragments, ignore any wildcards
248    for frag in remainder:
249        pats.append(re.escape(frag))
250
251    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
252    return pat.match(hostname)
253
254
255def match_hostname(cert, hostname):
256    """Verify that *cert* (in decoded format as returned by
257    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
258    rules are followed, but IP addresses are not accepted for *hostname*.
259
260    CertificateError is raised on failure. On success, the function
261    returns nothing.
262    """
263    if not cert:
264        raise ValueError("empty or no certificate, match_hostname needs a "
265                         "SSL socket or SSL context with either "
266                         "CERT_OPTIONAL or CERT_REQUIRED")
267    dnsnames = []
268    san = cert.get('subjectAltName', ())
269    for key, value in san:
270        if key == 'DNS':
271            if _dnsname_match(value, hostname):
272                return
273            dnsnames.append(value)
274    if not dnsnames:
275        # The subject is only checked when there is no dNSName entry
276        # in subjectAltName
277        for sub in cert.get('subject', ()):
278            for key, value in sub:
279                # XXX according to RFC 2818, the most specific Common Name
280                # must be used.
281                if key == 'commonName':
282                    if _dnsname_match(value, hostname):
283                        return
284                    dnsnames.append(value)
285    if len(dnsnames) > 1:
286        raise CertificateError("hostname %r "
287            "doesn't match either of %s"
288            % (hostname, ', '.join(map(repr, dnsnames))))
289    elif len(dnsnames) == 1:
290        raise CertificateError("hostname %r "
291            "doesn't match %r"
292            % (hostname, dnsnames[0]))
293    else:
294        raise CertificateError("no appropriate commonName or "
295            "subjectAltName fields were found")
296
297
298DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
299    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
300    "openssl_capath")
301
302def get_default_verify_paths():
303    """Return paths to default cafile and capath.
304    """
305    parts = _ssl.get_default_verify_paths()
306
307    # environment vars shadow paths
308    cafile = os.environ.get(parts[0], parts[1])
309    capath = os.environ.get(parts[2], parts[3])
310
311    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
312                              capath if os.path.isdir(capath) else None,
313                              *parts)
314
315
316class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
317    """ASN.1 object identifier lookup
318    """
319    __slots__ = ()
320
321    def __new__(cls, oid):
322        return super(_ASN1Object, cls).__new__(cls, *_txt2obj(oid, name=False))
323
324    @classmethod
325    def fromnid(cls, nid):
326        """Create _ASN1Object from OpenSSL numeric ID
327        """
328        return super(_ASN1Object, cls).__new__(cls, *_nid2obj(nid))
329
330    @classmethod
331    def fromname(cls, name):
332        """Create _ASN1Object from short name, long name or OID
333        """
334        return super(_ASN1Object, cls).__new__(cls, *_txt2obj(name, name=True))
335
336
337class Purpose(_ASN1Object):
338    """SSLContext purpose flags with X509v3 Extended Key Usage objects
339    """
340
341Purpose.SERVER_AUTH = Purpose('1.3.6.1.5.5.7.3.1')
342Purpose.CLIENT_AUTH = Purpose('1.3.6.1.5.5.7.3.2')
343
344
345class SSLContext(_SSLContext):
346    """An SSLContext holds various SSL-related configuration options and
347    data, such as certificates and possibly a private key."""
348
349    __slots__ = ('protocol', '__weakref__')
350    _windows_cert_stores = ("CA", "ROOT")
351
352    def __new__(cls, protocol, *args, **kwargs):
353        self = _SSLContext.__new__(cls, protocol)
354        if protocol != _SSLv2_IF_EXISTS:
355            self.set_ciphers(_DEFAULT_CIPHERS)
356        return self
357
358    def __init__(self, protocol):
359        self.protocol = protocol
360
361    def wrap_socket(self, sock, server_side=False,
362                    do_handshake_on_connect=True,
363                    suppress_ragged_eofs=True,
364                    server_hostname=None):
365        return SSLSocket(sock=sock, server_side=server_side,
366                         do_handshake_on_connect=do_handshake_on_connect,
367                         suppress_ragged_eofs=suppress_ragged_eofs,
368                         server_hostname=server_hostname,
369                         _context=self)
370
371    def set_npn_protocols(self, npn_protocols):
372        protos = bytearray()
373        for protocol in npn_protocols:
374            b = protocol.encode('ascii')
375            if len(b) == 0 or len(b) > 255:
376                raise SSLError('NPN protocols must be 1 to 255 in length')
377            protos.append(len(b))
378            protos.extend(b)
379
380        self._set_npn_protocols(protos)
381
382    def set_alpn_protocols(self, alpn_protocols):
383        protos = bytearray()
384        for protocol in alpn_protocols:
385            b = protocol.encode('ascii')
386            if len(b) == 0 or len(b) > 255:
387                raise SSLError('ALPN protocols must be 1 to 255 in length')
388            protos.append(len(b))
389            protos.extend(b)
390
391        self._set_alpn_protocols(protos)
392
393    def _load_windows_store_certs(self, storename, purpose):
394        certs = bytearray()
395        try:
396            for cert, encoding, trust in enum_certificates(storename):
397                # CA certs are never PKCS#7 encoded
398                if encoding == "x509_asn":
399                    if trust is True or purpose.oid in trust:
400                        certs.extend(cert)
401        except OSError:
402            warnings.warn("unable to enumerate Windows certificate store")
403        if certs:
404            self.load_verify_locations(cadata=certs)
405        return certs
406
407    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
408        if not isinstance(purpose, _ASN1Object):
409            raise TypeError(purpose)
410        if sys.platform == "win32":
411            for storename in self._windows_cert_stores:
412                self._load_windows_store_certs(storename, purpose)
413        self.set_default_verify_paths()
414
415
416def create_default_context(purpose=Purpose.SERVER_AUTH, cafile=None,
417                           capath=None, cadata=None):
418    """Create a SSLContext object with default settings.
419
420    NOTE: The protocol and settings may change anytime without prior
421          deprecation. The values represent a fair balance between maximum
422          compatibility and security.
423    """
424    if not isinstance(purpose, _ASN1Object):
425        raise TypeError(purpose)
426
427    context = SSLContext(PROTOCOL_TLS)
428
429    # SSLv2 considered harmful.
430    context.options |= OP_NO_SSLv2
431
432    # SSLv3 has problematic security and is only required for really old
433    # clients such as IE6 on Windows XP
434    context.options |= OP_NO_SSLv3
435
436    # disable compression to prevent CRIME attacks (OpenSSL 1.0+)
437    context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0)
438
439    if purpose == Purpose.SERVER_AUTH:
440        # verify certs and host name in client mode
441        context.verify_mode = CERT_REQUIRED
442        context.check_hostname = True
443    elif purpose == Purpose.CLIENT_AUTH:
444        # Prefer the server's ciphers by default so that we get stronger
445        # encryption
446        context.options |= getattr(_ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
447
448        # Use single use keys in order to improve forward secrecy
449        context.options |= getattr(_ssl, "OP_SINGLE_DH_USE", 0)
450        context.options |= getattr(_ssl, "OP_SINGLE_ECDH_USE", 0)
451
452        # disallow ciphers with known vulnerabilities
453        context.set_ciphers(_RESTRICTED_SERVER_CIPHERS)
454
455    if cafile or capath or cadata:
456        context.load_verify_locations(cafile, capath, cadata)
457    elif context.verify_mode != CERT_NONE:
458        # no explicit cafile, capath or cadata but the verify mode is
459        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
460        # root CA certificates for the given purpose. This may fail silently.
461        context.load_default_certs(purpose)
462    return context
463
464def _create_unverified_context(protocol=PROTOCOL_TLS, cert_reqs=None,
465                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
466                           certfile=None, keyfile=None,
467                           cafile=None, capath=None, cadata=None):
468    """Create a SSLContext object for Python stdlib modules
469
470    All Python stdlib modules shall use this function to create SSLContext
471    objects in order to keep common settings in one place. The configuration
472    is less restrict than create_default_context()'s to increase backward
473    compatibility.
474    """
475    if not isinstance(purpose, _ASN1Object):
476        raise TypeError(purpose)
477
478    context = SSLContext(protocol)
479    # SSLv2 considered harmful.
480    context.options |= OP_NO_SSLv2
481    # SSLv3 has problematic security and is only required for really old
482    # clients such as IE6 on Windows XP
483    context.options |= OP_NO_SSLv3
484
485    if cert_reqs is not None:
486        context.verify_mode = cert_reqs
487    context.check_hostname = check_hostname
488
489    if keyfile and not certfile:
490        raise ValueError("certfile must be specified")
491    if certfile or keyfile:
492        context.load_cert_chain(certfile, keyfile)
493
494    # load CA root certs
495    if cafile or capath or cadata:
496        context.load_verify_locations(cafile, capath, cadata)
497    elif context.verify_mode != CERT_NONE:
498        # no explicit cafile, capath or cadata but the verify mode is
499        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
500        # root CA certificates for the given purpose. This may fail silently.
501        context.load_default_certs(purpose)
502
503    return context
504
505# Backwards compatibility alias, even though it's not a public name.
506_create_stdlib_context = _create_unverified_context
507
508# PEP 493: Verify HTTPS by default, but allow envvar to override that
509_https_verify_envvar = 'PYTHONHTTPSVERIFY'
510
511def _get_https_context_factory():
512    if not sys.flags.ignore_environment:
513        config_setting = os.environ.get(_https_verify_envvar)
514        if config_setting == '0':
515            return _create_unverified_context
516    return create_default_context
517
518_create_default_https_context = _get_https_context_factory()
519
520# PEP 493: "private" API to configure HTTPS defaults without monkeypatching
521def _https_verify_certificates(enable=True):
522    """Verify server HTTPS certificates by default?"""
523    global _create_default_https_context
524    if enable:
525        _create_default_https_context = create_default_context
526    else:
527        _create_default_https_context = _create_unverified_context
528
529
530class SSLSocket(socket):
531    """This class implements a subtype of socket.socket that wraps
532    the underlying OS socket in an SSL context when necessary, and
533    provides read and write methods over that channel."""
534
535    def __init__(self, sock=None, keyfile=None, certfile=None,
536                 server_side=False, cert_reqs=CERT_NONE,
537                 ssl_version=PROTOCOL_TLS, ca_certs=None,
538                 do_handshake_on_connect=True,
539                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
540                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
541                 server_hostname=None,
542                 _context=None):
543
544        self._makefile_refs = 0
545        if _context:
546            self._context = _context
547        else:
548            if server_side and not certfile:
549                raise ValueError("certfile must be specified for server-side "
550                                 "operations")
551            if keyfile and not certfile:
552                raise ValueError("certfile must be specified")
553            if certfile and not keyfile:
554                keyfile = certfile
555            self._context = SSLContext(ssl_version)
556            self._context.verify_mode = cert_reqs
557            if ca_certs:
558                self._context.load_verify_locations(ca_certs)
559            if certfile:
560                self._context.load_cert_chain(certfile, keyfile)
561            if npn_protocols:
562                self._context.set_npn_protocols(npn_protocols)
563            if ciphers:
564                self._context.set_ciphers(ciphers)
565            self.keyfile = keyfile
566            self.certfile = certfile
567            self.cert_reqs = cert_reqs
568            self.ssl_version = ssl_version
569            self.ca_certs = ca_certs
570            self.ciphers = ciphers
571        # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
572        # mixed in.
573        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
574            raise NotImplementedError("only stream sockets are supported")
575        socket.__init__(self, _sock=sock._sock)
576        # The initializer for socket overrides the methods send(), recv(), etc.
577        # in the instancce, which we don't need -- but we want to provide the
578        # methods defined in SSLSocket.
579        for attr in _delegate_methods:
580            try:
581                delattr(self, attr)
582            except AttributeError:
583                pass
584        if server_side and server_hostname:
585            raise ValueError("server_hostname can only be specified "
586                             "in client mode")
587        if self._context.check_hostname and not server_hostname:
588            raise ValueError("check_hostname requires server_hostname")
589        self.server_side = server_side
590        self.server_hostname = server_hostname
591        self.do_handshake_on_connect = do_handshake_on_connect
592        self.suppress_ragged_eofs = suppress_ragged_eofs
593
594        # See if we are connected
595        try:
596            self.getpeername()
597        except socket_error as e:
598            if e.errno != errno.ENOTCONN:
599                raise
600            connected = False
601        else:
602            connected = True
603
604        self._closed = False
605        self._sslobj = None
606        self._connected = connected
607        if connected:
608            # create the SSL object
609            try:
610                self._sslobj = self._context._wrap_socket(self._sock, server_side,
611                                                          server_hostname, ssl_sock=self)
612                if do_handshake_on_connect:
613                    timeout = self.gettimeout()
614                    if timeout == 0.0:
615                        # non-blocking
616                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
617                    self.do_handshake()
618
619            except (OSError, ValueError):
620                self.close()
621                raise
622
623    @property
624    def context(self):
625        return self._context
626
627    @context.setter
628    def context(self, ctx):
629        self._context = ctx
630        self._sslobj.context = ctx
631
632    def dup(self):
633        raise NotImplemented("Can't dup() %s instances" %
634                             self.__class__.__name__)
635
636    def _checkClosed(self, msg=None):
637        # raise an exception here if you wish to check for spurious closes
638        pass
639
640    def _check_connected(self):
641        if not self._connected:
642            # getpeername() will raise ENOTCONN if the socket is really
643            # not connected; note that we can be connected even without
644            # _connected being set, e.g. if connect() first returned
645            # EAGAIN.
646            self.getpeername()
647
648    def read(self, len=1024, buffer=None):
649        """Read up to LEN bytes and return them.
650        Return zero-length string on EOF."""
651
652        self._checkClosed()
653        if not self._sslobj:
654            raise ValueError("Read on closed or unwrapped SSL socket.")
655        try:
656            if buffer is not None:
657                v = self._sslobj.read(len, buffer)
658            else:
659                v = self._sslobj.read(len)
660            return v
661        except SSLError as x:
662            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
663                if buffer is not None:
664                    return 0
665                else:
666                    return b''
667            else:
668                raise
669
670    def write(self, data):
671        """Write DATA to the underlying SSL channel.  Returns
672        number of bytes of DATA actually transmitted."""
673
674        self._checkClosed()
675        if not self._sslobj:
676            raise ValueError("Write on closed or unwrapped SSL socket.")
677        return self._sslobj.write(data)
678
679    def getpeercert(self, binary_form=False):
680        """Returns a formatted version of the data in the
681        certificate provided by the other end of the SSL channel.
682        Return None if no certificate was provided, {} if a
683        certificate was provided, but not validated."""
684
685        self._checkClosed()
686        self._check_connected()
687        return self._sslobj.peer_certificate(binary_form)
688
689    def selected_npn_protocol(self):
690        self._checkClosed()
691        if not self._sslobj or not _ssl.HAS_NPN:
692            return None
693        else:
694            return self._sslobj.selected_npn_protocol()
695
696    def selected_alpn_protocol(self):
697        self._checkClosed()
698        if not self._sslobj or not _ssl.HAS_ALPN:
699            return None
700        else:
701            return self._sslobj.selected_alpn_protocol()
702
703    def cipher(self):
704        self._checkClosed()
705        if not self._sslobj:
706            return None
707        else:
708            return self._sslobj.cipher()
709
710    def compression(self):
711        self._checkClosed()
712        if not self._sslobj:
713            return None
714        else:
715            return self._sslobj.compression()
716
717    def send(self, data, flags=0):
718        self._checkClosed()
719        if self._sslobj:
720            if flags != 0:
721                raise ValueError(
722                    "non-zero flags not allowed in calls to send() on %s" %
723                    self.__class__)
724            try:
725                v = self._sslobj.write(data)
726            except SSLError as x:
727                if x.args[0] == SSL_ERROR_WANT_READ:
728                    return 0
729                elif x.args[0] == SSL_ERROR_WANT_WRITE:
730                    return 0
731                else:
732                    raise
733            else:
734                return v
735        else:
736            return self._sock.send(data, flags)
737
738    def sendto(self, data, flags_or_addr, addr=None):
739        self._checkClosed()
740        if self._sslobj:
741            raise ValueError("sendto not allowed on instances of %s" %
742                             self.__class__)
743        elif addr is None:
744            return self._sock.sendto(data, flags_or_addr)
745        else:
746            return self._sock.sendto(data, flags_or_addr, addr)
747
748
749    def sendall(self, data, flags=0):
750        self._checkClosed()
751        if self._sslobj:
752            if flags != 0:
753                raise ValueError(
754                    "non-zero flags not allowed in calls to sendall() on %s" %
755                    self.__class__)
756            amount = len(data)
757            count = 0
758            while (count < amount):
759                v = self.send(data[count:])
760                count += v
761            return amount
762        else:
763            return socket.sendall(self, data, flags)
764
765    def recv(self, buflen=1024, flags=0):
766        self._checkClosed()
767        if self._sslobj:
768            if flags != 0:
769                raise ValueError(
770                    "non-zero flags not allowed in calls to recv() on %s" %
771                    self.__class__)
772            return self.read(buflen)
773        else:
774            return self._sock.recv(buflen, flags)
775
776    def recv_into(self, buffer, nbytes=None, flags=0):
777        self._checkClosed()
778        if buffer and (nbytes is None):
779            nbytes = len(buffer)
780        elif nbytes is None:
781            nbytes = 1024
782        if self._sslobj:
783            if flags != 0:
784                raise ValueError(
785                  "non-zero flags not allowed in calls to recv_into() on %s" %
786                  self.__class__)
787            return self.read(nbytes, buffer)
788        else:
789            return self._sock.recv_into(buffer, nbytes, flags)
790
791    def recvfrom(self, buflen=1024, flags=0):
792        self._checkClosed()
793        if self._sslobj:
794            raise ValueError("recvfrom not allowed on instances of %s" %
795                             self.__class__)
796        else:
797            return self._sock.recvfrom(buflen, flags)
798
799    def recvfrom_into(self, buffer, nbytes=None, flags=0):
800        self._checkClosed()
801        if self._sslobj:
802            raise ValueError("recvfrom_into not allowed on instances of %s" %
803                             self.__class__)
804        else:
805            return self._sock.recvfrom_into(buffer, nbytes, flags)
806
807
808    def pending(self):
809        self._checkClosed()
810        if self._sslobj:
811            return self._sslobj.pending()
812        else:
813            return 0
814
815    def shutdown(self, how):
816        self._checkClosed()
817        self._sslobj = None
818        socket.shutdown(self, how)
819
820    def close(self):
821        if self._makefile_refs < 1:
822            self._sslobj = None
823            socket.close(self)
824        else:
825            self._makefile_refs -= 1
826
827    def unwrap(self):
828        if self._sslobj:
829            s = self._sslobj.shutdown()
830            self._sslobj = None
831            return s
832        else:
833            raise ValueError("No SSL wrapper around " + str(self))
834
835    def _real_close(self):
836        self._sslobj = None
837        socket._real_close(self)
838
839    def do_handshake(self, block=False):
840        """Perform a TLS/SSL handshake."""
841        self._check_connected()
842        timeout = self.gettimeout()
843        try:
844            if timeout == 0.0 and block:
845                self.settimeout(None)
846            self._sslobj.do_handshake()
847        finally:
848            self.settimeout(timeout)
849
850        if self.context.check_hostname:
851            if not self.server_hostname:
852                raise ValueError("check_hostname needs server_hostname "
853                                 "argument")
854            match_hostname(self.getpeercert(), self.server_hostname)
855
856    def _real_connect(self, addr, connect_ex):
857        if self.server_side:
858            raise ValueError("can't connect in server-side mode")
859        # Here we assume that the socket is client-side, and not
860        # connected at the time of the call.  We connect it, then wrap it.
861        if self._connected:
862            raise ValueError("attempt to connect already-connected SSLSocket!")
863        self._sslobj = self.context._wrap_socket(self._sock, False, self.server_hostname, ssl_sock=self)
864        try:
865            if connect_ex:
866                rc = socket.connect_ex(self, addr)
867            else:
868                rc = None
869                socket.connect(self, addr)
870            if not rc:
871                self._connected = True
872                if self.do_handshake_on_connect:
873                    self.do_handshake()
874            return rc
875        except (OSError, ValueError):
876            self._sslobj = None
877            raise
878
879    def connect(self, addr):
880        """Connects to remote ADDR, and then wraps the connection in
881        an SSL channel."""
882        self._real_connect(addr, False)
883
884    def connect_ex(self, addr):
885        """Connects to remote ADDR, and then wraps the connection in
886        an SSL channel."""
887        return self._real_connect(addr, True)
888
889    def accept(self):
890        """Accepts a new connection from a remote client, and returns
891        a tuple containing that new connection wrapped with a server-side
892        SSL channel, and the address of the remote client."""
893
894        newsock, addr = socket.accept(self)
895        newsock = self.context.wrap_socket(newsock,
896                    do_handshake_on_connect=self.do_handshake_on_connect,
897                    suppress_ragged_eofs=self.suppress_ragged_eofs,
898                    server_side=True)
899        return newsock, addr
900
901    def makefile(self, mode='r', bufsize=-1):
902
903        """Make and return a file-like object that
904        works with the SSL connection.  Just use the code
905        from the socket module."""
906
907        self._makefile_refs += 1
908        # close=True so as to decrement the reference count when done with
909        # the file-like object.
910        return _fileobject(self, mode, bufsize, close=True)
911
912    def get_channel_binding(self, cb_type="tls-unique"):
913        """Get channel binding data for current connection.  Raise ValueError
914        if the requested `cb_type` is not supported.  Return bytes of the data
915        or None if the data is not available (e.g. before the handshake).
916        """
917        if cb_type not in CHANNEL_BINDING_TYPES:
918            raise ValueError("Unsupported channel binding type")
919        if cb_type != "tls-unique":
920            raise NotImplementedError(
921                            "{0} channel binding type not implemented"
922                            .format(cb_type))
923        if self._sslobj is None:
924            return None
925        return self._sslobj.tls_unique_cb()
926
927    def version(self):
928        """
929        Return a string identifying the protocol version used by the
930        current SSL channel, or None if there is no established channel.
931        """
932        if self._sslobj is None:
933            return None
934        return self._sslobj.version()
935
936
937def wrap_socket(sock, keyfile=None, certfile=None,
938                server_side=False, cert_reqs=CERT_NONE,
939                ssl_version=PROTOCOL_TLS, ca_certs=None,
940                do_handshake_on_connect=True,
941                suppress_ragged_eofs=True,
942                ciphers=None):
943
944    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
945                     server_side=server_side, cert_reqs=cert_reqs,
946                     ssl_version=ssl_version, ca_certs=ca_certs,
947                     do_handshake_on_connect=do_handshake_on_connect,
948                     suppress_ragged_eofs=suppress_ragged_eofs,
949                     ciphers=ciphers)
950
951# some utility functions
952
953def cert_time_to_seconds(cert_time):
954    """Return the time in seconds since the Epoch, given the timestring
955    representing the "notBefore" or "notAfter" date from a certificate
956    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
957
958    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
959
960    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
961    UTC should be specified as GMT (see ASN1_TIME_print())
962    """
963    from time import strptime
964    from calendar import timegm
965
966    months = (
967        "Jan","Feb","Mar","Apr","May","Jun",
968        "Jul","Aug","Sep","Oct","Nov","Dec"
969    )
970    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
971    try:
972        month_number = months.index(cert_time[:3].title()) + 1
973    except ValueError:
974        raise ValueError('time data %r does not match '
975                         'format "%%b%s"' % (cert_time, time_format))
976    else:
977        # found valid month
978        tt = strptime(cert_time[3:], time_format)
979        # return an integer, the previous mktime()-based implementation
980        # returned a float (fractional seconds are always zero here).
981        return timegm((tt[0], month_number) + tt[2:6])
982
983PEM_HEADER = "-----BEGIN CERTIFICATE-----"
984PEM_FOOTER = "-----END CERTIFICATE-----"
985
986def DER_cert_to_PEM_cert(der_cert_bytes):
987    """Takes a certificate in binary DER format and returns the
988    PEM version of it as a string."""
989
990    f = base64.standard_b64encode(der_cert_bytes).decode('ascii')
991    return (PEM_HEADER + '\n' +
992            textwrap.fill(f, 64) + '\n' +
993            PEM_FOOTER + '\n')
994
995def PEM_cert_to_DER_cert(pem_cert_string):
996    """Takes a certificate in ASCII PEM format and returns the
997    DER-encoded version of it as a byte sequence"""
998
999    if not pem_cert_string.startswith(PEM_HEADER):
1000        raise ValueError("Invalid PEM encoding; must start with %s"
1001                         % PEM_HEADER)
1002    if not pem_cert_string.strip().endswith(PEM_FOOTER):
1003        raise ValueError("Invalid PEM encoding; must end with %s"
1004                         % PEM_FOOTER)
1005    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1006    return base64.decodestring(d.encode('ASCII', 'strict'))
1007
1008def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None):
1009    """Retrieve the certificate from the server at the specified address,
1010    and return it as a PEM-encoded string.
1011    If 'ca_certs' is specified, validate the server cert against it.
1012    If 'ssl_version' is specified, use it in the connection attempt."""
1013
1014    host, port = addr
1015    if ca_certs is not None:
1016        cert_reqs = CERT_REQUIRED
1017    else:
1018        cert_reqs = CERT_NONE
1019    context = _create_stdlib_context(ssl_version,
1020                                     cert_reqs=cert_reqs,
1021                                     cafile=ca_certs)
1022    with closing(create_connection(addr)) as sock:
1023        with closing(context.wrap_socket(sock)) as sslsock:
1024            dercert = sslsock.getpeercert(True)
1025    return DER_cert_to_PEM_cert(dercert)
1026
1027def get_protocol_name(protocol_code):
1028    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1029
1030
1031# a replacement for the old socket.ssl function
1032
1033def sslwrap_simple(sock, keyfile=None, certfile=None):
1034    """A replacement for the old socket.ssl function.  Designed
1035    for compability with Python 2.5 and earlier.  Will disappear in
1036    Python 3.0."""
1037    if hasattr(sock, "_sock"):
1038        sock = sock._sock
1039
1040    ctx = SSLContext(PROTOCOL_SSLv23)
1041    if keyfile or certfile:
1042        ctx.load_cert_chain(certfile, keyfile)
1043    ssl_sock = ctx._wrap_socket(sock, server_side=False)
1044    try:
1045        sock.getpeername()
1046    except socket_error:
1047        # no, no connection yet
1048        pass
1049    else:
1050        # yes, do the handshake
1051        ssl_sock.do_handshake()
1052
1053    return ssl_sock
1054