1# Test the support for SSL and sockets
2
3import sys
4import unittest
5from test import support
6import socket
7import select
8import time
9import datetime
10import gc
11import os
12import errno
13import pprint
14import urllib.request
15import threading
16import traceback
17import asyncore
18import weakref
19import platform
20import functools
21import sysconfig
22try:
23    import ctypes
24except ImportError:
25    ctypes = None
26
27ssl = support.import_module("ssl")
28
29
30PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
31HOST = support.HOST
32IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
33IS_OPENSSL_1_1_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
34IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
35PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
36
37PROTOCOL_TO_TLS_VERSION = {}
38for proto, ver in (
39    ("PROTOCOL_SSLv23", "SSLv3"),
40    ("PROTOCOL_TLSv1", "TLSv1"),
41    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
42):
43    try:
44        proto = getattr(ssl, proto)
45        ver = getattr(ssl.TLSVersion, ver)
46    except AttributeError:
47        continue
48    PROTOCOL_TO_TLS_VERSION[proto] = ver
49
50def data_file(*name):
51    return os.path.join(os.path.dirname(__file__), *name)
52
53# The custom key and certificate files used in test_ssl are generated
54# using Lib/test/make_ssl_certs.py.
55# Other certificates are simply fetched from the Internet servers they
56# are meant to authenticate.
57
58CERTFILE = data_file("keycert.pem")
59BYTES_CERTFILE = os.fsencode(CERTFILE)
60ONLYCERT = data_file("ssl_cert.pem")
61ONLYKEY = data_file("ssl_key.pem")
62BYTES_ONLYCERT = os.fsencode(ONLYCERT)
63BYTES_ONLYKEY = os.fsencode(ONLYKEY)
64CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
65ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
66KEY_PASSWORD = "somepass"
67CAPATH = data_file("capath")
68BYTES_CAPATH = os.fsencode(CAPATH)
69CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
70CAFILE_CACERT = data_file("capath", "5ed36f99.0")
71
72CERTFILE_INFO = {
73    'issuer': ((('countryName', 'XY'),),
74               (('localityName', 'Castle Anthrax'),),
75               (('organizationName', 'Python Software Foundation'),),
76               (('commonName', 'localhost'),)),
77    'notAfter': 'Aug 26 14:23:15 2028 GMT',
78    'notBefore': 'Aug 29 14:23:15 2018 GMT',
79    'serialNumber': '98A7CF88C74A32ED',
80    'subject': ((('countryName', 'XY'),),
81             (('localityName', 'Castle Anthrax'),),
82             (('organizationName', 'Python Software Foundation'),),
83             (('commonName', 'localhost'),)),
84    'subjectAltName': (('DNS', 'localhost'),),
85    'version': 3
86}
87
88# empty CRL
89CRLFILE = data_file("revocation.crl")
90
91# Two keys and certs signed by the same CA (for SNI tests)
92SIGNED_CERTFILE = data_file("keycert3.pem")
93SIGNED_CERTFILE_HOSTNAME = 'localhost'
94
95SIGNED_CERTFILE_INFO = {
96    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
97    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
98    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
99    'issuer': ((('countryName', 'XY'),),
100            (('organizationName', 'Python Software Foundation CA'),),
101            (('commonName', 'our-ca-server'),)),
102    'notAfter': 'Jul  7 14:23:16 2028 GMT',
103    'notBefore': 'Aug 29 14:23:16 2018 GMT',
104    'serialNumber': 'CB2D80995A69525C',
105    'subject': ((('countryName', 'XY'),),
106             (('localityName', 'Castle Anthrax'),),
107             (('organizationName', 'Python Software Foundation'),),
108             (('commonName', 'localhost'),)),
109    'subjectAltName': (('DNS', 'localhost'),),
110    'version': 3
111}
112
113SIGNED_CERTFILE2 = data_file("keycert4.pem")
114SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
115SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
116SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
117
118# Same certificate as pycacert.pem, but without extra text in file
119SIGNING_CA = data_file("capath", "ceff1710.0")
120# cert with all kinds of subject alt names
121ALLSANFILE = data_file("allsans.pem")
122IDNSANSFILE = data_file("idnsans.pem")
123
124REMOTE_HOST = "self-signed.pythontest.net"
125
126EMPTYCERT = data_file("nullcert.pem")
127BADCERT = data_file("badcert.pem")
128NONEXISTINGCERT = data_file("XXXnonexisting.pem")
129BADKEY = data_file("badkey.pem")
130NOKIACERT = data_file("nokia.pem")
131NULLBYTECERT = data_file("nullbytecert.pem")
132TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
133
134DHFILE = data_file("ffdh3072.pem")
135BYTES_DHFILE = os.fsencode(DHFILE)
136
137# Not defined in all versions of OpenSSL
138OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
139OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
140OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
141OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
142OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
143
144
145def handle_error(prefix):
146    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
147    if support.verbose:
148        sys.stdout.write(prefix + exc_format)
149
150def can_clear_options():
151    # 0.9.8m or higher
152    return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
153
154def no_sslv2_implies_sslv3_hello():
155    # 0.9.7h or higher
156    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
157
158def have_verify_flags():
159    # 0.9.8 or higher
160    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
161
162def _have_secp_curves():
163    if not ssl.HAS_ECDH:
164        return False
165    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
166    try:
167        ctx.set_ecdh_curve("secp384r1")
168    except ValueError:
169        return False
170    else:
171        return True
172
173
174HAVE_SECP_CURVES = _have_secp_curves()
175
176
177def utc_offset(): #NOTE: ignore issues like #1647654
178    # local time = utc time + utc offset
179    if time.daylight and time.localtime().tm_isdst > 0:
180        return -time.altzone  # seconds
181    return -time.timezone
182
183def asn1time(cert_time):
184    # Some versions of OpenSSL ignore seconds, see #18207
185    # 0.9.8.i
186    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
187        fmt = "%b %d %H:%M:%S %Y GMT"
188        dt = datetime.datetime.strptime(cert_time, fmt)
189        dt = dt.replace(second=0)
190        cert_time = dt.strftime(fmt)
191        # %d adds leading zero but ASN1_TIME_print() uses leading space
192        if cert_time[4] == "0":
193            cert_time = cert_time[:4] + " " + cert_time[5:]
194
195    return cert_time
196
197# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2
198def skip_if_broken_ubuntu_ssl(func):
199    if hasattr(ssl, 'PROTOCOL_SSLv2'):
200        @functools.wraps(func)
201        def f(*args, **kwargs):
202            try:
203                ssl.SSLContext(ssl.PROTOCOL_SSLv2)
204            except ssl.SSLError:
205                if (ssl.OPENSSL_VERSION_INFO == (0, 9, 8, 15, 15) and
206                    platform.linux_distribution() == ('debian', 'squeeze/sid', '')):
207                    raise unittest.SkipTest("Patched Ubuntu OpenSSL breaks behaviour")
208            return func(*args, **kwargs)
209        return f
210    else:
211        return func
212
213needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
214
215
216def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *,
217                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
218                     ciphers=None, certfile=None, keyfile=None,
219                     **kwargs):
220    context = ssl.SSLContext(ssl_version)
221    if cert_reqs is not None:
222        if cert_reqs == ssl.CERT_NONE:
223            context.check_hostname = False
224        context.verify_mode = cert_reqs
225    if ca_certs is not None:
226        context.load_verify_locations(ca_certs)
227    if certfile is not None or keyfile is not None:
228        context.load_cert_chain(certfile, keyfile)
229    if ciphers is not None:
230        context.set_ciphers(ciphers)
231    return context.wrap_socket(sock, **kwargs)
232
233
234def testing_context(server_cert=SIGNED_CERTFILE):
235    """Create context
236
237    client_context, server_context, hostname = testing_context()
238    """
239    if server_cert == SIGNED_CERTFILE:
240        hostname = SIGNED_CERTFILE_HOSTNAME
241    elif server_cert == SIGNED_CERTFILE2:
242        hostname = SIGNED_CERTFILE2_HOSTNAME
243    else:
244        raise ValueError(server_cert)
245
246    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
247    client_context.load_verify_locations(SIGNING_CA)
248
249    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
250    server_context.load_cert_chain(server_cert)
251    server_context.load_verify_locations(SIGNING_CA)
252
253    return client_context, server_context, hostname
254
255
256class BasicSocketTests(unittest.TestCase):
257
258    def test_constants(self):
259        ssl.CERT_NONE
260        ssl.CERT_OPTIONAL
261        ssl.CERT_REQUIRED
262        ssl.OP_CIPHER_SERVER_PREFERENCE
263        ssl.OP_SINGLE_DH_USE
264        if ssl.HAS_ECDH:
265            ssl.OP_SINGLE_ECDH_USE
266        if ssl.OPENSSL_VERSION_INFO >= (1, 0):
267            ssl.OP_NO_COMPRESSION
268        self.assertIn(ssl.HAS_SNI, {True, False})
269        self.assertIn(ssl.HAS_ECDH, {True, False})
270        ssl.OP_NO_SSLv2
271        ssl.OP_NO_SSLv3
272        ssl.OP_NO_TLSv1
273        ssl.OP_NO_TLSv1_3
274        if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1):
275            ssl.OP_NO_TLSv1_1
276            ssl.OP_NO_TLSv1_2
277        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
278
279    def test_private_init(self):
280        with self.assertRaisesRegex(TypeError, "public constructor"):
281            with socket.socket() as s:
282                ssl.SSLSocket(s)
283
284    def test_str_for_enums(self):
285        # Make sure that the PROTOCOL_* constants have enum-like string
286        # reprs.
287        proto = ssl.PROTOCOL_TLS
288        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
289        ctx = ssl.SSLContext(proto)
290        self.assertIs(ctx.protocol, proto)
291
292    def test_random(self):
293        v = ssl.RAND_status()
294        if support.verbose:
295            sys.stdout.write("\n RAND_status is %d (%s)\n"
296                             % (v, (v and "sufficient randomness") or
297                                "insufficient randomness"))
298
299        data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
300        self.assertEqual(len(data), 16)
301        self.assertEqual(is_cryptographic, v == 1)
302        if v:
303            data = ssl.RAND_bytes(16)
304            self.assertEqual(len(data), 16)
305        else:
306            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
307
308        # negative num is invalid
309        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
310        self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
311
312        if hasattr(ssl, 'RAND_egd'):
313            self.assertRaises(TypeError, ssl.RAND_egd, 1)
314            self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
315        ssl.RAND_add("this is a random string", 75.0)
316        ssl.RAND_add(b"this is a random bytes object", 75.0)
317        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
318
319    @unittest.skipUnless(os.name == 'posix', 'requires posix')
320    def test_random_fork(self):
321        status = ssl.RAND_status()
322        if not status:
323            self.fail("OpenSSL's PRNG has insufficient randomness")
324
325        rfd, wfd = os.pipe()
326        pid = os.fork()
327        if pid == 0:
328            try:
329                os.close(rfd)
330                child_random = ssl.RAND_pseudo_bytes(16)[0]
331                self.assertEqual(len(child_random), 16)
332                os.write(wfd, child_random)
333                os.close(wfd)
334            except BaseException:
335                os._exit(1)
336            else:
337                os._exit(0)
338        else:
339            os.close(wfd)
340            self.addCleanup(os.close, rfd)
341            _, status = os.waitpid(pid, 0)
342            self.assertEqual(status, 0)
343
344            child_random = os.read(rfd, 16)
345            self.assertEqual(len(child_random), 16)
346            parent_random = ssl.RAND_pseudo_bytes(16)[0]
347            self.assertEqual(len(parent_random), 16)
348
349            self.assertNotEqual(child_random, parent_random)
350
351    maxDiff = None
352
353    def test_parse_cert(self):
354        # note that this uses an 'unofficial' function in _ssl.c,
355        # provided solely for this test, to exercise the certificate
356        # parsing code
357        self.assertEqual(
358            ssl._ssl._test_decode_cert(CERTFILE),
359            CERTFILE_INFO
360        )
361        self.assertEqual(
362            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
363            SIGNED_CERTFILE_INFO
364        )
365
366        # Issue #13034: the subjectAltName in some certificates
367        # (notably projects.developer.nokia.com:443) wasn't parsed
368        p = ssl._ssl._test_decode_cert(NOKIACERT)
369        if support.verbose:
370            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
371        self.assertEqual(p['subjectAltName'],
372                         (('DNS', 'projects.developer.nokia.com'),
373                          ('DNS', 'projects.forum.nokia.com'))
374                        )
375        # extra OCSP and AIA fields
376        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
377        self.assertEqual(p['caIssuers'],
378                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
379        self.assertEqual(p['crlDistributionPoints'],
380                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
381
382    def test_parse_cert_CVE_2019_5010(self):
383        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
384        if support.verbose:
385            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
386        self.assertEqual(
387            p,
388            {
389                'issuer': (
390                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
391                'notAfter': 'Jun 14 18:00:58 2028 GMT',
392                'notBefore': 'Jun 18 18:00:58 2018 GMT',
393                'serialNumber': '02',
394                'subject': ((('countryName', 'UK'),),
395                            (('commonName',
396                              'codenomicon-vm-2.test.lal.cisco.com'),)),
397                'subjectAltName': (
398                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
399                'version': 3
400            }
401        )
402
403    def test_parse_cert_CVE_2013_4238(self):
404        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
405        if support.verbose:
406            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
407        subject = ((('countryName', 'US'),),
408                   (('stateOrProvinceName', 'Oregon'),),
409                   (('localityName', 'Beaverton'),),
410                   (('organizationName', 'Python Software Foundation'),),
411                   (('organizationalUnitName', 'Python Core Development'),),
412                   (('commonName', 'null.python.org\x00example.org'),),
413                   (('emailAddress', 'python-dev@python.org'),))
414        self.assertEqual(p['subject'], subject)
415        self.assertEqual(p['issuer'], subject)
416        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
417            san = (('DNS', 'altnull.python.org\x00example.com'),
418                   ('email', 'null@python.org\x00user@example.org'),
419                   ('URI', 'http://null.python.org\x00http://example.org'),
420                   ('IP Address', '192.0.2.1'),
421                   ('IP Address', '2001:DB8:0:0:0:0:0:1\n'))
422        else:
423            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
424            san = (('DNS', 'altnull.python.org\x00example.com'),
425                   ('email', 'null@python.org\x00user@example.org'),
426                   ('URI', 'http://null.python.org\x00http://example.org'),
427                   ('IP Address', '192.0.2.1'),
428                   ('IP Address', '<invalid>'))
429
430        self.assertEqual(p['subjectAltName'], san)
431
432    def test_parse_all_sans(self):
433        p = ssl._ssl._test_decode_cert(ALLSANFILE)
434        self.assertEqual(p['subjectAltName'],
435            (
436                ('DNS', 'allsans'),
437                ('othername', '<unsupported>'),
438                ('othername', '<unsupported>'),
439                ('email', 'user@example.org'),
440                ('DNS', 'www.example.org'),
441                ('DirName',
442                    ((('countryName', 'XY'),),
443                    (('localityName', 'Castle Anthrax'),),
444                    (('organizationName', 'Python Software Foundation'),),
445                    (('commonName', 'dirname example'),))),
446                ('URI', 'https://www.python.org/'),
447                ('IP Address', '127.0.0.1'),
448                ('IP Address', '0:0:0:0:0:0:0:1\n'),
449                ('Registered ID', '1.2.3.4.5')
450            )
451        )
452
453    def test_DER_to_PEM(self):
454        with open(CAFILE_CACERT, 'r') as f:
455            pem = f.read()
456        d1 = ssl.PEM_cert_to_DER_cert(pem)
457        p2 = ssl.DER_cert_to_PEM_cert(d1)
458        d2 = ssl.PEM_cert_to_DER_cert(p2)
459        self.assertEqual(d1, d2)
460        if not p2.startswith(ssl.PEM_HEADER + '\n'):
461            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
462        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
463            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
464
465    def test_openssl_version(self):
466        n = ssl.OPENSSL_VERSION_NUMBER
467        t = ssl.OPENSSL_VERSION_INFO
468        s = ssl.OPENSSL_VERSION
469        self.assertIsInstance(n, int)
470        self.assertIsInstance(t, tuple)
471        self.assertIsInstance(s, str)
472        # Some sanity checks follow
473        # >= 0.9
474        self.assertGreaterEqual(n, 0x900000)
475        # < 3.0
476        self.assertLess(n, 0x30000000)
477        major, minor, fix, patch, status = t
478        self.assertGreaterEqual(major, 0)
479        self.assertLess(major, 3)
480        self.assertGreaterEqual(minor, 0)
481        self.assertLess(minor, 256)
482        self.assertGreaterEqual(fix, 0)
483        self.assertLess(fix, 256)
484        self.assertGreaterEqual(patch, 0)
485        self.assertLessEqual(patch, 63)
486        self.assertGreaterEqual(status, 0)
487        self.assertLessEqual(status, 15)
488        # Version string as returned by {Open,Libre}SSL, the format might change
489        if IS_LIBRESSL:
490            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
491                            (s, t, hex(n)))
492        else:
493            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
494                            (s, t, hex(n)))
495
496    @support.cpython_only
497    def test_refcycle(self):
498        # Issue #7943: an SSL object doesn't create reference cycles with
499        # itself.
500        s = socket.socket(socket.AF_INET)
501        ss = test_wrap_socket(s)
502        wr = weakref.ref(ss)
503        with support.check_warnings(("", ResourceWarning)):
504            del ss
505        self.assertEqual(wr(), None)
506
507    def test_wrapped_unconnected(self):
508        # Methods on an unconnected SSLSocket propagate the original
509        # OSError raise by the underlying socket object.
510        s = socket.socket(socket.AF_INET)
511        with test_wrap_socket(s) as ss:
512            self.assertRaises(OSError, ss.recv, 1)
513            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
514            self.assertRaises(OSError, ss.recvfrom, 1)
515            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
516            self.assertRaises(OSError, ss.send, b'x')
517            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
518            self.assertRaises(NotImplementedError, ss.dup)
519            self.assertRaises(NotImplementedError, ss.sendmsg,
520                              [b'x'], (), 0, ('0.0.0.0', 0))
521            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
522            self.assertRaises(NotImplementedError, ss.recvmsg_into,
523                              [bytearray(100)])
524
525    def test_timeout(self):
526        # Issue #8524: when creating an SSL socket, the timeout of the
527        # original socket should be retained.
528        for timeout in (None, 0.0, 5.0):
529            s = socket.socket(socket.AF_INET)
530            s.settimeout(timeout)
531            with test_wrap_socket(s) as ss:
532                self.assertEqual(timeout, ss.gettimeout())
533
534    def test_errors_sslwrap(self):
535        sock = socket.socket()
536        self.assertRaisesRegex(ValueError,
537                        "certfile must be specified",
538                        ssl.wrap_socket, sock, keyfile=CERTFILE)
539        self.assertRaisesRegex(ValueError,
540                        "certfile must be specified for server-side operations",
541                        ssl.wrap_socket, sock, server_side=True)
542        self.assertRaisesRegex(ValueError,
543                        "certfile must be specified for server-side operations",
544                         ssl.wrap_socket, sock, server_side=True, certfile="")
545        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
546            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
547                                     s.connect, (HOST, 8080))
548        with self.assertRaises(OSError) as cm:
549            with socket.socket() as sock:
550                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
551        self.assertEqual(cm.exception.errno, errno.ENOENT)
552        with self.assertRaises(OSError) as cm:
553            with socket.socket() as sock:
554                ssl.wrap_socket(sock,
555                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
556        self.assertEqual(cm.exception.errno, errno.ENOENT)
557        with self.assertRaises(OSError) as cm:
558            with socket.socket() as sock:
559                ssl.wrap_socket(sock,
560                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
561        self.assertEqual(cm.exception.errno, errno.ENOENT)
562
563    def bad_cert_test(self, certfile):
564        """Check that trying to use the given client certificate fails"""
565        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
566                                   certfile)
567        sock = socket.socket()
568        self.addCleanup(sock.close)
569        with self.assertRaises(ssl.SSLError):
570            test_wrap_socket(sock,
571                             certfile=certfile)
572
573    def test_empty_cert(self):
574        """Wrapping with an empty cert file"""
575        self.bad_cert_test("nullcert.pem")
576
577    def test_malformed_cert(self):
578        """Wrapping with a badly formatted certificate (syntax error)"""
579        self.bad_cert_test("badcert.pem")
580
581    def test_malformed_key(self):
582        """Wrapping with a badly formatted key (syntax error)"""
583        self.bad_cert_test("badkey.pem")
584
585    def test_match_hostname(self):
586        def ok(cert, hostname):
587            ssl.match_hostname(cert, hostname)
588        def fail(cert, hostname):
589            self.assertRaises(ssl.CertificateError,
590                              ssl.match_hostname, cert, hostname)
591
592        # -- Hostname matching --
593
594        cert = {'subject': ((('commonName', 'example.com'),),)}
595        ok(cert, 'example.com')
596        ok(cert, 'ExAmple.cOm')
597        fail(cert, 'www.example.com')
598        fail(cert, '.example.com')
599        fail(cert, 'example.org')
600        fail(cert, 'exampleXcom')
601
602        cert = {'subject': ((('commonName', '*.a.com'),),)}
603        ok(cert, 'foo.a.com')
604        fail(cert, 'bar.foo.a.com')
605        fail(cert, 'a.com')
606        fail(cert, 'Xa.com')
607        fail(cert, '.a.com')
608
609        # only match wildcards when they are the only thing
610        # in left-most segment
611        cert = {'subject': ((('commonName', 'f*.com'),),)}
612        fail(cert, 'foo.com')
613        fail(cert, 'f.com')
614        fail(cert, 'bar.com')
615        fail(cert, 'foo.a.com')
616        fail(cert, 'bar.foo.com')
617
618        # NULL bytes are bad, CVE-2013-4073
619        cert = {'subject': ((('commonName',
620                              'null.python.org\x00example.org'),),)}
621        ok(cert, 'null.python.org\x00example.org') # or raise an error?
622        fail(cert, 'example.org')
623        fail(cert, 'null.python.org')
624
625        # error cases with wildcards
626        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
627        fail(cert, 'bar.foo.a.com')
628        fail(cert, 'a.com')
629        fail(cert, 'Xa.com')
630        fail(cert, '.a.com')
631
632        cert = {'subject': ((('commonName', 'a.*.com'),),)}
633        fail(cert, 'a.foo.com')
634        fail(cert, 'a..com')
635        fail(cert, 'a.com')
636
637        # wildcard doesn't match IDNA prefix 'xn--'
638        idna = 'püthon.python.org'.encode("idna").decode("ascii")
639        cert = {'subject': ((('commonName', idna),),)}
640        ok(cert, idna)
641        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
642        fail(cert, idna)
643        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
644        fail(cert, idna)
645
646        # wildcard in first fragment and  IDNA A-labels in sequent fragments
647        # are supported.
648        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
649        cert = {'subject': ((('commonName', idna),),)}
650        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
651        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
652        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
653        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
654
655        # Slightly fake real-world example
656        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
657                'subject': ((('commonName', 'linuxfrz.org'),),),
658                'subjectAltName': (('DNS', 'linuxfr.org'),
659                                   ('DNS', 'linuxfr.com'),
660                                   ('othername', '<unsupported>'))}
661        ok(cert, 'linuxfr.org')
662        ok(cert, 'linuxfr.com')
663        # Not a "DNS" entry
664        fail(cert, '<unsupported>')
665        # When there is a subjectAltName, commonName isn't used
666        fail(cert, 'linuxfrz.org')
667
668        # A pristine real-world example
669        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
670                'subject': ((('countryName', 'US'),),
671                            (('stateOrProvinceName', 'California'),),
672                            (('localityName', 'Mountain View'),),
673                            (('organizationName', 'Google Inc'),),
674                            (('commonName', 'mail.google.com'),))}
675        ok(cert, 'mail.google.com')
676        fail(cert, 'gmail.com')
677        # Only commonName is considered
678        fail(cert, 'California')
679
680        # -- IPv4 matching --
681        cert = {'subject': ((('commonName', 'example.com'),),),
682                'subjectAltName': (('DNS', 'example.com'),
683                                   ('IP Address', '10.11.12.13'),
684                                   ('IP Address', '14.15.16.17'))}
685        ok(cert, '10.11.12.13')
686        ok(cert, '14.15.16.17')
687        fail(cert, '14.15.16.18')
688        fail(cert, 'example.net')
689
690        # -- IPv6 matching --
691        if hasattr(socket, 'AF_INET6'):
692            cert = {'subject': ((('commonName', 'example.com'),),),
693                    'subjectAltName': (
694                        ('DNS', 'example.com'),
695                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
696                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
697            ok(cert, '2001::cafe')
698            ok(cert, '2003::baba')
699            fail(cert, '2003::bebe')
700            fail(cert, 'example.net')
701
702        # -- Miscellaneous --
703
704        # Neither commonName nor subjectAltName
705        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
706                'subject': ((('countryName', 'US'),),
707                            (('stateOrProvinceName', 'California'),),
708                            (('localityName', 'Mountain View'),),
709                            (('organizationName', 'Google Inc'),))}
710        fail(cert, 'mail.google.com')
711
712        # No DNS entry in subjectAltName but a commonName
713        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
714                'subject': ((('countryName', 'US'),),
715                            (('stateOrProvinceName', 'California'),),
716                            (('localityName', 'Mountain View'),),
717                            (('commonName', 'mail.google.com'),)),
718                'subjectAltName': (('othername', 'blabla'), )}
719        ok(cert, 'mail.google.com')
720
721        # No DNS entry subjectAltName and no commonName
722        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
723                'subject': ((('countryName', 'US'),),
724                            (('stateOrProvinceName', 'California'),),
725                            (('localityName', 'Mountain View'),),
726                            (('organizationName', 'Google Inc'),)),
727                'subjectAltName': (('othername', 'blabla'),)}
728        fail(cert, 'google.com')
729
730        # Empty cert / no cert
731        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
732        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
733
734        # Issue #17980: avoid denials of service by refusing more than one
735        # wildcard per fragment.
736        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
737        with self.assertRaisesRegex(
738                ssl.CertificateError,
739                "partial wildcards in leftmost label are not supported"):
740            ssl.match_hostname(cert, 'axxb.example.com')
741
742        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
743        with self.assertRaisesRegex(
744                ssl.CertificateError,
745                "wildcard can only be present in the leftmost label"):
746            ssl.match_hostname(cert, 'www.sub.example.com')
747
748        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
749        with self.assertRaisesRegex(
750                ssl.CertificateError,
751                "too many wildcards"):
752            ssl.match_hostname(cert, 'axxbxxc.example.com')
753
754        cert = {'subject': ((('commonName', '*'),),)}
755        with self.assertRaisesRegex(
756                ssl.CertificateError,
757                "sole wildcard without additional labels are not support"):
758            ssl.match_hostname(cert, 'host')
759
760        cert = {'subject': ((('commonName', '*.com'),),)}
761        with self.assertRaisesRegex(
762                ssl.CertificateError,
763                r"hostname 'com' doesn't match '\*.com'"):
764            ssl.match_hostname(cert, 'com')
765
766        # extra checks for _inet_paton()
767        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
768            with self.assertRaises(ValueError):
769                ssl._inet_paton(invalid)
770        for ipaddr in ['127.0.0.1', '192.168.0.1']:
771            self.assertTrue(ssl._inet_paton(ipaddr))
772        if hasattr(socket, 'AF_INET6'):
773            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
774                self.assertTrue(ssl._inet_paton(ipaddr))
775
776    def test_server_side(self):
777        # server_hostname doesn't work for server sockets
778        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
779        with socket.socket() as sock:
780            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
781                              server_hostname="some.hostname")
782
783    def test_unknown_channel_binding(self):
784        # should raise ValueError for unknown type
785        s = socket.socket(socket.AF_INET)
786        s.bind(('127.0.0.1', 0))
787        s.listen()
788        c = socket.socket(socket.AF_INET)
789        c.connect(s.getsockname())
790        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
791            with self.assertRaises(ValueError):
792                ss.get_channel_binding("unknown-type")
793        s.close()
794
795    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
796                         "'tls-unique' channel binding not available")
797    def test_tls_unique_channel_binding(self):
798        # unconnected should return None for known type
799        s = socket.socket(socket.AF_INET)
800        with test_wrap_socket(s) as ss:
801            self.assertIsNone(ss.get_channel_binding("tls-unique"))
802        # the same for server-side
803        s = socket.socket(socket.AF_INET)
804        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
805            self.assertIsNone(ss.get_channel_binding("tls-unique"))
806
807    def test_dealloc_warn(self):
808        ss = test_wrap_socket(socket.socket(socket.AF_INET))
809        r = repr(ss)
810        with self.assertWarns(ResourceWarning) as cm:
811            ss = None
812            support.gc_collect()
813        self.assertIn(r, str(cm.warning.args[0]))
814
815    def test_get_default_verify_paths(self):
816        paths = ssl.get_default_verify_paths()
817        self.assertEqual(len(paths), 6)
818        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
819
820        with support.EnvironmentVarGuard() as env:
821            env["SSL_CERT_DIR"] = CAPATH
822            env["SSL_CERT_FILE"] = CERTFILE
823            paths = ssl.get_default_verify_paths()
824            self.assertEqual(paths.cafile, CERTFILE)
825            self.assertEqual(paths.capath, CAPATH)
826
827    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
828    def test_enum_certificates(self):
829        self.assertTrue(ssl.enum_certificates("CA"))
830        self.assertTrue(ssl.enum_certificates("ROOT"))
831
832        self.assertRaises(TypeError, ssl.enum_certificates)
833        self.assertRaises(WindowsError, ssl.enum_certificates, "")
834
835        trust_oids = set()
836        for storename in ("CA", "ROOT"):
837            store = ssl.enum_certificates(storename)
838            self.assertIsInstance(store, list)
839            for element in store:
840                self.assertIsInstance(element, tuple)
841                self.assertEqual(len(element), 3)
842                cert, enc, trust = element
843                self.assertIsInstance(cert, bytes)
844                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
845                self.assertIsInstance(trust, (set, bool))
846                if isinstance(trust, set):
847                    trust_oids.update(trust)
848
849        serverAuth = "1.3.6.1.5.5.7.3.1"
850        self.assertIn(serverAuth, trust_oids)
851
852    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
853    def test_enum_crls(self):
854        self.assertTrue(ssl.enum_crls("CA"))
855        self.assertRaises(TypeError, ssl.enum_crls)
856        self.assertRaises(WindowsError, ssl.enum_crls, "")
857
858        crls = ssl.enum_crls("CA")
859        self.assertIsInstance(crls, list)
860        for element in crls:
861            self.assertIsInstance(element, tuple)
862            self.assertEqual(len(element), 2)
863            self.assertIsInstance(element[0], bytes)
864            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
865
866
867    def test_asn1object(self):
868        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
869                    '1.3.6.1.5.5.7.3.1')
870
871        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
872        self.assertEqual(val, expected)
873        self.assertEqual(val.nid, 129)
874        self.assertEqual(val.shortname, 'serverAuth')
875        self.assertEqual(val.longname, 'TLS Web Server Authentication')
876        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
877        self.assertIsInstance(val, ssl._ASN1Object)
878        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
879
880        val = ssl._ASN1Object.fromnid(129)
881        self.assertEqual(val, expected)
882        self.assertIsInstance(val, ssl._ASN1Object)
883        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
884        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
885            ssl._ASN1Object.fromnid(100000)
886        for i in range(1000):
887            try:
888                obj = ssl._ASN1Object.fromnid(i)
889            except ValueError:
890                pass
891            else:
892                self.assertIsInstance(obj.nid, int)
893                self.assertIsInstance(obj.shortname, str)
894                self.assertIsInstance(obj.longname, str)
895                self.assertIsInstance(obj.oid, (str, type(None)))
896
897        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
898        self.assertEqual(val, expected)
899        self.assertIsInstance(val, ssl._ASN1Object)
900        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
901        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
902                         expected)
903        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
904            ssl._ASN1Object.fromname('serverauth')
905
906    def test_purpose_enum(self):
907        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
908        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
909        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
910        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
911        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
912        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
913                              '1.3.6.1.5.5.7.3.1')
914
915        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
916        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
917        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
918        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
919        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
920        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
921                              '1.3.6.1.5.5.7.3.2')
922
923    def test_unsupported_dtls(self):
924        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
925        self.addCleanup(s.close)
926        with self.assertRaises(NotImplementedError) as cx:
927            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
928        self.assertEqual(str(cx.exception), "only stream sockets are supported")
929        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
930        with self.assertRaises(NotImplementedError) as cx:
931            ctx.wrap_socket(s)
932        self.assertEqual(str(cx.exception), "only stream sockets are supported")
933
934    def cert_time_ok(self, timestring, timestamp):
935        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
936
937    def cert_time_fail(self, timestring):
938        with self.assertRaises(ValueError):
939            ssl.cert_time_to_seconds(timestring)
940
941    @unittest.skipUnless(utc_offset(),
942                         'local time needs to be different from UTC')
943    def test_cert_time_to_seconds_timezone(self):
944        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
945        #               results if local timezone is not UTC
946        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
947        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
948
949    def test_cert_time_to_seconds(self):
950        timestring = "Jan  5 09:34:43 2018 GMT"
951        ts = 1515144883.0
952        self.cert_time_ok(timestring, ts)
953        # accept keyword parameter, assert its name
954        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
955        # accept both %e and %d (space or zero generated by strftime)
956        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
957        # case-insensitive
958        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
959        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
960        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
961        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
962        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
963        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
964        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
965        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
966
967        newyear_ts = 1230768000.0
968        # leap seconds
969        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
970        # same timestamp
971        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
972
973        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
974        #  allow 60th second (even if it is not a leap second)
975        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
976        #  allow 2nd leap second for compatibility with time.strptime()
977        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
978        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
979
980        # no special treatment for the special value:
981        #   99991231235959Z (rfc 5280)
982        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
983
984    @support.run_with_locale('LC_ALL', '')
985    def test_cert_time_to_seconds_locale(self):
986        # `cert_time_to_seconds()` should be locale independent
987
988        def local_february_name():
989            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
990
991        if local_february_name().lower() == 'feb':
992            self.skipTest("locale-specific month name needs to be "
993                          "different from C locale")
994
995        # locale-independent
996        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
997        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
998
999    def test_connect_ex_error(self):
1000        server = socket.socket(socket.AF_INET)
1001        self.addCleanup(server.close)
1002        port = support.bind_port(server)  # Reserve port but don't listen
1003        s = test_wrap_socket(socket.socket(socket.AF_INET),
1004                            cert_reqs=ssl.CERT_REQUIRED)
1005        self.addCleanup(s.close)
1006        rc = s.connect_ex((HOST, port))
1007        # Issue #19919: Windows machines or VMs hosted on Windows
1008        # machines sometimes return EWOULDBLOCK.
1009        errors = (
1010            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1011            errno.EWOULDBLOCK,
1012        )
1013        self.assertIn(rc, errors)
1014
1015
1016class ContextTests(unittest.TestCase):
1017
1018    @skip_if_broken_ubuntu_ssl
1019    def test_constructor(self):
1020        for protocol in PROTOCOLS:
1021            ssl.SSLContext(protocol)
1022        ctx = ssl.SSLContext()
1023        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1024        self.assertRaises(ValueError, ssl.SSLContext, -1)
1025        self.assertRaises(ValueError, ssl.SSLContext, 42)
1026
1027    @skip_if_broken_ubuntu_ssl
1028    def test_protocol(self):
1029        for proto in PROTOCOLS:
1030            ctx = ssl.SSLContext(proto)
1031            self.assertEqual(ctx.protocol, proto)
1032
1033    def test_ciphers(self):
1034        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1035        ctx.set_ciphers("ALL")
1036        ctx.set_ciphers("DEFAULT")
1037        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1038            ctx.set_ciphers("^$:,;?*'dorothyx")
1039
1040    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1041                         "Test applies only to Python default ciphers")
1042    def test_python_ciphers(self):
1043        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1044        ciphers = ctx.get_ciphers()
1045        for suite in ciphers:
1046            name = suite['name']
1047            self.assertNotIn("PSK", name)
1048            self.assertNotIn("SRP", name)
1049            self.assertNotIn("MD5", name)
1050            self.assertNotIn("RC4", name)
1051            self.assertNotIn("3DES", name)
1052
1053    @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old')
1054    def test_get_ciphers(self):
1055        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1056        ctx.set_ciphers('AESGCM')
1057        names = set(d['name'] for d in ctx.get_ciphers())
1058        self.assertIn('AES256-GCM-SHA384', names)
1059        self.assertIn('AES128-GCM-SHA256', names)
1060
1061    @skip_if_broken_ubuntu_ssl
1062    def test_options(self):
1063        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1064        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1065        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1066        # SSLContext also enables these by default
1067        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1068                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1069                    OP_ENABLE_MIDDLEBOX_COMPAT)
1070        self.assertEqual(default, ctx.options)
1071        ctx.options |= ssl.OP_NO_TLSv1
1072        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1073        if can_clear_options():
1074            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1075            self.assertEqual(default, ctx.options)
1076            ctx.options = 0
1077            # Ubuntu has OP_NO_SSLv3 forced on by default
1078            self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1079        else:
1080            with self.assertRaises(ValueError):
1081                ctx.options = 0
1082
1083    def test_verify_mode_protocol(self):
1084        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1085        # Default value
1086        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1087        ctx.verify_mode = ssl.CERT_OPTIONAL
1088        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1089        ctx.verify_mode = ssl.CERT_REQUIRED
1090        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1091        ctx.verify_mode = ssl.CERT_NONE
1092        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1093        with self.assertRaises(TypeError):
1094            ctx.verify_mode = None
1095        with self.assertRaises(ValueError):
1096            ctx.verify_mode = 42
1097
1098        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1099        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1100        self.assertFalse(ctx.check_hostname)
1101
1102        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1103        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1104        self.assertTrue(ctx.check_hostname)
1105
1106    def test_hostname_checks_common_name(self):
1107        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1108        self.assertTrue(ctx.hostname_checks_common_name)
1109        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1110            ctx.hostname_checks_common_name = True
1111            self.assertTrue(ctx.hostname_checks_common_name)
1112            ctx.hostname_checks_common_name = False
1113            self.assertFalse(ctx.hostname_checks_common_name)
1114            ctx.hostname_checks_common_name = True
1115            self.assertTrue(ctx.hostname_checks_common_name)
1116        else:
1117            with self.assertRaises(AttributeError):
1118                ctx.hostname_checks_common_name = True
1119
1120    @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'),
1121                         "required OpenSSL 1.1.0g")
1122    def test_min_max_version(self):
1123        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1124        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1125        # Fedora override the setting to TLS 1.0.
1126        self.assertIn(
1127            ctx.minimum_version,
1128            {ssl.TLSVersion.MINIMUM_SUPPORTED,
1129             # Fedora 29 uses TLS 1.0 by default
1130             ssl.TLSVersion.TLSv1,
1131             # RHEL 8 uses TLS 1.2 by default
1132             ssl.TLSVersion.TLSv1_2}
1133        )
1134        self.assertEqual(
1135            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1136        )
1137
1138        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1139        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1140        self.assertEqual(
1141            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1142        )
1143        self.assertEqual(
1144            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1145        )
1146
1147        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1148        ctx.maximum_version = ssl.TLSVersion.TLSv1
1149        self.assertEqual(
1150            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1151        )
1152        self.assertEqual(
1153            ctx.maximum_version, ssl.TLSVersion.TLSv1
1154        )
1155
1156        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1157        self.assertEqual(
1158            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1159        )
1160
1161        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1162        self.assertIn(
1163            ctx.maximum_version,
1164            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3}
1165        )
1166
1167        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1168        self.assertIn(
1169            ctx.minimum_version,
1170            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1171        )
1172
1173        with self.assertRaises(ValueError):
1174            ctx.minimum_version = 42
1175
1176        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1177
1178        self.assertEqual(
1179            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1180        )
1181        self.assertEqual(
1182            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1183        )
1184        with self.assertRaises(ValueError):
1185            ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1186        with self.assertRaises(ValueError):
1187            ctx.maximum_version = ssl.TLSVersion.TLSv1
1188
1189
1190    @unittest.skipUnless(have_verify_flags(),
1191                         "verify_flags need OpenSSL > 0.9.8")
1192    def test_verify_flags(self):
1193        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1194        # default value
1195        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1196        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1197        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1198        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1199        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1200        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1201        ctx.verify_flags = ssl.VERIFY_DEFAULT
1202        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1203        # supports any value
1204        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1205        self.assertEqual(ctx.verify_flags,
1206                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1207        with self.assertRaises(TypeError):
1208            ctx.verify_flags = None
1209
1210    def test_load_cert_chain(self):
1211        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1212        # Combined key and cert in a single file
1213        ctx.load_cert_chain(CERTFILE, keyfile=None)
1214        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1215        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1216        with self.assertRaises(OSError) as cm:
1217            ctx.load_cert_chain(NONEXISTINGCERT)
1218        self.assertEqual(cm.exception.errno, errno.ENOENT)
1219        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1220            ctx.load_cert_chain(BADCERT)
1221        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1222            ctx.load_cert_chain(EMPTYCERT)
1223        # Separate key and cert
1224        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1225        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1226        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1227        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1228        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1229            ctx.load_cert_chain(ONLYCERT)
1230        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1231            ctx.load_cert_chain(ONLYKEY)
1232        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1233            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1234        # Mismatching key and cert
1235        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1236        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1237            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1238        # Password protected key and cert
1239        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1240        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1241        ctx.load_cert_chain(CERTFILE_PROTECTED,
1242                            password=bytearray(KEY_PASSWORD.encode()))
1243        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1244        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1245        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1246                            bytearray(KEY_PASSWORD.encode()))
1247        with self.assertRaisesRegex(TypeError, "should be a string"):
1248            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1249        with self.assertRaises(ssl.SSLError):
1250            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1251        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1252            # openssl has a fixed limit on the password buffer.
1253            # PEM_BUFSIZE is generally set to 1kb.
1254            # Return a string larger than this.
1255            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1256        # Password callback
1257        def getpass_unicode():
1258            return KEY_PASSWORD
1259        def getpass_bytes():
1260            return KEY_PASSWORD.encode()
1261        def getpass_bytearray():
1262            return bytearray(KEY_PASSWORD.encode())
1263        def getpass_badpass():
1264            return "badpass"
1265        def getpass_huge():
1266            return b'a' * (1024 * 1024)
1267        def getpass_bad_type():
1268            return 9
1269        def getpass_exception():
1270            raise Exception('getpass error')
1271        class GetPassCallable:
1272            def __call__(self):
1273                return KEY_PASSWORD
1274            def getpass(self):
1275                return KEY_PASSWORD
1276        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1277        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1278        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1279        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1280        ctx.load_cert_chain(CERTFILE_PROTECTED,
1281                            password=GetPassCallable().getpass)
1282        with self.assertRaises(ssl.SSLError):
1283            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1284        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1285            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1286        with self.assertRaisesRegex(TypeError, "must return a string"):
1287            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1288        with self.assertRaisesRegex(Exception, "getpass error"):
1289            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1290        # Make sure the password function isn't called if it isn't needed
1291        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1292
1293    def test_load_verify_locations(self):
1294        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1295        ctx.load_verify_locations(CERTFILE)
1296        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1297        ctx.load_verify_locations(BYTES_CERTFILE)
1298        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1299        self.assertRaises(TypeError, ctx.load_verify_locations)
1300        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1301        with self.assertRaises(OSError) as cm:
1302            ctx.load_verify_locations(NONEXISTINGCERT)
1303        self.assertEqual(cm.exception.errno, errno.ENOENT)
1304        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1305            ctx.load_verify_locations(BADCERT)
1306        ctx.load_verify_locations(CERTFILE, CAPATH)
1307        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1308
1309        # Issue #10989: crash if the second argument type is invalid
1310        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1311
1312    def test_load_verify_cadata(self):
1313        # test cadata
1314        with open(CAFILE_CACERT) as f:
1315            cacert_pem = f.read()
1316        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1317        with open(CAFILE_NEURONIO) as f:
1318            neuronio_pem = f.read()
1319        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1320
1321        # test PEM
1322        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1323        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1324        ctx.load_verify_locations(cadata=cacert_pem)
1325        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1326        ctx.load_verify_locations(cadata=neuronio_pem)
1327        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1328        # cert already in hash table
1329        ctx.load_verify_locations(cadata=neuronio_pem)
1330        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1331
1332        # combined
1333        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1334        combined = "\n".join((cacert_pem, neuronio_pem))
1335        ctx.load_verify_locations(cadata=combined)
1336        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1337
1338        # with junk around the certs
1339        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1340        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1341                    neuronio_pem, "tail"]
1342        ctx.load_verify_locations(cadata="\n".join(combined))
1343        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1344
1345        # test DER
1346        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1347        ctx.load_verify_locations(cadata=cacert_der)
1348        ctx.load_verify_locations(cadata=neuronio_der)
1349        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1350        # cert already in hash table
1351        ctx.load_verify_locations(cadata=cacert_der)
1352        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1353
1354        # combined
1355        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1356        combined = b"".join((cacert_der, neuronio_der))
1357        ctx.load_verify_locations(cadata=combined)
1358        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1359
1360        # error cases
1361        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1362        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1363
1364        with self.assertRaisesRegex(ssl.SSLError, "no start line"):
1365            ctx.load_verify_locations(cadata="broken")
1366        with self.assertRaisesRegex(ssl.SSLError, "not enough data"):
1367            ctx.load_verify_locations(cadata=b"broken")
1368
1369
1370    def test_load_dh_params(self):
1371        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1372        ctx.load_dh_params(DHFILE)
1373        if os.name != 'nt':
1374            ctx.load_dh_params(BYTES_DHFILE)
1375        self.assertRaises(TypeError, ctx.load_dh_params)
1376        self.assertRaises(TypeError, ctx.load_dh_params, None)
1377        with self.assertRaises(FileNotFoundError) as cm:
1378            ctx.load_dh_params(NONEXISTINGCERT)
1379        self.assertEqual(cm.exception.errno, errno.ENOENT)
1380        with self.assertRaises(ssl.SSLError) as cm:
1381            ctx.load_dh_params(CERTFILE)
1382
1383    @skip_if_broken_ubuntu_ssl
1384    def test_session_stats(self):
1385        for proto in PROTOCOLS:
1386            ctx = ssl.SSLContext(proto)
1387            self.assertEqual(ctx.session_stats(), {
1388                'number': 0,
1389                'connect': 0,
1390                'connect_good': 0,
1391                'connect_renegotiate': 0,
1392                'accept': 0,
1393                'accept_good': 0,
1394                'accept_renegotiate': 0,
1395                'hits': 0,
1396                'misses': 0,
1397                'timeouts': 0,
1398                'cache_full': 0,
1399            })
1400
1401    def test_set_default_verify_paths(self):
1402        # There's not much we can do to test that it acts as expected,
1403        # so just check it doesn't crash or raise an exception.
1404        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1405        ctx.set_default_verify_paths()
1406
1407    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1408    def test_set_ecdh_curve(self):
1409        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1410        ctx.set_ecdh_curve("prime256v1")
1411        ctx.set_ecdh_curve(b"prime256v1")
1412        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1413        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1414        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1415        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1416
1417    @needs_sni
1418    def test_sni_callback(self):
1419        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1420
1421        # set_servername_callback expects a callable, or None
1422        self.assertRaises(TypeError, ctx.set_servername_callback)
1423        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1424        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1425        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1426
1427        def dummycallback(sock, servername, ctx):
1428            pass
1429        ctx.set_servername_callback(None)
1430        ctx.set_servername_callback(dummycallback)
1431
1432    @needs_sni
1433    def test_sni_callback_refcycle(self):
1434        # Reference cycles through the servername callback are detected
1435        # and cleared.
1436        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1437        def dummycallback(sock, servername, ctx, cycle=ctx):
1438            pass
1439        ctx.set_servername_callback(dummycallback)
1440        wr = weakref.ref(ctx)
1441        del ctx, dummycallback
1442        gc.collect()
1443        self.assertIs(wr(), None)
1444
1445    def test_cert_store_stats(self):
1446        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1447        self.assertEqual(ctx.cert_store_stats(),
1448            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1449        ctx.load_cert_chain(CERTFILE)
1450        self.assertEqual(ctx.cert_store_stats(),
1451            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1452        ctx.load_verify_locations(CERTFILE)
1453        self.assertEqual(ctx.cert_store_stats(),
1454            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1455        ctx.load_verify_locations(CAFILE_CACERT)
1456        self.assertEqual(ctx.cert_store_stats(),
1457            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1458
1459    def test_get_ca_certs(self):
1460        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1461        self.assertEqual(ctx.get_ca_certs(), [])
1462        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1463        ctx.load_verify_locations(CERTFILE)
1464        self.assertEqual(ctx.get_ca_certs(), [])
1465        # but CAFILE_CACERT is a CA cert
1466        ctx.load_verify_locations(CAFILE_CACERT)
1467        self.assertEqual(ctx.get_ca_certs(),
1468            [{'issuer': ((('organizationName', 'Root CA'),),
1469                         (('organizationalUnitName', 'http://www.cacert.org'),),
1470                         (('commonName', 'CA Cert Signing Authority'),),
1471                         (('emailAddress', 'support@cacert.org'),)),
1472              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
1473              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
1474              'serialNumber': '00',
1475              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1476              'subject': ((('organizationName', 'Root CA'),),
1477                          (('organizationalUnitName', 'http://www.cacert.org'),),
1478                          (('commonName', 'CA Cert Signing Authority'),),
1479                          (('emailAddress', 'support@cacert.org'),)),
1480              'version': 3}])
1481
1482        with open(CAFILE_CACERT) as f:
1483            pem = f.read()
1484        der = ssl.PEM_cert_to_DER_cert(pem)
1485        self.assertEqual(ctx.get_ca_certs(True), [der])
1486
1487    def test_load_default_certs(self):
1488        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1489        ctx.load_default_certs()
1490
1491        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1492        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1493        ctx.load_default_certs()
1494
1495        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1496        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1497
1498        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1499        self.assertRaises(TypeError, ctx.load_default_certs, None)
1500        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1501
1502    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1503    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
1504    def test_load_default_certs_env(self):
1505        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1506        with support.EnvironmentVarGuard() as env:
1507            env["SSL_CERT_DIR"] = CAPATH
1508            env["SSL_CERT_FILE"] = CERTFILE
1509            ctx.load_default_certs()
1510            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1511
1512    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1513    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1514    def test_load_default_certs_env_windows(self):
1515        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1516        ctx.load_default_certs()
1517        stats = ctx.cert_store_stats()
1518
1519        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1520        with support.EnvironmentVarGuard() as env:
1521            env["SSL_CERT_DIR"] = CAPATH
1522            env["SSL_CERT_FILE"] = CERTFILE
1523            ctx.load_default_certs()
1524            stats["x509"] += 1
1525            self.assertEqual(ctx.cert_store_stats(), stats)
1526
1527    def _assert_context_options(self, ctx):
1528        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1529        if OP_NO_COMPRESSION != 0:
1530            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1531                             OP_NO_COMPRESSION)
1532        if OP_SINGLE_DH_USE != 0:
1533            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1534                             OP_SINGLE_DH_USE)
1535        if OP_SINGLE_ECDH_USE != 0:
1536            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1537                             OP_SINGLE_ECDH_USE)
1538        if OP_CIPHER_SERVER_PREFERENCE != 0:
1539            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1540                             OP_CIPHER_SERVER_PREFERENCE)
1541
1542    def test_create_default_context(self):
1543        ctx = ssl.create_default_context()
1544
1545        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1546        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1547        self.assertTrue(ctx.check_hostname)
1548        self._assert_context_options(ctx)
1549
1550        with open(SIGNING_CA) as f:
1551            cadata = f.read()
1552        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1553                                         cadata=cadata)
1554        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1555        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1556        self._assert_context_options(ctx)
1557
1558        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1559        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1560        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1561        self._assert_context_options(ctx)
1562
1563    def test__create_stdlib_context(self):
1564        ctx = ssl._create_stdlib_context()
1565        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1566        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1567        self.assertFalse(ctx.check_hostname)
1568        self._assert_context_options(ctx)
1569
1570        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1571        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1572        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1573        self._assert_context_options(ctx)
1574
1575        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
1576                                         cert_reqs=ssl.CERT_REQUIRED,
1577                                         check_hostname=True)
1578        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1579        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1580        self.assertTrue(ctx.check_hostname)
1581        self._assert_context_options(ctx)
1582
1583        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1584        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1585        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1586        self._assert_context_options(ctx)
1587
1588    def test_check_hostname(self):
1589        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1590        self.assertFalse(ctx.check_hostname)
1591        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1592
1593        # Auto set CERT_REQUIRED
1594        ctx.check_hostname = True
1595        self.assertTrue(ctx.check_hostname)
1596        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1597        ctx.check_hostname = False
1598        ctx.verify_mode = ssl.CERT_REQUIRED
1599        self.assertFalse(ctx.check_hostname)
1600        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1601
1602        # Changing verify_mode does not affect check_hostname
1603        ctx.check_hostname = False
1604        ctx.verify_mode = ssl.CERT_NONE
1605        ctx.check_hostname = False
1606        self.assertFalse(ctx.check_hostname)
1607        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1608        # Auto set
1609        ctx.check_hostname = True
1610        self.assertTrue(ctx.check_hostname)
1611        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1612
1613        ctx.check_hostname = False
1614        ctx.verify_mode = ssl.CERT_OPTIONAL
1615        ctx.check_hostname = False
1616        self.assertFalse(ctx.check_hostname)
1617        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1618        # keep CERT_OPTIONAL
1619        ctx.check_hostname = True
1620        self.assertTrue(ctx.check_hostname)
1621        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1622
1623        # Cannot set CERT_NONE with check_hostname enabled
1624        with self.assertRaises(ValueError):
1625            ctx.verify_mode = ssl.CERT_NONE
1626        ctx.check_hostname = False
1627        self.assertFalse(ctx.check_hostname)
1628        ctx.verify_mode = ssl.CERT_NONE
1629        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1630
1631    def test_context_client_server(self):
1632        # PROTOCOL_TLS_CLIENT has sane defaults
1633        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1634        self.assertTrue(ctx.check_hostname)
1635        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1636
1637        # PROTOCOL_TLS_SERVER has different but also sane defaults
1638        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1639        self.assertFalse(ctx.check_hostname)
1640        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1641
1642    def test_context_custom_class(self):
1643        class MySSLSocket(ssl.SSLSocket):
1644            pass
1645
1646        class MySSLObject(ssl.SSLObject):
1647            pass
1648
1649        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1650        ctx.sslsocket_class = MySSLSocket
1651        ctx.sslobject_class = MySSLObject
1652
1653        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1654            self.assertIsInstance(sock, MySSLSocket)
1655        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO())
1656        self.assertIsInstance(obj, MySSLObject)
1657
1658
1659class SSLErrorTests(unittest.TestCase):
1660
1661    def test_str(self):
1662        # The str() of a SSLError doesn't include the errno
1663        e = ssl.SSLError(1, "foo")
1664        self.assertEqual(str(e), "foo")
1665        self.assertEqual(e.errno, 1)
1666        # Same for a subclass
1667        e = ssl.SSLZeroReturnError(1, "foo")
1668        self.assertEqual(str(e), "foo")
1669        self.assertEqual(e.errno, 1)
1670
1671    def test_lib_reason(self):
1672        # Test the library and reason attributes
1673        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1674        with self.assertRaises(ssl.SSLError) as cm:
1675            ctx.load_dh_params(CERTFILE)
1676        self.assertEqual(cm.exception.library, 'PEM')
1677        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1678        s = str(cm.exception)
1679        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1680
1681    def test_subclass(self):
1682        # Check that the appropriate SSLError subclass is raised
1683        # (this only tests one of them)
1684        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1685        ctx.check_hostname = False
1686        ctx.verify_mode = ssl.CERT_NONE
1687        with socket.socket() as s:
1688            s.bind(("127.0.0.1", 0))
1689            s.listen()
1690            c = socket.socket()
1691            c.connect(s.getsockname())
1692            c.setblocking(False)
1693            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1694                with self.assertRaises(ssl.SSLWantReadError) as cm:
1695                    c.do_handshake()
1696                s = str(cm.exception)
1697                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1698                # For compatibility
1699                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1700
1701
1702    def test_bad_server_hostname(self):
1703        ctx = ssl.create_default_context()
1704        with self.assertRaises(ValueError):
1705            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1706                         server_hostname="")
1707        with self.assertRaises(ValueError):
1708            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1709                         server_hostname=".example.org")
1710        with self.assertRaises(TypeError):
1711            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1712                         server_hostname="example.org\x00evil.com")
1713
1714
1715class MemoryBIOTests(unittest.TestCase):
1716
1717    def test_read_write(self):
1718        bio = ssl.MemoryBIO()
1719        bio.write(b'foo')
1720        self.assertEqual(bio.read(), b'foo')
1721        self.assertEqual(bio.read(), b'')
1722        bio.write(b'foo')
1723        bio.write(b'bar')
1724        self.assertEqual(bio.read(), b'foobar')
1725        self.assertEqual(bio.read(), b'')
1726        bio.write(b'baz')
1727        self.assertEqual(bio.read(2), b'ba')
1728        self.assertEqual(bio.read(1), b'z')
1729        self.assertEqual(bio.read(1), b'')
1730
1731    def test_eof(self):
1732        bio = ssl.MemoryBIO()
1733        self.assertFalse(bio.eof)
1734        self.assertEqual(bio.read(), b'')
1735        self.assertFalse(bio.eof)
1736        bio.write(b'foo')
1737        self.assertFalse(bio.eof)
1738        bio.write_eof()
1739        self.assertFalse(bio.eof)
1740        self.assertEqual(bio.read(2), b'fo')
1741        self.assertFalse(bio.eof)
1742        self.assertEqual(bio.read(1), b'o')
1743        self.assertTrue(bio.eof)
1744        self.assertEqual(bio.read(), b'')
1745        self.assertTrue(bio.eof)
1746
1747    def test_pending(self):
1748        bio = ssl.MemoryBIO()
1749        self.assertEqual(bio.pending, 0)
1750        bio.write(b'foo')
1751        self.assertEqual(bio.pending, 3)
1752        for i in range(3):
1753            bio.read(1)
1754            self.assertEqual(bio.pending, 3-i-1)
1755        for i in range(3):
1756            bio.write(b'x')
1757            self.assertEqual(bio.pending, i+1)
1758        bio.read()
1759        self.assertEqual(bio.pending, 0)
1760
1761    def test_buffer_types(self):
1762        bio = ssl.MemoryBIO()
1763        bio.write(b'foo')
1764        self.assertEqual(bio.read(), b'foo')
1765        bio.write(bytearray(b'bar'))
1766        self.assertEqual(bio.read(), b'bar')
1767        bio.write(memoryview(b'baz'))
1768        self.assertEqual(bio.read(), b'baz')
1769
1770    def test_error_types(self):
1771        bio = ssl.MemoryBIO()
1772        self.assertRaises(TypeError, bio.write, 'foo')
1773        self.assertRaises(TypeError, bio.write, None)
1774        self.assertRaises(TypeError, bio.write, True)
1775        self.assertRaises(TypeError, bio.write, 1)
1776
1777
1778class SSLObjectTests(unittest.TestCase):
1779    def test_private_init(self):
1780        bio = ssl.MemoryBIO()
1781        with self.assertRaisesRegex(TypeError, "public constructor"):
1782            ssl.SSLObject(bio, bio)
1783
1784    def test_unwrap(self):
1785        client_ctx, server_ctx, hostname = testing_context()
1786        c_in = ssl.MemoryBIO()
1787        c_out = ssl.MemoryBIO()
1788        s_in = ssl.MemoryBIO()
1789        s_out = ssl.MemoryBIO()
1790        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1791        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1792
1793        # Loop on the handshake for a bit to get it settled
1794        for _ in range(5):
1795            try:
1796                client.do_handshake()
1797            except ssl.SSLWantReadError:
1798                pass
1799            if c_out.pending:
1800                s_in.write(c_out.read())
1801            try:
1802                server.do_handshake()
1803            except ssl.SSLWantReadError:
1804                pass
1805            if s_out.pending:
1806                c_in.write(s_out.read())
1807        # Now the handshakes should be complete (don't raise WantReadError)
1808        client.do_handshake()
1809        server.do_handshake()
1810
1811        # Now if we unwrap one side unilaterally, it should send close-notify
1812        # and raise WantReadError:
1813        with self.assertRaises(ssl.SSLWantReadError):
1814            client.unwrap()
1815
1816        # But server.unwrap() does not raise, because it reads the client's
1817        # close-notify:
1818        s_in.write(c_out.read())
1819        server.unwrap()
1820
1821        # And now that the client gets the server's close-notify, it doesn't
1822        # raise either.
1823        c_in.write(s_out.read())
1824        client.unwrap()
1825
1826class SimpleBackgroundTests(unittest.TestCase):
1827    """Tests that connect to a simple server running in the background"""
1828
1829    def setUp(self):
1830        server = ThreadedEchoServer(SIGNED_CERTFILE)
1831        self.server_addr = (HOST, server.port)
1832        server.__enter__()
1833        self.addCleanup(server.__exit__, None, None, None)
1834
1835    def test_connect(self):
1836        with test_wrap_socket(socket.socket(socket.AF_INET),
1837                            cert_reqs=ssl.CERT_NONE) as s:
1838            s.connect(self.server_addr)
1839            self.assertEqual({}, s.getpeercert())
1840            self.assertFalse(s.server_side)
1841
1842        # this should succeed because we specify the root cert
1843        with test_wrap_socket(socket.socket(socket.AF_INET),
1844                            cert_reqs=ssl.CERT_REQUIRED,
1845                            ca_certs=SIGNING_CA) as s:
1846            s.connect(self.server_addr)
1847            self.assertTrue(s.getpeercert())
1848            self.assertFalse(s.server_side)
1849
1850    def test_connect_fail(self):
1851        # This should fail because we have no verification certs. Connection
1852        # failure crashes ThreadedEchoServer, so run this in an independent
1853        # test method.
1854        s = test_wrap_socket(socket.socket(socket.AF_INET),
1855                            cert_reqs=ssl.CERT_REQUIRED)
1856        self.addCleanup(s.close)
1857        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1858                               s.connect, self.server_addr)
1859
1860    def test_connect_ex(self):
1861        # Issue #11326: check connect_ex() implementation
1862        s = test_wrap_socket(socket.socket(socket.AF_INET),
1863                            cert_reqs=ssl.CERT_REQUIRED,
1864                            ca_certs=SIGNING_CA)
1865        self.addCleanup(s.close)
1866        self.assertEqual(0, s.connect_ex(self.server_addr))
1867        self.assertTrue(s.getpeercert())
1868
1869    def test_non_blocking_connect_ex(self):
1870        # Issue #11326: non-blocking connect_ex() should allow handshake
1871        # to proceed after the socket gets ready.
1872        s = test_wrap_socket(socket.socket(socket.AF_INET),
1873                            cert_reqs=ssl.CERT_REQUIRED,
1874                            ca_certs=SIGNING_CA,
1875                            do_handshake_on_connect=False)
1876        self.addCleanup(s.close)
1877        s.setblocking(False)
1878        rc = s.connect_ex(self.server_addr)
1879        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
1880        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
1881        # Wait for connect to finish
1882        select.select([], [s], [], 5.0)
1883        # Non-blocking handshake
1884        while True:
1885            try:
1886                s.do_handshake()
1887                break
1888            except ssl.SSLWantReadError:
1889                select.select([s], [], [], 5.0)
1890            except ssl.SSLWantWriteError:
1891                select.select([], [s], [], 5.0)
1892        # SSL established
1893        self.assertTrue(s.getpeercert())
1894
1895    def test_connect_with_context(self):
1896        # Same as test_connect, but with a separately created context
1897        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1898        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1899            s.connect(self.server_addr)
1900            self.assertEqual({}, s.getpeercert())
1901        # Same with a server hostname
1902        with ctx.wrap_socket(socket.socket(socket.AF_INET),
1903                            server_hostname="dummy") as s:
1904            s.connect(self.server_addr)
1905        ctx.verify_mode = ssl.CERT_REQUIRED
1906        # This should succeed because we specify the root cert
1907        ctx.load_verify_locations(SIGNING_CA)
1908        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1909            s.connect(self.server_addr)
1910            cert = s.getpeercert()
1911            self.assertTrue(cert)
1912
1913    def test_connect_with_context_fail(self):
1914        # This should fail because we have no verification certs. Connection
1915        # failure crashes ThreadedEchoServer, so run this in an independent
1916        # test method.
1917        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1918        ctx.verify_mode = ssl.CERT_REQUIRED
1919        s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1920        self.addCleanup(s.close)
1921        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1922                                s.connect, self.server_addr)
1923
1924    def test_connect_capath(self):
1925        # Verify server certificates using the `capath` argument
1926        # NOTE: the subject hashing algorithm has been changed between
1927        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
1928        # contain both versions of each certificate (same content, different
1929        # filename) for this test to be portable across OpenSSL releases.
1930        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1931        ctx.verify_mode = ssl.CERT_REQUIRED
1932        ctx.load_verify_locations(capath=CAPATH)
1933        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1934            s.connect(self.server_addr)
1935            cert = s.getpeercert()
1936            self.assertTrue(cert)
1937
1938        # Same with a bytes `capath` argument
1939        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1940        ctx.verify_mode = ssl.CERT_REQUIRED
1941        ctx.load_verify_locations(capath=BYTES_CAPATH)
1942        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1943            s.connect(self.server_addr)
1944            cert = s.getpeercert()
1945            self.assertTrue(cert)
1946
1947    def test_connect_cadata(self):
1948        with open(SIGNING_CA) as f:
1949            pem = f.read()
1950        der = ssl.PEM_cert_to_DER_cert(pem)
1951        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1952        ctx.verify_mode = ssl.CERT_REQUIRED
1953        ctx.load_verify_locations(cadata=pem)
1954        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1955            s.connect(self.server_addr)
1956            cert = s.getpeercert()
1957            self.assertTrue(cert)
1958
1959        # same with DER
1960        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1961        ctx.verify_mode = ssl.CERT_REQUIRED
1962        ctx.load_verify_locations(cadata=der)
1963        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1964            s.connect(self.server_addr)
1965            cert = s.getpeercert()
1966            self.assertTrue(cert)
1967
1968    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
1969    def test_makefile_close(self):
1970        # Issue #5238: creating a file-like object with makefile() shouldn't
1971        # delay closing the underlying "real socket" (here tested with its
1972        # file descriptor, hence skipping the test under Windows).
1973        ss = test_wrap_socket(socket.socket(socket.AF_INET))
1974        ss.connect(self.server_addr)
1975        fd = ss.fileno()
1976        f = ss.makefile()
1977        f.close()
1978        # The fd is still open
1979        os.read(fd, 0)
1980        # Closing the SSL socket should close the fd too
1981        ss.close()
1982        gc.collect()
1983        with self.assertRaises(OSError) as e:
1984            os.read(fd, 0)
1985        self.assertEqual(e.exception.errno, errno.EBADF)
1986
1987    def test_non_blocking_handshake(self):
1988        s = socket.socket(socket.AF_INET)
1989        s.connect(self.server_addr)
1990        s.setblocking(False)
1991        s = test_wrap_socket(s,
1992                            cert_reqs=ssl.CERT_NONE,
1993                            do_handshake_on_connect=False)
1994        self.addCleanup(s.close)
1995        count = 0
1996        while True:
1997            try:
1998                count += 1
1999                s.do_handshake()
2000                break
2001            except ssl.SSLWantReadError:
2002                select.select([s], [], [])
2003            except ssl.SSLWantWriteError:
2004                select.select([], [s], [])
2005        if support.verbose:
2006            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2007
2008    def test_get_server_certificate(self):
2009        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2010
2011    def test_get_server_certificate_fail(self):
2012        # Connection failure crashes ThreadedEchoServer, so run this in an
2013        # independent test method
2014        _test_get_server_certificate_fail(self, *self.server_addr)
2015
2016    def test_ciphers(self):
2017        with test_wrap_socket(socket.socket(socket.AF_INET),
2018                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2019            s.connect(self.server_addr)
2020        with test_wrap_socket(socket.socket(socket.AF_INET),
2021                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2022            s.connect(self.server_addr)
2023        # Error checking can happen at instantiation or when connecting
2024        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2025            with socket.socket(socket.AF_INET) as sock:
2026                s = test_wrap_socket(sock,
2027                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2028                s.connect(self.server_addr)
2029
2030    def test_get_ca_certs_capath(self):
2031        # capath certs are loaded on request
2032        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2033        ctx.load_verify_locations(capath=CAPATH)
2034        self.assertEqual(ctx.get_ca_certs(), [])
2035        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2036                             server_hostname='localhost') as s:
2037            s.connect(self.server_addr)
2038            cert = s.getpeercert()
2039            self.assertTrue(cert)
2040        self.assertEqual(len(ctx.get_ca_certs()), 1)
2041
2042    @needs_sni
2043    def test_context_setget(self):
2044        # Check that the context of a connected socket can be replaced.
2045        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2046        ctx1.load_verify_locations(capath=CAPATH)
2047        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2048        ctx2.load_verify_locations(capath=CAPATH)
2049        s = socket.socket(socket.AF_INET)
2050        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2051            ss.connect(self.server_addr)
2052            self.assertIs(ss.context, ctx1)
2053            self.assertIs(ss._sslobj.context, ctx1)
2054            ss.context = ctx2
2055            self.assertIs(ss.context, ctx2)
2056            self.assertIs(ss._sslobj.context, ctx2)
2057
2058    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2059        # A simple IO loop. Call func(*args) depending on the error we get
2060        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2061        timeout = kwargs.get('timeout', 10)
2062        deadline = time.monotonic() + timeout
2063        count = 0
2064        while True:
2065            if time.monotonic() > deadline:
2066                self.fail("timeout")
2067            errno = None
2068            count += 1
2069            try:
2070                ret = func(*args)
2071            except ssl.SSLError as e:
2072                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2073                                   ssl.SSL_ERROR_WANT_WRITE):
2074                    raise
2075                errno = e.errno
2076            # Get any data from the outgoing BIO irrespective of any error, and
2077            # send it to the socket.
2078            buf = outgoing.read()
2079            sock.sendall(buf)
2080            # If there's no error, we're done. For WANT_READ, we need to get
2081            # data from the socket and put it in the incoming BIO.
2082            if errno is None:
2083                break
2084            elif errno == ssl.SSL_ERROR_WANT_READ:
2085                buf = sock.recv(32768)
2086                if buf:
2087                    incoming.write(buf)
2088                else:
2089                    incoming.write_eof()
2090        if support.verbose:
2091            sys.stdout.write("Needed %d calls to complete %s().\n"
2092                             % (count, func.__name__))
2093        return ret
2094
2095    def test_bio_handshake(self):
2096        sock = socket.socket(socket.AF_INET)
2097        self.addCleanup(sock.close)
2098        sock.connect(self.server_addr)
2099        incoming = ssl.MemoryBIO()
2100        outgoing = ssl.MemoryBIO()
2101        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2102        self.assertTrue(ctx.check_hostname)
2103        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2104        ctx.load_verify_locations(SIGNING_CA)
2105        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2106                              SIGNED_CERTFILE_HOSTNAME)
2107        self.assertIs(sslobj._sslobj.owner, sslobj)
2108        self.assertIsNone(sslobj.cipher())
2109        self.assertIsNone(sslobj.version())
2110        self.assertIsNotNone(sslobj.shared_ciphers())
2111        self.assertRaises(ValueError, sslobj.getpeercert)
2112        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2113            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2114        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2115        self.assertTrue(sslobj.cipher())
2116        self.assertIsNotNone(sslobj.shared_ciphers())
2117        self.assertIsNotNone(sslobj.version())
2118        self.assertTrue(sslobj.getpeercert())
2119        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2120            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2121        try:
2122            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2123        except ssl.SSLSyscallError:
2124            # If the server shuts down the TCP connection without sending a
2125            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2126            pass
2127        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2128
2129    def test_bio_read_write_data(self):
2130        sock = socket.socket(socket.AF_INET)
2131        self.addCleanup(sock.close)
2132        sock.connect(self.server_addr)
2133        incoming = ssl.MemoryBIO()
2134        outgoing = ssl.MemoryBIO()
2135        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2136        ctx.verify_mode = ssl.CERT_NONE
2137        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2138        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2139        req = b'FOO\n'
2140        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2141        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2142        self.assertEqual(buf, b'foo\n')
2143        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2144
2145
2146class NetworkedTests(unittest.TestCase):
2147
2148    def test_timeout_connect_ex(self):
2149        # Issue #12065: on a timeout, connect_ex() should return the original
2150        # errno (mimicking the behaviour of non-SSL sockets).
2151        with support.transient_internet(REMOTE_HOST):
2152            s = test_wrap_socket(socket.socket(socket.AF_INET),
2153                                cert_reqs=ssl.CERT_REQUIRED,
2154                                do_handshake_on_connect=False)
2155            self.addCleanup(s.close)
2156            s.settimeout(0.0000001)
2157            rc = s.connect_ex((REMOTE_HOST, 443))
2158            if rc == 0:
2159                self.skipTest("REMOTE_HOST responded too quickly")
2160            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2161
2162    @unittest.skipUnless(support.IPV6_ENABLED, 'Needs IPv6')
2163    def test_get_server_certificate_ipv6(self):
2164        with support.transient_internet('ipv6.google.com'):
2165            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2166            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2167
2168
2169def _test_get_server_certificate(test, host, port, cert=None):
2170    pem = ssl.get_server_certificate((host, port))
2171    if not pem:
2172        test.fail("No server certificate on %s:%s!" % (host, port))
2173
2174    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2175    if not pem:
2176        test.fail("No server certificate on %s:%s!" % (host, port))
2177    if support.verbose:
2178        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2179
2180def _test_get_server_certificate_fail(test, host, port):
2181    try:
2182        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2183    except ssl.SSLError as x:
2184        #should fail
2185        if support.verbose:
2186            sys.stdout.write("%s\n" % x)
2187    else:
2188        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2189
2190
2191from test.ssl_servers import make_https_server
2192
2193class ThreadedEchoServer(threading.Thread):
2194
2195    class ConnectionHandler(threading.Thread):
2196
2197        """A mildly complicated class, because we want it to work both
2198        with and without the SSL wrapper around the socket connection, so
2199        that we can test the STARTTLS functionality."""
2200
2201        def __init__(self, server, connsock, addr):
2202            self.server = server
2203            self.running = False
2204            self.sock = connsock
2205            self.addr = addr
2206            self.sock.setblocking(1)
2207            self.sslconn = None
2208            threading.Thread.__init__(self)
2209            self.daemon = True
2210
2211        def wrap_conn(self):
2212            try:
2213                self.sslconn = self.server.context.wrap_socket(
2214                    self.sock, server_side=True)
2215                self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
2216                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2217            except (ConnectionResetError, BrokenPipeError) as e:
2218                # We treat ConnectionResetError as though it were an
2219                # SSLError - OpenSSL on Ubuntu abruptly closes the
2220                # connection when asked to use an unsupported protocol.
2221                #
2222                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2223                # tries to send session tickets after handshake.
2224                # https://github.com/openssl/openssl/issues/6342
2225                self.server.conn_errors.append(str(e))
2226                if self.server.chatty:
2227                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2228                self.running = False
2229                self.close()
2230                return False
2231            except (ssl.SSLError, OSError) as e:
2232                # OSError may occur with wrong protocols, e.g. both
2233                # sides use PROTOCOL_TLS_SERVER.
2234                #
2235                # XXX Various errors can have happened here, for example
2236                # a mismatching protocol version, an invalid certificate,
2237                # or a low-level bug. This should be made more discriminating.
2238                #
2239                # bpo-31323: Store the exception as string to prevent
2240                # a reference leak: server -> conn_errors -> exception
2241                # -> traceback -> self (ConnectionHandler) -> server
2242                self.server.conn_errors.append(str(e))
2243                if self.server.chatty:
2244                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2245                self.running = False
2246                self.server.stop()
2247                self.close()
2248                return False
2249            else:
2250                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2251                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2252                    cert = self.sslconn.getpeercert()
2253                    if support.verbose and self.server.chatty:
2254                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2255                    cert_binary = self.sslconn.getpeercert(True)
2256                    if support.verbose and self.server.chatty:
2257                        sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
2258                cipher = self.sslconn.cipher()
2259                if support.verbose and self.server.chatty:
2260                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2261                    sys.stdout.write(" server: selected protocol is now "
2262                            + str(self.sslconn.selected_npn_protocol()) + "\n")
2263                return True
2264
2265        def read(self):
2266            if self.sslconn:
2267                return self.sslconn.read()
2268            else:
2269                return self.sock.recv(1024)
2270
2271        def write(self, bytes):
2272            if self.sslconn:
2273                return self.sslconn.write(bytes)
2274            else:
2275                return self.sock.send(bytes)
2276
2277        def close(self):
2278            if self.sslconn:
2279                self.sslconn.close()
2280            else:
2281                self.sock.close()
2282
2283        def run(self):
2284            self.running = True
2285            if not self.server.starttls_server:
2286                if not self.wrap_conn():
2287                    return
2288            while self.running:
2289                try:
2290                    msg = self.read()
2291                    stripped = msg.strip()
2292                    if not stripped:
2293                        # eof, so quit this handler
2294                        self.running = False
2295                        try:
2296                            self.sock = self.sslconn.unwrap()
2297                        except OSError:
2298                            # Many tests shut the TCP connection down
2299                            # without an SSL shutdown. This causes
2300                            # unwrap() to raise OSError with errno=0!
2301                            pass
2302                        else:
2303                            self.sslconn = None
2304                        self.close()
2305                    elif stripped == b'over':
2306                        if support.verbose and self.server.connectionchatty:
2307                            sys.stdout.write(" server: client closed connection\n")
2308                        self.close()
2309                        return
2310                    elif (self.server.starttls_server and
2311                          stripped == b'STARTTLS'):
2312                        if support.verbose and self.server.connectionchatty:
2313                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2314                        self.write(b"OK\n")
2315                        if not self.wrap_conn():
2316                            return
2317                    elif (self.server.starttls_server and self.sslconn
2318                          and stripped == b'ENDTLS'):
2319                        if support.verbose and self.server.connectionchatty:
2320                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2321                        self.write(b"OK\n")
2322                        self.sock = self.sslconn.unwrap()
2323                        self.sslconn = None
2324                        if support.verbose and self.server.connectionchatty:
2325                            sys.stdout.write(" server: connection is now unencrypted...\n")
2326                    elif stripped == b'CB tls-unique':
2327                        if support.verbose and self.server.connectionchatty:
2328                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2329                        data = self.sslconn.get_channel_binding("tls-unique")
2330                        self.write(repr(data).encode("us-ascii") + b"\n")
2331                    elif stripped == b'PHA':
2332                        if support.verbose and self.server.connectionchatty:
2333                            sys.stdout.write(" server: initiating post handshake auth\n")
2334                        try:
2335                            self.sslconn.verify_client_post_handshake()
2336                        except ssl.SSLError as e:
2337                            self.write(repr(e).encode("us-ascii") + b"\n")
2338                        else:
2339                            self.write(b"OK\n")
2340                    elif stripped == b'HASCERT':
2341                        if self.sslconn.getpeercert() is not None:
2342                            self.write(b'TRUE\n')
2343                        else:
2344                            self.write(b'FALSE\n')
2345                    elif stripped == b'GETCERT':
2346                        cert = self.sslconn.getpeercert()
2347                        self.write(repr(cert).encode("us-ascii") + b"\n")
2348                    else:
2349                        if (support.verbose and
2350                            self.server.connectionchatty):
2351                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2352                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2353                                             % (msg, ctype, msg.lower(), ctype))
2354                        self.write(msg.lower())
2355                except ConnectionResetError:
2356                    # XXX: OpenSSL 1.1.1 sometimes raises ConnectionResetError
2357                    # when connection is not shut down gracefully.
2358                    if self.server.chatty and support.verbose:
2359                        sys.stdout.write(
2360                            " Connection reset by peer: {}\n".format(
2361                                self.addr)
2362                        )
2363                    self.close()
2364                    self.running = False
2365                except OSError:
2366                    if self.server.chatty:
2367                        handle_error("Test server failure:\n")
2368                    self.close()
2369                    self.running = False
2370
2371                    # normally, we'd just stop here, but for the test
2372                    # harness, we want to stop the server
2373                    self.server.stop()
2374
2375    def __init__(self, certificate=None, ssl_version=None,
2376                 certreqs=None, cacerts=None,
2377                 chatty=True, connectionchatty=False, starttls_server=False,
2378                 npn_protocols=None, alpn_protocols=None,
2379                 ciphers=None, context=None):
2380        if context:
2381            self.context = context
2382        else:
2383            self.context = ssl.SSLContext(ssl_version
2384                                          if ssl_version is not None
2385                                          else ssl.PROTOCOL_TLS_SERVER)
2386            self.context.verify_mode = (certreqs if certreqs is not None
2387                                        else ssl.CERT_NONE)
2388            if cacerts:
2389                self.context.load_verify_locations(cacerts)
2390            if certificate:
2391                self.context.load_cert_chain(certificate)
2392            if npn_protocols:
2393                self.context.set_npn_protocols(npn_protocols)
2394            if alpn_protocols:
2395                self.context.set_alpn_protocols(alpn_protocols)
2396            if ciphers:
2397                self.context.set_ciphers(ciphers)
2398        self.chatty = chatty
2399        self.connectionchatty = connectionchatty
2400        self.starttls_server = starttls_server
2401        self.sock = socket.socket()
2402        self.port = support.bind_port(self.sock)
2403        self.flag = None
2404        self.active = False
2405        self.selected_npn_protocols = []
2406        self.selected_alpn_protocols = []
2407        self.shared_ciphers = []
2408        self.conn_errors = []
2409        threading.Thread.__init__(self)
2410        self.daemon = True
2411
2412    def __enter__(self):
2413        self.start(threading.Event())
2414        self.flag.wait()
2415        return self
2416
2417    def __exit__(self, *args):
2418        self.stop()
2419        self.join()
2420
2421    def start(self, flag=None):
2422        self.flag = flag
2423        threading.Thread.start(self)
2424
2425    def run(self):
2426        self.sock.settimeout(0.05)
2427        self.sock.listen()
2428        self.active = True
2429        if self.flag:
2430            # signal an event
2431            self.flag.set()
2432        while self.active:
2433            try:
2434                newconn, connaddr = self.sock.accept()
2435                if support.verbose and self.chatty:
2436                    sys.stdout.write(' server:  new connection from '
2437                                     + repr(connaddr) + '\n')
2438                handler = self.ConnectionHandler(self, newconn, connaddr)
2439                handler.start()
2440                handler.join()
2441            except socket.timeout:
2442                pass
2443            except KeyboardInterrupt:
2444                self.stop()
2445            except BaseException as e:
2446                if support.verbose and self.chatty:
2447                    sys.stdout.write(
2448                        ' connection handling failed: ' + repr(e) + '\n')
2449
2450        self.sock.close()
2451
2452    def stop(self):
2453        self.active = False
2454
2455class AsyncoreEchoServer(threading.Thread):
2456
2457    # this one's based on asyncore.dispatcher
2458
2459    class EchoServer (asyncore.dispatcher):
2460
2461        class ConnectionHandler(asyncore.dispatcher_with_send):
2462
2463            def __init__(self, conn, certfile):
2464                self.socket = test_wrap_socket(conn, server_side=True,
2465                                              certfile=certfile,
2466                                              do_handshake_on_connect=False)
2467                asyncore.dispatcher_with_send.__init__(self, self.socket)
2468                self._ssl_accepting = True
2469                self._do_ssl_handshake()
2470
2471            def readable(self):
2472                if isinstance(self.socket, ssl.SSLSocket):
2473                    while self.socket.pending() > 0:
2474                        self.handle_read_event()
2475                return True
2476
2477            def _do_ssl_handshake(self):
2478                try:
2479                    self.socket.do_handshake()
2480                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2481                    return
2482                except ssl.SSLEOFError:
2483                    return self.handle_close()
2484                except ssl.SSLError:
2485                    raise
2486                except OSError as err:
2487                    if err.args[0] == errno.ECONNABORTED:
2488                        return self.handle_close()
2489                else:
2490                    self._ssl_accepting = False
2491
2492            def handle_read(self):
2493                if self._ssl_accepting:
2494                    self._do_ssl_handshake()
2495                else:
2496                    data = self.recv(1024)
2497                    if support.verbose:
2498                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2499                    if not data:
2500                        self.close()
2501                    else:
2502                        self.send(data.lower())
2503
2504            def handle_close(self):
2505                self.close()
2506                if support.verbose:
2507                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2508
2509            def handle_error(self):
2510                raise
2511
2512        def __init__(self, certfile):
2513            self.certfile = certfile
2514            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2515            self.port = support.bind_port(sock, '')
2516            asyncore.dispatcher.__init__(self, sock)
2517            self.listen(5)
2518
2519        def handle_accepted(self, sock_obj, addr):
2520            if support.verbose:
2521                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2522            self.ConnectionHandler(sock_obj, self.certfile)
2523
2524        def handle_error(self):
2525            raise
2526
2527    def __init__(self, certfile):
2528        self.flag = None
2529        self.active = False
2530        self.server = self.EchoServer(certfile)
2531        self.port = self.server.port
2532        threading.Thread.__init__(self)
2533        self.daemon = True
2534
2535    def __str__(self):
2536        return "<%s %s>" % (self.__class__.__name__, self.server)
2537
2538    def __enter__(self):
2539        self.start(threading.Event())
2540        self.flag.wait()
2541        return self
2542
2543    def __exit__(self, *args):
2544        if support.verbose:
2545            sys.stdout.write(" cleanup: stopping server.\n")
2546        self.stop()
2547        if support.verbose:
2548            sys.stdout.write(" cleanup: joining server thread.\n")
2549        self.join()
2550        if support.verbose:
2551            sys.stdout.write(" cleanup: successfully joined.\n")
2552        # make sure that ConnectionHandler is removed from socket_map
2553        asyncore.close_all(ignore_all=True)
2554
2555    def start (self, flag=None):
2556        self.flag = flag
2557        threading.Thread.start(self)
2558
2559    def run(self):
2560        self.active = True
2561        if self.flag:
2562            self.flag.set()
2563        while self.active:
2564            try:
2565                asyncore.loop(1)
2566            except:
2567                pass
2568
2569    def stop(self):
2570        self.active = False
2571        self.server.close()
2572
2573def server_params_test(client_context, server_context, indata=b"FOO\n",
2574                       chatty=True, connectionchatty=False, sni_name=None,
2575                       session=None):
2576    """
2577    Launch a server, connect a client to it and try various reads
2578    and writes.
2579    """
2580    stats = {}
2581    server = ThreadedEchoServer(context=server_context,
2582                                chatty=chatty,
2583                                connectionchatty=False)
2584    with server:
2585        with client_context.wrap_socket(socket.socket(),
2586                server_hostname=sni_name, session=session) as s:
2587            s.connect((HOST, server.port))
2588            for arg in [indata, bytearray(indata), memoryview(indata)]:
2589                if connectionchatty:
2590                    if support.verbose:
2591                        sys.stdout.write(
2592                            " client:  sending %r...\n" % indata)
2593                s.write(arg)
2594                outdata = s.read()
2595                if connectionchatty:
2596                    if support.verbose:
2597                        sys.stdout.write(" client:  read %r\n" % outdata)
2598                if outdata != indata.lower():
2599                    raise AssertionError(
2600                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2601                        % (outdata[:20], len(outdata),
2602                           indata[:20].lower(), len(indata)))
2603            s.write(b"over\n")
2604            if connectionchatty:
2605                if support.verbose:
2606                    sys.stdout.write(" client:  closing connection.\n")
2607            stats.update({
2608                'compression': s.compression(),
2609                'cipher': s.cipher(),
2610                'peercert': s.getpeercert(),
2611                'client_alpn_protocol': s.selected_alpn_protocol(),
2612                'client_npn_protocol': s.selected_npn_protocol(),
2613                'version': s.version(),
2614                'session_reused': s.session_reused,
2615                'session': s.session,
2616            })
2617            s.close()
2618        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2619        stats['server_npn_protocols'] = server.selected_npn_protocols
2620        stats['server_shared_ciphers'] = server.shared_ciphers
2621    return stats
2622
2623def try_protocol_combo(server_protocol, client_protocol, expect_success,
2624                       certsreqs=None, server_options=0, client_options=0):
2625    """
2626    Try to SSL-connect using *client_protocol* to *server_protocol*.
2627    If *expect_success* is true, assert that the connection succeeds,
2628    if it's false, assert that the connection fails.
2629    Also, if *expect_success* is a string, assert that it is the protocol
2630    version actually used by the connection.
2631    """
2632    if certsreqs is None:
2633        certsreqs = ssl.CERT_NONE
2634    certtype = {
2635        ssl.CERT_NONE: "CERT_NONE",
2636        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2637        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2638    }[certsreqs]
2639    if support.verbose:
2640        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2641        sys.stdout.write(formatstr %
2642                         (ssl.get_protocol_name(client_protocol),
2643                          ssl.get_protocol_name(server_protocol),
2644                          certtype))
2645    client_context = ssl.SSLContext(client_protocol)
2646    client_context.options |= client_options
2647    server_context = ssl.SSLContext(server_protocol)
2648    server_context.options |= server_options
2649
2650    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2651    if (min_version is not None
2652    # SSLContext.minimum_version is only available on recent OpenSSL
2653    # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2654    and hasattr(server_context, 'minimum_version')
2655    and server_protocol == ssl.PROTOCOL_TLS
2656    and server_context.minimum_version > min_version):
2657        # If OpenSSL configuration is strict and requires more recent TLS
2658        # version, we have to change the minimum to test old TLS versions.
2659        server_context.minimum_version = min_version
2660
2661    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2662    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2663    # starting from OpenSSL 1.0.0 (see issue #8322).
2664    if client_context.protocol == ssl.PROTOCOL_TLS:
2665        client_context.set_ciphers("ALL")
2666
2667    for ctx in (client_context, server_context):
2668        ctx.verify_mode = certsreqs
2669        ctx.load_cert_chain(SIGNED_CERTFILE)
2670        ctx.load_verify_locations(SIGNING_CA)
2671    try:
2672        stats = server_params_test(client_context, server_context,
2673                                   chatty=False, connectionchatty=False)
2674    # Protocol mismatch can result in either an SSLError, or a
2675    # "Connection reset by peer" error.
2676    except ssl.SSLError:
2677        if expect_success:
2678            raise
2679    except OSError as e:
2680        if expect_success or e.errno != errno.ECONNRESET:
2681            raise
2682    else:
2683        if not expect_success:
2684            raise AssertionError(
2685                "Client protocol %s succeeded with server protocol %s!"
2686                % (ssl.get_protocol_name(client_protocol),
2687                   ssl.get_protocol_name(server_protocol)))
2688        elif (expect_success is not True
2689              and expect_success != stats['version']):
2690            raise AssertionError("version mismatch: expected %r, got %r"
2691                                 % (expect_success, stats['version']))
2692
2693
2694class ThreadedTests(unittest.TestCase):
2695
2696    @skip_if_broken_ubuntu_ssl
2697    def test_echo(self):
2698        """Basic test of an SSL client connecting to a server"""
2699        if support.verbose:
2700            sys.stdout.write("\n")
2701        for protocol in PROTOCOLS:
2702            if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
2703                continue
2704            with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
2705                context = ssl.SSLContext(protocol)
2706                context.load_cert_chain(CERTFILE)
2707                server_params_test(context, context,
2708                                   chatty=True, connectionchatty=True)
2709
2710        client_context, server_context, hostname = testing_context()
2711
2712        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2713            server_params_test(client_context=client_context,
2714                               server_context=server_context,
2715                               chatty=True, connectionchatty=True,
2716                               sni_name=hostname)
2717
2718        client_context.check_hostname = False
2719        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2720            with self.assertRaises(ssl.SSLError) as e:
2721                server_params_test(client_context=server_context,
2722                                   server_context=client_context,
2723                                   chatty=True, connectionchatty=True,
2724                                   sni_name=hostname)
2725            self.assertIn('called a function you should not call',
2726                          str(e.exception))
2727
2728        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2729            with self.assertRaises(ssl.SSLError) as e:
2730                server_params_test(client_context=server_context,
2731                                   server_context=server_context,
2732                                   chatty=True, connectionchatty=True)
2733            self.assertIn('called a function you should not call',
2734                          str(e.exception))
2735
2736        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2737            with self.assertRaises(ssl.SSLError) as e:
2738                server_params_test(client_context=server_context,
2739                                   server_context=client_context,
2740                                   chatty=True, connectionchatty=True)
2741            self.assertIn('called a function you should not call',
2742                          str(e.exception))
2743
2744    def test_getpeercert(self):
2745        if support.verbose:
2746            sys.stdout.write("\n")
2747
2748        client_context, server_context, hostname = testing_context()
2749        server = ThreadedEchoServer(context=server_context, chatty=False)
2750        with server:
2751            with client_context.wrap_socket(socket.socket(),
2752                                            do_handshake_on_connect=False,
2753                                            server_hostname=hostname) as s:
2754                s.connect((HOST, server.port))
2755                # getpeercert() raise ValueError while the handshake isn't
2756                # done.
2757                with self.assertRaises(ValueError):
2758                    s.getpeercert()
2759                s.do_handshake()
2760                cert = s.getpeercert()
2761                self.assertTrue(cert, "Can't get peer certificate.")
2762                cipher = s.cipher()
2763                if support.verbose:
2764                    sys.stdout.write(pprint.pformat(cert) + '\n')
2765                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2766                if 'subject' not in cert:
2767                    self.fail("No subject field in certificate: %s." %
2768                              pprint.pformat(cert))
2769                if ((('organizationName', 'Python Software Foundation'),)
2770                    not in cert['subject']):
2771                    self.fail(
2772                        "Missing or invalid 'organizationName' field in certificate subject; "
2773                        "should be 'Python Software Foundation'.")
2774                self.assertIn('notBefore', cert)
2775                self.assertIn('notAfter', cert)
2776                before = ssl.cert_time_to_seconds(cert['notBefore'])
2777                after = ssl.cert_time_to_seconds(cert['notAfter'])
2778                self.assertLess(before, after)
2779
2780    @unittest.skipUnless(have_verify_flags(),
2781                        "verify_flags need OpenSSL > 0.9.8")
2782    def test_crl_check(self):
2783        if support.verbose:
2784            sys.stdout.write("\n")
2785
2786        client_context, server_context, hostname = testing_context()
2787
2788        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
2789        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
2790
2791        # VERIFY_DEFAULT should pass
2792        server = ThreadedEchoServer(context=server_context, chatty=True)
2793        with server:
2794            with client_context.wrap_socket(socket.socket(),
2795                                            server_hostname=hostname) as s:
2796                s.connect((HOST, server.port))
2797                cert = s.getpeercert()
2798                self.assertTrue(cert, "Can't get peer certificate.")
2799
2800        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
2801        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
2802
2803        server = ThreadedEchoServer(context=server_context, chatty=True)
2804        with server:
2805            with client_context.wrap_socket(socket.socket(),
2806                                            server_hostname=hostname) as s:
2807                with self.assertRaisesRegex(ssl.SSLError,
2808                                            "certificate verify failed"):
2809                    s.connect((HOST, server.port))
2810
2811        # now load a CRL file. The CRL file is signed by the CA.
2812        client_context.load_verify_locations(CRLFILE)
2813
2814        server = ThreadedEchoServer(context=server_context, chatty=True)
2815        with server:
2816            with client_context.wrap_socket(socket.socket(),
2817                                            server_hostname=hostname) as s:
2818                s.connect((HOST, server.port))
2819                cert = s.getpeercert()
2820                self.assertTrue(cert, "Can't get peer certificate.")
2821
2822    def test_check_hostname(self):
2823        if support.verbose:
2824            sys.stdout.write("\n")
2825
2826        client_context, server_context, hostname = testing_context()
2827
2828        # correct hostname should verify
2829        server = ThreadedEchoServer(context=server_context, chatty=True)
2830        with server:
2831            with client_context.wrap_socket(socket.socket(),
2832                                            server_hostname=hostname) as s:
2833                s.connect((HOST, server.port))
2834                cert = s.getpeercert()
2835                self.assertTrue(cert, "Can't get peer certificate.")
2836
2837        # incorrect hostname should raise an exception
2838        server = ThreadedEchoServer(context=server_context, chatty=True)
2839        with server:
2840            with client_context.wrap_socket(socket.socket(),
2841                                            server_hostname="invalid") as s:
2842                with self.assertRaisesRegex(
2843                        ssl.CertificateError,
2844                        "Hostname mismatch, certificate is not valid for 'invalid'."):
2845                    s.connect((HOST, server.port))
2846
2847        # missing server_hostname arg should cause an exception, too
2848        server = ThreadedEchoServer(context=server_context, chatty=True)
2849        with server:
2850            with socket.socket() as s:
2851                with self.assertRaisesRegex(ValueError,
2852                                            "check_hostname requires server_hostname"):
2853                    client_context.wrap_socket(s)
2854
2855    def test_ecc_cert(self):
2856        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2857        client_context.load_verify_locations(SIGNING_CA)
2858        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
2859        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
2860
2861        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2862        # load ECC cert
2863        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
2864
2865        # correct hostname should verify
2866        server = ThreadedEchoServer(context=server_context, chatty=True)
2867        with server:
2868            with client_context.wrap_socket(socket.socket(),
2869                                            server_hostname=hostname) as s:
2870                s.connect((HOST, server.port))
2871                cert = s.getpeercert()
2872                self.assertTrue(cert, "Can't get peer certificate.")
2873                cipher = s.cipher()[0].split('-')
2874                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
2875
2876    def test_dual_rsa_ecc(self):
2877        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2878        client_context.load_verify_locations(SIGNING_CA)
2879        # TODO: fix TLSv1.3 once SSLContext can restrict signature
2880        #       algorithms.
2881        client_context.options |= ssl.OP_NO_TLSv1_3
2882        # only ECDSA certs
2883        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
2884        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
2885
2886        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2887        # load ECC and RSA key/cert pairs
2888        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
2889        server_context.load_cert_chain(SIGNED_CERTFILE)
2890
2891        # correct hostname should verify
2892        server = ThreadedEchoServer(context=server_context, chatty=True)
2893        with server:
2894            with client_context.wrap_socket(socket.socket(),
2895                                            server_hostname=hostname) as s:
2896                s.connect((HOST, server.port))
2897                cert = s.getpeercert()
2898                self.assertTrue(cert, "Can't get peer certificate.")
2899                cipher = s.cipher()[0].split('-')
2900                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
2901
2902    def test_check_hostname_idn(self):
2903        if support.verbose:
2904            sys.stdout.write("\n")
2905
2906        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2907        server_context.load_cert_chain(IDNSANSFILE)
2908
2909        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2910        context.verify_mode = ssl.CERT_REQUIRED
2911        context.check_hostname = True
2912        context.load_verify_locations(SIGNING_CA)
2913
2914        # correct hostname should verify, when specified in several
2915        # different ways
2916        idn_hostnames = [
2917            ('könig.idn.pythontest.net',
2918             'xn--knig-5qa.idn.pythontest.net'),
2919            ('xn--knig-5qa.idn.pythontest.net',
2920             'xn--knig-5qa.idn.pythontest.net'),
2921            (b'xn--knig-5qa.idn.pythontest.net',
2922             'xn--knig-5qa.idn.pythontest.net'),
2923
2924            ('königsgäßchen.idna2003.pythontest.net',
2925             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
2926            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
2927             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
2928            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
2929             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
2930
2931            # ('königsgäßchen.idna2008.pythontest.net',
2932            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
2933            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
2934             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
2935            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
2936             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
2937
2938        ]
2939        for server_hostname, expected_hostname in idn_hostnames:
2940            server = ThreadedEchoServer(context=server_context, chatty=True)
2941            with server:
2942                with context.wrap_socket(socket.socket(),
2943                                         server_hostname=server_hostname) as s:
2944                    self.assertEqual(s.server_hostname, expected_hostname)
2945                    s.connect((HOST, server.port))
2946                    cert = s.getpeercert()
2947                    self.assertEqual(s.server_hostname, expected_hostname)
2948                    self.assertTrue(cert, "Can't get peer certificate.")
2949
2950        # incorrect hostname should raise an exception
2951        server = ThreadedEchoServer(context=server_context, chatty=True)
2952        with server:
2953            with context.wrap_socket(socket.socket(),
2954                                     server_hostname="python.example.org") as s:
2955                with self.assertRaises(ssl.CertificateError):
2956                    s.connect((HOST, server.port))
2957
2958    def test_wrong_cert_tls12(self):
2959        """Connecting when the server rejects the client's certificate
2960
2961        Launch a server with CERT_REQUIRED, and check that trying to
2962        connect to it with a wrong client certificate fails.
2963        """
2964        client_context, server_context, hostname = testing_context()
2965        # load client cert that is not signed by trusted CA
2966        client_context.load_cert_chain(CERTFILE)
2967        # require TLS client authentication
2968        server_context.verify_mode = ssl.CERT_REQUIRED
2969        # TLS 1.3 has different handshake
2970        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
2971
2972        server = ThreadedEchoServer(
2973            context=server_context, chatty=True, connectionchatty=True,
2974        )
2975
2976        with server, \
2977                client_context.wrap_socket(socket.socket(),
2978                                           server_hostname=hostname) as s:
2979            try:
2980                # Expect either an SSL error about the server rejecting
2981                # the connection, or a low-level connection reset (which
2982                # sometimes happens on Windows)
2983                s.connect((HOST, server.port))
2984            except ssl.SSLError as e:
2985                if support.verbose:
2986                    sys.stdout.write("\nSSLError is %r\n" % e)
2987            except OSError as e:
2988                if e.errno != errno.ECONNRESET:
2989                    raise
2990                if support.verbose:
2991                    sys.stdout.write("\nsocket.error is %r\n" % e)
2992            else:
2993                self.fail("Use of invalid cert should have failed!")
2994
2995    @unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
2996    def test_wrong_cert_tls13(self):
2997        client_context, server_context, hostname = testing_context()
2998        # load client cert that is not signed by trusted CA
2999        client_context.load_cert_chain(CERTFILE)
3000        server_context.verify_mode = ssl.CERT_REQUIRED
3001        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3002        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3003
3004        server = ThreadedEchoServer(
3005            context=server_context, chatty=True, connectionchatty=True,
3006        )
3007        with server, \
3008             client_context.wrap_socket(socket.socket(),
3009                                        server_hostname=hostname) as s:
3010            # TLS 1.3 perform client cert exchange after handshake
3011            s.connect((HOST, server.port))
3012            try:
3013                s.write(b'data')
3014                s.read(4)
3015            except ssl.SSLError as e:
3016                if support.verbose:
3017                    sys.stdout.write("\nSSLError is %r\n" % e)
3018            except OSError as e:
3019                if e.errno != errno.ECONNRESET:
3020                    raise
3021                if support.verbose:
3022                    sys.stdout.write("\nsocket.error is %r\n" % e)
3023            else:
3024                self.fail("Use of invalid cert should have failed!")
3025
3026    def test_rude_shutdown(self):
3027        """A brutal shutdown of an SSL server should raise an OSError
3028        in the client when attempting handshake.
3029        """
3030        listener_ready = threading.Event()
3031        listener_gone = threading.Event()
3032
3033        s = socket.socket()
3034        port = support.bind_port(s, HOST)
3035
3036        # `listener` runs in a thread.  It sits in an accept() until
3037        # the main thread connects.  Then it rudely closes the socket,
3038        # and sets Event `listener_gone` to let the main thread know
3039        # the socket is gone.
3040        def listener():
3041            s.listen()
3042            listener_ready.set()
3043            newsock, addr = s.accept()
3044            newsock.close()
3045            s.close()
3046            listener_gone.set()
3047
3048        def connector():
3049            listener_ready.wait()
3050            with socket.socket() as c:
3051                c.connect((HOST, port))
3052                listener_gone.wait()
3053                try:
3054                    ssl_sock = test_wrap_socket(c)
3055                except OSError:
3056                    pass
3057                else:
3058                    self.fail('connecting to closed SSL socket should have failed')
3059
3060        t = threading.Thread(target=listener)
3061        t.start()
3062        try:
3063            connector()
3064        finally:
3065            t.join()
3066
3067    def test_ssl_cert_verify_error(self):
3068        if support.verbose:
3069            sys.stdout.write("\n")
3070
3071        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3072        server_context.load_cert_chain(SIGNED_CERTFILE)
3073
3074        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3075
3076        server = ThreadedEchoServer(context=server_context, chatty=True)
3077        with server:
3078            with context.wrap_socket(socket.socket(),
3079                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3080                try:
3081                    s.connect((HOST, server.port))
3082                except ssl.SSLError as e:
3083                    msg = 'unable to get local issuer certificate'
3084                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3085                    self.assertEqual(e.verify_code, 20)
3086                    self.assertEqual(e.verify_message, msg)
3087                    self.assertIn(msg, repr(e))
3088                    self.assertIn('certificate verify failed', repr(e))
3089
3090    @skip_if_broken_ubuntu_ssl
3091    @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
3092                         "OpenSSL is compiled without SSLv2 support")
3093    def test_protocol_sslv2(self):
3094        """Connecting to an SSLv2 server with various client options"""
3095        if support.verbose:
3096            sys.stdout.write("\n")
3097        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3098        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3099        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3100        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3101        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3102            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3103        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3104        # SSLv23 client with specific SSL options
3105        if no_sslv2_implies_sslv3_hello():
3106            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3107            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3108                               client_options=ssl.OP_NO_SSLv2)
3109        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3110                           client_options=ssl.OP_NO_SSLv3)
3111        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3112                           client_options=ssl.OP_NO_TLSv1)
3113
3114    @skip_if_broken_ubuntu_ssl
3115    def test_PROTOCOL_TLS(self):
3116        """Connecting to an SSLv23 server with various client options"""
3117        if support.verbose:
3118            sys.stdout.write("\n")
3119        if hasattr(ssl, 'PROTOCOL_SSLv2'):
3120            try:
3121                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3122            except OSError as x:
3123                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3124                if support.verbose:
3125                    sys.stdout.write(
3126                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3127                        % str(x))
3128        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3129            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3130        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3131        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3132
3133        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3134            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3135        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3136        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3137
3138        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3139            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3140        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3141        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3142
3143        # Server with specific SSL options
3144        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3145            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3146                           server_options=ssl.OP_NO_SSLv3)
3147        # Will choose TLSv1
3148        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3149                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3150        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3151                           server_options=ssl.OP_NO_TLSv1)
3152
3153
3154    @skip_if_broken_ubuntu_ssl
3155    @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
3156                         "OpenSSL is compiled without SSLv3 support")
3157    def test_protocol_sslv3(self):
3158        """Connecting to an SSLv3 server with various client options"""
3159        if support.verbose:
3160            sys.stdout.write("\n")
3161        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3162        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3163        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3164        if hasattr(ssl, 'PROTOCOL_SSLv2'):
3165            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3166        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3167                           client_options=ssl.OP_NO_SSLv3)
3168        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3169        if no_sslv2_implies_sslv3_hello():
3170            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3171            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS,
3172                               False, client_options=ssl.OP_NO_SSLv2)
3173
3174    @skip_if_broken_ubuntu_ssl
3175    def test_protocol_tlsv1(self):
3176        """Connecting to a TLSv1 server with various client options"""
3177        if support.verbose:
3178            sys.stdout.write("\n")
3179        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3180        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3181        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3182        if hasattr(ssl, 'PROTOCOL_SSLv2'):
3183            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3184        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3185            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3186        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3187                           client_options=ssl.OP_NO_TLSv1)
3188
3189    @skip_if_broken_ubuntu_ssl
3190    @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
3191                         "TLS version 1.1 not supported.")
3192    def test_protocol_tlsv1_1(self):
3193        """Connecting to a TLSv1.1 server with various client options.
3194           Testing against older TLS versions."""
3195        if support.verbose:
3196            sys.stdout.write("\n")
3197        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3198        if hasattr(ssl, 'PROTOCOL_SSLv2'):
3199            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3200        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3201            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3202        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3203                           client_options=ssl.OP_NO_TLSv1_1)
3204
3205        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3206        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
3207        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
3208
3209    @skip_if_broken_ubuntu_ssl
3210    @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
3211                         "TLS version 1.2 not supported.")
3212    def test_protocol_tlsv1_2(self):
3213        """Connecting to a TLSv1.2 server with various client options.
3214           Testing against older TLS versions."""
3215        if support.verbose:
3216            sys.stdout.write("\n")
3217        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3218                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3219                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3220        if hasattr(ssl, 'PROTOCOL_SSLv2'):
3221            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3222        if hasattr(ssl, 'PROTOCOL_SSLv3'):
3223            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3224        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3225                           client_options=ssl.OP_NO_TLSv1_2)
3226
3227        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3228        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3229        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3230        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3231        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3232
3233    def test_starttls(self):
3234        """Switching from clear text to encrypted and back again."""
3235        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3236
3237        server = ThreadedEchoServer(CERTFILE,
3238                                    starttls_server=True,
3239                                    chatty=True,
3240                                    connectionchatty=True)
3241        wrapped = False
3242        with server:
3243            s = socket.socket()
3244            s.setblocking(1)
3245            s.connect((HOST, server.port))
3246            if support.verbose:
3247                sys.stdout.write("\n")
3248            for indata in msgs:
3249                if support.verbose:
3250                    sys.stdout.write(
3251                        " client:  sending %r...\n" % indata)
3252                if wrapped:
3253                    conn.write(indata)
3254                    outdata = conn.read()
3255                else:
3256                    s.send(indata)
3257                    outdata = s.recv(1024)
3258                msg = outdata.strip().lower()
3259                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3260                    # STARTTLS ok, switch to secure mode
3261                    if support.verbose:
3262                        sys.stdout.write(
3263                            " client:  read %r from server, starting TLS...\n"
3264                            % msg)
3265                    conn = test_wrap_socket(s)
3266                    wrapped = True
3267                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3268                    # ENDTLS ok, switch back to clear text
3269                    if support.verbose:
3270                        sys.stdout.write(
3271                            " client:  read %r from server, ending TLS...\n"
3272                            % msg)
3273                    s = conn.unwrap()
3274                    wrapped = False
3275                else:
3276                    if support.verbose:
3277                        sys.stdout.write(
3278                            " client:  read %r from server\n" % msg)
3279            if support.verbose:
3280                sys.stdout.write(" client:  closing connection.\n")
3281            if wrapped:
3282                conn.write(b"over\n")
3283            else:
3284                s.send(b"over\n")
3285            if wrapped:
3286                conn.close()
3287            else:
3288                s.close()
3289
3290    def test_socketserver(self):
3291        """Using socketserver to create and manage SSL connections."""
3292        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3293        # try to connect
3294        if support.verbose:
3295            sys.stdout.write('\n')
3296        with open(CERTFILE, 'rb') as f:
3297            d1 = f.read()
3298        d2 = ''
3299        # now fetch the same data from the HTTPS server
3300        url = 'https://localhost:%d/%s' % (
3301            server.port, os.path.split(CERTFILE)[1])
3302        context = ssl.create_default_context(cafile=SIGNING_CA)
3303        f = urllib.request.urlopen(url, context=context)
3304        try:
3305            dlen = f.info().get("content-length")
3306            if dlen and (int(dlen) > 0):
3307                d2 = f.read(int(dlen))
3308                if support.verbose:
3309                    sys.stdout.write(
3310                        " client: read %d bytes from remote server '%s'\n"
3311                        % (len(d2), server))
3312        finally:
3313            f.close()
3314        self.assertEqual(d1, d2)
3315
3316    def test_asyncore_server(self):
3317        """Check the example asyncore integration."""
3318        if support.verbose:
3319            sys.stdout.write("\n")
3320
3321        indata = b"FOO\n"
3322        server = AsyncoreEchoServer(CERTFILE)
3323        with server:
3324            s = test_wrap_socket(socket.socket())
3325            s.connect(('127.0.0.1', server.port))
3326            if support.verbose:
3327                sys.stdout.write(
3328                    " client:  sending %r...\n" % indata)
3329            s.write(indata)
3330            outdata = s.read()
3331            if support.verbose:
3332                sys.stdout.write(" client:  read %r\n" % outdata)
3333            if outdata != indata.lower():
3334                self.fail(
3335                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3336                    % (outdata[:20], len(outdata),
3337                       indata[:20].lower(), len(indata)))
3338            s.write(b"over\n")
3339            if support.verbose:
3340                sys.stdout.write(" client:  closing connection.\n")
3341            s.close()
3342            if support.verbose:
3343                sys.stdout.write(" client:  connection closed.\n")
3344
3345    def test_recv_send(self):
3346        """Test recv(), send() and friends."""
3347        if support.verbose:
3348            sys.stdout.write("\n")
3349
3350        server = ThreadedEchoServer(CERTFILE,
3351                                    certreqs=ssl.CERT_NONE,
3352                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3353                                    cacerts=CERTFILE,
3354                                    chatty=True,
3355                                    connectionchatty=False)
3356        with server:
3357            s = test_wrap_socket(socket.socket(),
3358                                server_side=False,
3359                                certfile=CERTFILE,
3360                                ca_certs=CERTFILE,
3361                                cert_reqs=ssl.CERT_NONE,
3362                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3363            s.connect((HOST, server.port))
3364            # helper methods for standardising recv* method signatures
3365            def _recv_into():
3366                b = bytearray(b"\0"*100)
3367                count = s.recv_into(b)
3368                return b[:count]
3369
3370            def _recvfrom_into():
3371                b = bytearray(b"\0"*100)
3372                count, addr = s.recvfrom_into(b)
3373                return b[:count]
3374
3375            # (name, method, expect success?, *args, return value func)
3376            send_methods = [
3377                ('send', s.send, True, [], len),
3378                ('sendto', s.sendto, False, ["some.address"], len),
3379                ('sendall', s.sendall, True, [], lambda x: None),
3380            ]
3381            # (name, method, whether to expect success, *args)
3382            recv_methods = [
3383                ('recv', s.recv, True, []),
3384                ('recvfrom', s.recvfrom, False, ["some.address"]),
3385                ('recv_into', _recv_into, True, []),
3386                ('recvfrom_into', _recvfrom_into, False, []),
3387            ]
3388            data_prefix = "PREFIX_"
3389
3390            for (meth_name, send_meth, expect_success, args,
3391                    ret_val_meth) in send_methods:
3392                indata = (data_prefix + meth_name).encode('ascii')
3393                try:
3394                    ret = send_meth(indata, *args)
3395                    msg = "sending with {}".format(meth_name)
3396                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3397                    outdata = s.read()
3398                    if outdata != indata.lower():
3399                        self.fail(
3400                            "While sending with <<{name:s}>> bad data "
3401                            "<<{outdata:r}>> ({nout:d}) received; "
3402                            "expected <<{indata:r}>> ({nin:d})\n".format(
3403                                name=meth_name, outdata=outdata[:20],
3404                                nout=len(outdata),
3405                                indata=indata[:20], nin=len(indata)
3406                            )
3407                        )
3408                except ValueError as e:
3409                    if expect_success:
3410                        self.fail(
3411                            "Failed to send with method <<{name:s}>>; "
3412                            "expected to succeed.\n".format(name=meth_name)
3413                        )
3414                    if not str(e).startswith(meth_name):
3415                        self.fail(
3416                            "Method <<{name:s}>> failed with unexpected "
3417                            "exception message: {exp:s}\n".format(
3418                                name=meth_name, exp=e
3419                            )
3420                        )
3421
3422            for meth_name, recv_meth, expect_success, args in recv_methods:
3423                indata = (data_prefix + meth_name).encode('ascii')
3424                try:
3425                    s.send(indata)
3426                    outdata = recv_meth(*args)
3427                    if outdata != indata.lower():
3428                        self.fail(
3429                            "While receiving with <<{name:s}>> bad data "
3430                            "<<{outdata:r}>> ({nout:d}) received; "
3431                            "expected <<{indata:r}>> ({nin:d})\n".format(
3432                                name=meth_name, outdata=outdata[:20],
3433                                nout=len(outdata),
3434                                indata=indata[:20], nin=len(indata)
3435                            )
3436                        )
3437                except ValueError as e:
3438                    if expect_success:
3439                        self.fail(
3440                            "Failed to receive with method <<{name:s}>>; "
3441                            "expected to succeed.\n".format(name=meth_name)
3442                        )
3443                    if not str(e).startswith(meth_name):
3444                        self.fail(
3445                            "Method <<{name:s}>> failed with unexpected "
3446                            "exception message: {exp:s}\n".format(
3447                                name=meth_name, exp=e
3448                            )
3449                        )
3450                    # consume data
3451                    s.read()
3452
3453            # read(-1, buffer) is supported, even though read(-1) is not
3454            data = b"data"
3455            s.send(data)
3456            buffer = bytearray(len(data))
3457            self.assertEqual(s.read(-1, buffer), len(data))
3458            self.assertEqual(buffer, data)
3459
3460            # sendall accepts bytes-like objects
3461            if ctypes is not None:
3462                ubyte = ctypes.c_ubyte * len(data)
3463                byteslike = ubyte.from_buffer_copy(data)
3464                s.sendall(byteslike)
3465                self.assertEqual(s.read(), data)
3466
3467            # Make sure sendmsg et al are disallowed to avoid
3468            # inadvertent disclosure of data and/or corruption
3469            # of the encrypted data stream
3470            self.assertRaises(NotImplementedError, s.dup)
3471            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3472            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3473            self.assertRaises(NotImplementedError,
3474                              s.recvmsg_into, [bytearray(100)])
3475            s.write(b"over\n")
3476
3477            self.assertRaises(ValueError, s.recv, -1)
3478            self.assertRaises(ValueError, s.read, -1)
3479
3480            s.close()
3481
3482    def test_recv_zero(self):
3483        server = ThreadedEchoServer(CERTFILE)
3484        server.__enter__()
3485        self.addCleanup(server.__exit__, None, None)
3486        s = socket.create_connection((HOST, server.port))
3487        self.addCleanup(s.close)
3488        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3489        self.addCleanup(s.close)
3490
3491        # recv/read(0) should return no data
3492        s.send(b"data")
3493        self.assertEqual(s.recv(0), b"")
3494        self.assertEqual(s.read(0), b"")
3495        self.assertEqual(s.read(), b"data")
3496
3497        # Should not block if the other end sends no data
3498        s.setblocking(False)
3499        self.assertEqual(s.recv(0), b"")
3500        self.assertEqual(s.recv_into(bytearray()), 0)
3501
3502    def test_nonblocking_send(self):
3503        server = ThreadedEchoServer(CERTFILE,
3504                                    certreqs=ssl.CERT_NONE,
3505                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3506                                    cacerts=CERTFILE,
3507                                    chatty=True,
3508                                    connectionchatty=False)
3509        with server:
3510            s = test_wrap_socket(socket.socket(),
3511                                server_side=False,
3512                                certfile=CERTFILE,
3513                                ca_certs=CERTFILE,
3514                                cert_reqs=ssl.CERT_NONE,
3515                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3516            s.connect((HOST, server.port))
3517            s.setblocking(False)
3518
3519            # If we keep sending data, at some point the buffers
3520            # will be full and the call will block
3521            buf = bytearray(8192)
3522            def fill_buffer():
3523                while True:
3524                    s.send(buf)
3525            self.assertRaises((ssl.SSLWantWriteError,
3526                               ssl.SSLWantReadError), fill_buffer)
3527
3528            # Now read all the output and discard it
3529            s.setblocking(True)
3530            s.close()
3531
3532    def test_handshake_timeout(self):
3533        # Issue #5103: SSL handshake must respect the socket timeout
3534        server = socket.socket(socket.AF_INET)
3535        host = "127.0.0.1"
3536        port = support.bind_port(server)
3537        started = threading.Event()
3538        finish = False
3539
3540        def serve():
3541            server.listen()
3542            started.set()
3543            conns = []
3544            while not finish:
3545                r, w, e = select.select([server], [], [], 0.1)
3546                if server in r:
3547                    # Let the socket hang around rather than having
3548                    # it closed by garbage collection.
3549                    conns.append(server.accept()[0])
3550            for sock in conns:
3551                sock.close()
3552
3553        t = threading.Thread(target=serve)
3554        t.start()
3555        started.wait()
3556
3557        try:
3558            try:
3559                c = socket.socket(socket.AF_INET)
3560                c.settimeout(0.2)
3561                c.connect((host, port))
3562                # Will attempt handshake and time out
3563                self.assertRaisesRegex(socket.timeout, "timed out",
3564                                       test_wrap_socket, c)
3565            finally:
3566                c.close()
3567            try:
3568                c = socket.socket(socket.AF_INET)
3569                c = test_wrap_socket(c)
3570                c.settimeout(0.2)
3571                # Will attempt handshake and time out
3572                self.assertRaisesRegex(socket.timeout, "timed out",
3573                                       c.connect, (host, port))
3574            finally:
3575                c.close()
3576        finally:
3577            finish = True
3578            t.join()
3579            server.close()
3580
3581    def test_server_accept(self):
3582        # Issue #16357: accept() on a SSLSocket created through
3583        # SSLContext.wrap_socket().
3584        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3585        context.verify_mode = ssl.CERT_REQUIRED
3586        context.load_verify_locations(SIGNING_CA)
3587        context.load_cert_chain(SIGNED_CERTFILE)
3588        server = socket.socket(socket.AF_INET)
3589        host = "127.0.0.1"
3590        port = support.bind_port(server)
3591        server = context.wrap_socket(server, server_side=True)
3592        self.assertTrue(server.server_side)
3593
3594        evt = threading.Event()
3595        remote = None
3596        peer = None
3597        def serve():
3598            nonlocal remote, peer
3599            server.listen()
3600            # Block on the accept and wait on the connection to close.
3601            evt.set()
3602            remote, peer = server.accept()
3603            remote.send(remote.recv(4))
3604
3605        t = threading.Thread(target=serve)
3606        t.start()
3607        # Client wait until server setup and perform a connect.
3608        evt.wait()
3609        client = context.wrap_socket(socket.socket())
3610        client.connect((host, port))
3611        client.send(b'data')
3612        client.recv()
3613        client_addr = client.getsockname()
3614        client.close()
3615        t.join()
3616        remote.close()
3617        server.close()
3618        # Sanity checks.
3619        self.assertIsInstance(remote, ssl.SSLSocket)
3620        self.assertEqual(peer, client_addr)
3621
3622    def test_getpeercert_enotconn(self):
3623        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3624        with context.wrap_socket(socket.socket()) as sock:
3625            with self.assertRaises(OSError) as cm:
3626                sock.getpeercert()
3627            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3628
3629    def test_do_handshake_enotconn(self):
3630        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3631        with context.wrap_socket(socket.socket()) as sock:
3632            with self.assertRaises(OSError) as cm:
3633                sock.do_handshake()
3634            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3635
3636    def test_no_shared_ciphers(self):
3637        client_context, server_context, hostname = testing_context()
3638        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3639        client_context.options |= ssl.OP_NO_TLSv1_3
3640        # Force different suites on client and master
3641        client_context.set_ciphers("AES128")
3642        server_context.set_ciphers("AES256")
3643        with ThreadedEchoServer(context=server_context) as server:
3644            with client_context.wrap_socket(socket.socket(),
3645                                            server_hostname=hostname) as s:
3646                with self.assertRaises(OSError):
3647                    s.connect((HOST, server.port))
3648        self.assertIn("no shared cipher", server.conn_errors[0])
3649
3650    def test_version_basic(self):
3651        """
3652        Basic tests for SSLSocket.version().
3653        More tests are done in the test_protocol_*() methods.
3654        """
3655        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3656        context.check_hostname = False
3657        context.verify_mode = ssl.CERT_NONE
3658        with ThreadedEchoServer(CERTFILE,
3659                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3660                                chatty=False) as server:
3661            with context.wrap_socket(socket.socket()) as s:
3662                self.assertIs(s.version(), None)
3663                self.assertIs(s._sslobj, None)
3664                s.connect((HOST, server.port))
3665                if IS_OPENSSL_1_1_1 and ssl.HAS_TLSv1_3:
3666                    self.assertEqual(s.version(), 'TLSv1.3')
3667                elif ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
3668                    self.assertEqual(s.version(), 'TLSv1.2')
3669                else:  # 0.9.8 to 1.0.1
3670                    self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
3671            self.assertIs(s._sslobj, None)
3672            self.assertIs(s.version(), None)
3673
3674    @unittest.skipUnless(ssl.HAS_TLSv1_3,
3675                         "test requires TLSv1.3 enabled OpenSSL")
3676    def test_tls1_3(self):
3677        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3678        context.load_cert_chain(CERTFILE)
3679        context.options |= (
3680            ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
3681        )
3682        with ThreadedEchoServer(context=context) as server:
3683            with context.wrap_socket(socket.socket()) as s:
3684                s.connect((HOST, server.port))
3685                self.assertIn(s.cipher()[0], {
3686                    'TLS_AES_256_GCM_SHA384',
3687                    'TLS_CHACHA20_POLY1305_SHA256',
3688                    'TLS_AES_128_GCM_SHA256',
3689                })
3690                self.assertEqual(s.version(), 'TLSv1.3')
3691
3692    @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'),
3693                         "required OpenSSL 1.1.0g")
3694    def test_min_max_version(self):
3695        client_context, server_context, hostname = testing_context()
3696        # client TLSv1.0 to 1.2
3697        client_context.minimum_version = ssl.TLSVersion.TLSv1
3698        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3699        # server only TLSv1.2
3700        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3701        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3702
3703        with ThreadedEchoServer(context=server_context) as server:
3704            with client_context.wrap_socket(socket.socket(),
3705                                            server_hostname=hostname) as s:
3706                s.connect((HOST, server.port))
3707                self.assertEqual(s.version(), 'TLSv1.2')
3708
3709        # client 1.0 to 1.2, server 1.0 to 1.1
3710        server_context.minimum_version = ssl.TLSVersion.TLSv1
3711        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3712
3713        with ThreadedEchoServer(context=server_context) as server:
3714            with client_context.wrap_socket(socket.socket(),
3715                                            server_hostname=hostname) as s:
3716                s.connect((HOST, server.port))
3717                self.assertEqual(s.version(), 'TLSv1.1')
3718
3719        # client 1.0, server 1.2 (mismatch)
3720        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3721        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3722        client_context.minimum_version = ssl.TLSVersion.TLSv1
3723        client_context.maximum_version = ssl.TLSVersion.TLSv1
3724        with ThreadedEchoServer(context=server_context) as server:
3725            with client_context.wrap_socket(socket.socket(),
3726                                            server_hostname=hostname) as s:
3727                with self.assertRaises(ssl.SSLError) as e:
3728                    s.connect((HOST, server.port))
3729                self.assertIn("alert", str(e.exception))
3730
3731
3732    @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'),
3733                         "required OpenSSL 1.1.0g")
3734    @unittest.skipUnless(ssl.HAS_SSLv3, "requires SSLv3 support")
3735    def test_min_max_version_sslv3(self):
3736        client_context, server_context, hostname = testing_context()
3737        server_context.minimum_version = ssl.TLSVersion.SSLv3
3738        client_context.minimum_version = ssl.TLSVersion.SSLv3
3739        client_context.maximum_version = ssl.TLSVersion.SSLv3
3740        with ThreadedEchoServer(context=server_context) as server:
3741            with client_context.wrap_socket(socket.socket(),
3742                                            server_hostname=hostname) as s:
3743                s.connect((HOST, server.port))
3744                self.assertEqual(s.version(), 'SSLv3')
3745
3746    @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
3747    def test_default_ecdh_curve(self):
3748        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
3749        # should be enabled by default on SSL contexts.
3750        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3751        context.load_cert_chain(CERTFILE)
3752        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
3753        # cipher name.
3754        context.options |= ssl.OP_NO_TLSv1_3
3755        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
3756        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
3757        # our default cipher list should prefer ECDH-based ciphers
3758        # automatically.
3759        if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
3760            context.set_ciphers("ECCdraft:ECDH")
3761        with ThreadedEchoServer(context=context) as server:
3762            with context.wrap_socket(socket.socket()) as s:
3763                s.connect((HOST, server.port))
3764                self.assertIn("ECDH", s.cipher()[0])
3765
3766    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
3767                         "'tls-unique' channel binding not available")
3768    def test_tls_unique_channel_binding(self):
3769        """Test tls-unique channel binding."""
3770        if support.verbose:
3771            sys.stdout.write("\n")
3772
3773        client_context, server_context, hostname = testing_context()
3774
3775        server = ThreadedEchoServer(context=server_context,
3776                                    chatty=True,
3777                                    connectionchatty=False)
3778
3779        with server:
3780            with client_context.wrap_socket(
3781                    socket.socket(),
3782                    server_hostname=hostname) as s:
3783                s.connect((HOST, server.port))
3784                # get the data
3785                cb_data = s.get_channel_binding("tls-unique")
3786                if support.verbose:
3787                    sys.stdout.write(
3788                        " got channel binding data: {0!r}\n".format(cb_data))
3789
3790                # check if it is sane
3791                self.assertIsNotNone(cb_data)
3792                if s.version() == 'TLSv1.3':
3793                    self.assertEqual(len(cb_data), 48)
3794                else:
3795                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3796
3797                # and compare with the peers version
3798                s.write(b"CB tls-unique\n")
3799                peer_data_repr = s.read().strip()
3800                self.assertEqual(peer_data_repr,
3801                                 repr(cb_data).encode("us-ascii"))
3802
3803            # now, again
3804            with client_context.wrap_socket(
3805                    socket.socket(),
3806                    server_hostname=hostname) as s:
3807                s.connect((HOST, server.port))
3808                new_cb_data = s.get_channel_binding("tls-unique")
3809                if support.verbose:
3810                    sys.stdout.write(
3811                        "got another channel binding data: {0!r}\n".format(
3812                            new_cb_data)
3813                    )
3814                # is it really unique
3815                self.assertNotEqual(cb_data, new_cb_data)
3816                self.assertIsNotNone(cb_data)
3817                if s.version() == 'TLSv1.3':
3818                    self.assertEqual(len(cb_data), 48)
3819                else:
3820                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3821                s.write(b"CB tls-unique\n")
3822                peer_data_repr = s.read().strip()
3823                self.assertEqual(peer_data_repr,
3824                                 repr(new_cb_data).encode("us-ascii"))
3825
3826    def test_compression(self):
3827        client_context, server_context, hostname = testing_context()
3828        stats = server_params_test(client_context, server_context,
3829                                   chatty=True, connectionchatty=True,
3830                                   sni_name=hostname)
3831        if support.verbose:
3832            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
3833        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
3834
3835    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
3836                         "ssl.OP_NO_COMPRESSION needed for this test")
3837    def test_compression_disabled(self):
3838        client_context, server_context, hostname = testing_context()
3839        client_context.options |= ssl.OP_NO_COMPRESSION
3840        server_context.options |= ssl.OP_NO_COMPRESSION
3841        stats = server_params_test(client_context, server_context,
3842                                   chatty=True, connectionchatty=True,
3843                                   sni_name=hostname)
3844        self.assertIs(stats['compression'], None)
3845
3846    def test_dh_params(self):
3847        # Check we can get a connection with ephemeral Diffie-Hellman
3848        client_context, server_context, hostname = testing_context()
3849        # test scenario needs TLS <= 1.2
3850        client_context.options |= ssl.OP_NO_TLSv1_3
3851        server_context.load_dh_params(DHFILE)
3852        server_context.set_ciphers("kEDH")
3853        server_context.options |= ssl.OP_NO_TLSv1_3
3854        stats = server_params_test(client_context, server_context,
3855                                   chatty=True, connectionchatty=True,
3856                                   sni_name=hostname)
3857        cipher = stats["cipher"][0]
3858        parts = cipher.split("-")
3859        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
3860            self.fail("Non-DH cipher: " + cipher[0])
3861
3862    @unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
3863    @unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
3864    def test_ecdh_curve(self):
3865        # server secp384r1, client auto
3866        client_context, server_context, hostname = testing_context()
3867
3868        server_context.set_ecdh_curve("secp384r1")
3869        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3870        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3871        stats = server_params_test(client_context, server_context,
3872                                   chatty=True, connectionchatty=True,
3873                                   sni_name=hostname)
3874
3875        # server auto, client secp384r1
3876        client_context, server_context, hostname = testing_context()
3877        client_context.set_ecdh_curve("secp384r1")
3878        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3879        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3880        stats = server_params_test(client_context, server_context,
3881                                   chatty=True, connectionchatty=True,
3882                                   sni_name=hostname)
3883
3884        # server / client curve mismatch
3885        client_context, server_context, hostname = testing_context()
3886        client_context.set_ecdh_curve("prime256v1")
3887        server_context.set_ecdh_curve("secp384r1")
3888        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3889        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3890        try:
3891            stats = server_params_test(client_context, server_context,
3892                                       chatty=True, connectionchatty=True,
3893                                       sni_name=hostname)
3894        except ssl.SSLError:
3895            pass
3896        else:
3897            # OpenSSL 1.0.2 does not fail although it should.
3898            if IS_OPENSSL_1_1_0:
3899                self.fail("mismatch curve did not fail")
3900
3901    def test_selected_alpn_protocol(self):
3902        # selected_alpn_protocol() is None unless ALPN is used.
3903        client_context, server_context, hostname = testing_context()
3904        stats = server_params_test(client_context, server_context,
3905                                   chatty=True, connectionchatty=True,
3906                                   sni_name=hostname)
3907        self.assertIs(stats['client_alpn_protocol'], None)
3908
3909    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
3910    def test_selected_alpn_protocol_if_server_uses_alpn(self):
3911        # selected_alpn_protocol() is None unless ALPN is used by the client.
3912        client_context, server_context, hostname = testing_context()
3913        server_context.set_alpn_protocols(['foo', 'bar'])
3914        stats = server_params_test(client_context, server_context,
3915                                   chatty=True, connectionchatty=True,
3916                                   sni_name=hostname)
3917        self.assertIs(stats['client_alpn_protocol'], None)
3918
3919    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
3920    def test_alpn_protocols(self):
3921        server_protocols = ['foo', 'bar', 'milkshake']
3922        protocol_tests = [
3923            (['foo', 'bar'], 'foo'),
3924            (['bar', 'foo'], 'foo'),
3925            (['milkshake'], 'milkshake'),
3926            (['http/3.0', 'http/4.0'], None)
3927        ]
3928        for client_protocols, expected in protocol_tests:
3929            client_context, server_context, hostname = testing_context()
3930            server_context.set_alpn_protocols(server_protocols)
3931            client_context.set_alpn_protocols(client_protocols)
3932
3933            try:
3934                stats = server_params_test(client_context,
3935                                           server_context,
3936                                           chatty=True,
3937                                           connectionchatty=True,
3938                                           sni_name=hostname)
3939            except ssl.SSLError as e:
3940                stats = e
3941
3942            if (expected is None and IS_OPENSSL_1_1_0
3943                    and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
3944                # OpenSSL 1.1.0 to 1.1.0e raises handshake error
3945                self.assertIsInstance(stats, ssl.SSLError)
3946            else:
3947                msg = "failed trying %s (s) and %s (c).\n" \
3948                    "was expecting %s, but got %%s from the %%s" \
3949                        % (str(server_protocols), str(client_protocols),
3950                            str(expected))
3951                client_result = stats['client_alpn_protocol']
3952                self.assertEqual(client_result, expected,
3953                                 msg % (client_result, "client"))
3954                server_result = stats['server_alpn_protocols'][-1] \
3955                    if len(stats['server_alpn_protocols']) else 'nothing'
3956                self.assertEqual(server_result, expected,
3957                                 msg % (server_result, "server"))
3958
3959    def test_selected_npn_protocol(self):
3960        # selected_npn_protocol() is None unless NPN is used
3961        client_context, server_context, hostname = testing_context()
3962        stats = server_params_test(client_context, server_context,
3963                                   chatty=True, connectionchatty=True,
3964                                   sni_name=hostname)
3965        self.assertIs(stats['client_npn_protocol'], None)
3966
3967    @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
3968    def test_npn_protocols(self):
3969        server_protocols = ['http/1.1', 'spdy/2']
3970        protocol_tests = [
3971            (['http/1.1', 'spdy/2'], 'http/1.1'),
3972            (['spdy/2', 'http/1.1'], 'http/1.1'),
3973            (['spdy/2', 'test'], 'spdy/2'),
3974            (['abc', 'def'], 'abc')
3975        ]
3976        for client_protocols, expected in protocol_tests:
3977            client_context, server_context, hostname = testing_context()
3978            server_context.set_npn_protocols(server_protocols)
3979            client_context.set_npn_protocols(client_protocols)
3980            stats = server_params_test(client_context, server_context,
3981                                       chatty=True, connectionchatty=True,
3982                                       sni_name=hostname)
3983            msg = "failed trying %s (s) and %s (c).\n" \
3984                  "was expecting %s, but got %%s from the %%s" \
3985                      % (str(server_protocols), str(client_protocols),
3986                         str(expected))
3987            client_result = stats['client_npn_protocol']
3988            self.assertEqual(client_result, expected, msg % (client_result, "client"))
3989            server_result = stats['server_npn_protocols'][-1] \
3990                if len(stats['server_npn_protocols']) else 'nothing'
3991            self.assertEqual(server_result, expected, msg % (server_result, "server"))
3992
3993    def sni_contexts(self):
3994        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3995        server_context.load_cert_chain(SIGNED_CERTFILE)
3996        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3997        other_context.load_cert_chain(SIGNED_CERTFILE2)
3998        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3999        client_context.load_verify_locations(SIGNING_CA)
4000        return server_context, other_context, client_context
4001
4002    def check_common_name(self, stats, name):
4003        cert = stats['peercert']
4004        self.assertIn((('commonName', name),), cert['subject'])
4005
4006    @needs_sni
4007    def test_sni_callback(self):
4008        calls = []
4009        server_context, other_context, client_context = self.sni_contexts()
4010
4011        client_context.check_hostname = False
4012
4013        def servername_cb(ssl_sock, server_name, initial_context):
4014            calls.append((server_name, initial_context))
4015            if server_name is not None:
4016                ssl_sock.context = other_context
4017        server_context.set_servername_callback(servername_cb)
4018
4019        stats = server_params_test(client_context, server_context,
4020                                   chatty=True,
4021                                   sni_name='supermessage')
4022        # The hostname was fetched properly, and the certificate was
4023        # changed for the connection.
4024        self.assertEqual(calls, [("supermessage", server_context)])
4025        # CERTFILE4 was selected
4026        self.check_common_name(stats, 'fakehostname')
4027
4028        calls = []
4029        # The callback is called with server_name=None
4030        stats = server_params_test(client_context, server_context,
4031                                   chatty=True,
4032                                   sni_name=None)
4033        self.assertEqual(calls, [(None, server_context)])
4034        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4035
4036        # Check disabling the callback
4037        calls = []
4038        server_context.set_servername_callback(None)
4039
4040        stats = server_params_test(client_context, server_context,
4041                                   chatty=True,
4042                                   sni_name='notfunny')
4043        # Certificate didn't change
4044        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4045        self.assertEqual(calls, [])
4046
4047    @needs_sni
4048    def test_sni_callback_alert(self):
4049        # Returning a TLS alert is reflected to the connecting client
4050        server_context, other_context, client_context = self.sni_contexts()
4051
4052        def cb_returning_alert(ssl_sock, server_name, initial_context):
4053            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4054        server_context.set_servername_callback(cb_returning_alert)
4055        with self.assertRaises(ssl.SSLError) as cm:
4056            stats = server_params_test(client_context, server_context,
4057                                       chatty=False,
4058                                       sni_name='supermessage')
4059        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4060
4061    @needs_sni
4062    def test_sni_callback_raising(self):
4063        # Raising fails the connection with a TLS handshake failure alert.
4064        server_context, other_context, client_context = self.sni_contexts()
4065
4066        def cb_raising(ssl_sock, server_name, initial_context):
4067            1/0
4068        server_context.set_servername_callback(cb_raising)
4069
4070        with self.assertRaises(ssl.SSLError) as cm, \
4071             support.captured_stderr() as stderr:
4072            stats = server_params_test(client_context, server_context,
4073                                       chatty=False,
4074                                       sni_name='supermessage')
4075        self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
4076        self.assertIn("ZeroDivisionError", stderr.getvalue())
4077
4078    @needs_sni
4079    def test_sni_callback_wrong_return_type(self):
4080        # Returning the wrong return type terminates the TLS connection
4081        # with an internal error alert.
4082        server_context, other_context, client_context = self.sni_contexts()
4083
4084        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4085            return "foo"
4086        server_context.set_servername_callback(cb_wrong_return_type)
4087
4088        with self.assertRaises(ssl.SSLError) as cm, \
4089             support.captured_stderr() as stderr:
4090            stats = server_params_test(client_context, server_context,
4091                                       chatty=False,
4092                                       sni_name='supermessage')
4093        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4094        self.assertIn("TypeError", stderr.getvalue())
4095
4096    def test_shared_ciphers(self):
4097        client_context, server_context, hostname = testing_context()
4098        client_context.set_ciphers("AES128:AES256")
4099        server_context.set_ciphers("AES256")
4100        expected_algs = [
4101            "AES256", "AES-256",
4102            # TLS 1.3 ciphers are always enabled
4103            "TLS_CHACHA20", "TLS_AES",
4104        ]
4105
4106        stats = server_params_test(client_context, server_context,
4107                                   sni_name=hostname)
4108        ciphers = stats['server_shared_ciphers'][0]
4109        self.assertGreater(len(ciphers), 0)
4110        for name, tls_version, bits in ciphers:
4111            if not any(alg in name for alg in expected_algs):
4112                self.fail(name)
4113
4114    def test_read_write_after_close_raises_valuerror(self):
4115        client_context, server_context, hostname = testing_context()
4116        server = ThreadedEchoServer(context=server_context, chatty=False)
4117
4118        with server:
4119            s = client_context.wrap_socket(socket.socket(),
4120                                           server_hostname=hostname)
4121            s.connect((HOST, server.port))
4122            s.close()
4123
4124            self.assertRaises(ValueError, s.read, 1024)
4125            self.assertRaises(ValueError, s.write, b'hello')
4126
4127    def test_sendfile(self):
4128        TEST_DATA = b"x" * 512
4129        with open(support.TESTFN, 'wb') as f:
4130            f.write(TEST_DATA)
4131        self.addCleanup(support.unlink, support.TESTFN)
4132        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
4133        context.verify_mode = ssl.CERT_REQUIRED
4134        context.load_verify_locations(SIGNING_CA)
4135        context.load_cert_chain(SIGNED_CERTFILE)
4136        server = ThreadedEchoServer(context=context, chatty=False)
4137        with server:
4138            with context.wrap_socket(socket.socket()) as s:
4139                s.connect((HOST, server.port))
4140                with open(support.TESTFN, 'rb') as file:
4141                    s.sendfile(file)
4142                    self.assertEqual(s.recv(1024), TEST_DATA)
4143
4144    def test_session(self):
4145        client_context, server_context, hostname = testing_context()
4146        # TODO: sessions aren't compatible with TLSv1.3 yet
4147        client_context.options |= ssl.OP_NO_TLSv1_3
4148
4149        # first connection without session
4150        stats = server_params_test(client_context, server_context,
4151                                   sni_name=hostname)
4152        session = stats['session']
4153        self.assertTrue(session.id)
4154        self.assertGreater(session.time, 0)
4155        self.assertGreater(session.timeout, 0)
4156        self.assertTrue(session.has_ticket)
4157        if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
4158            self.assertGreater(session.ticket_lifetime_hint, 0)
4159        self.assertFalse(stats['session_reused'])
4160        sess_stat = server_context.session_stats()
4161        self.assertEqual(sess_stat['accept'], 1)
4162        self.assertEqual(sess_stat['hits'], 0)
4163
4164        # reuse session
4165        stats = server_params_test(client_context, server_context,
4166                                   session=session, sni_name=hostname)
4167        sess_stat = server_context.session_stats()
4168        self.assertEqual(sess_stat['accept'], 2)
4169        self.assertEqual(sess_stat['hits'], 1)
4170        self.assertTrue(stats['session_reused'])
4171        session2 = stats['session']
4172        self.assertEqual(session2.id, session.id)
4173        self.assertEqual(session2, session)
4174        self.assertIsNot(session2, session)
4175        self.assertGreaterEqual(session2.time, session.time)
4176        self.assertGreaterEqual(session2.timeout, session.timeout)
4177
4178        # another one without session
4179        stats = server_params_test(client_context, server_context,
4180                                   sni_name=hostname)
4181        self.assertFalse(stats['session_reused'])
4182        session3 = stats['session']
4183        self.assertNotEqual(session3.id, session.id)
4184        self.assertNotEqual(session3, session)
4185        sess_stat = server_context.session_stats()
4186        self.assertEqual(sess_stat['accept'], 3)
4187        self.assertEqual(sess_stat['hits'], 1)
4188
4189        # reuse session again
4190        stats = server_params_test(client_context, server_context,
4191                                   session=session, sni_name=hostname)
4192        self.assertTrue(stats['session_reused'])
4193        session4 = stats['session']
4194        self.assertEqual(session4.id, session.id)
4195        self.assertEqual(session4, session)
4196        self.assertGreaterEqual(session4.time, session.time)
4197        self.assertGreaterEqual(session4.timeout, session.timeout)
4198        sess_stat = server_context.session_stats()
4199        self.assertEqual(sess_stat['accept'], 4)
4200        self.assertEqual(sess_stat['hits'], 2)
4201
4202    def test_session_handling(self):
4203        client_context, server_context, hostname = testing_context()
4204        client_context2, _, _ = testing_context()
4205
4206        # TODO: session reuse does not work with TLSv1.3
4207        client_context.options |= ssl.OP_NO_TLSv1_3
4208        client_context2.options |= ssl.OP_NO_TLSv1_3
4209
4210        server = ThreadedEchoServer(context=server_context, chatty=False)
4211        with server:
4212            with client_context.wrap_socket(socket.socket(),
4213                                            server_hostname=hostname) as s:
4214                # session is None before handshake
4215                self.assertEqual(s.session, None)
4216                self.assertEqual(s.session_reused, None)
4217                s.connect((HOST, server.port))
4218                session = s.session
4219                self.assertTrue(session)
4220                with self.assertRaises(TypeError) as e:
4221                    s.session = object
4222                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4223
4224            with client_context.wrap_socket(socket.socket(),
4225                                            server_hostname=hostname) as s:
4226                s.connect((HOST, server.port))
4227                # cannot set session after handshake
4228                with self.assertRaises(ValueError) as e:
4229                    s.session = session
4230                self.assertEqual(str(e.exception),
4231                                 'Cannot set session after handshake.')
4232
4233            with client_context.wrap_socket(socket.socket(),
4234                                            server_hostname=hostname) as s:
4235                # can set session before handshake and before the
4236                # connection was established
4237                s.session = session
4238                s.connect((HOST, server.port))
4239                self.assertEqual(s.session.id, session.id)
4240                self.assertEqual(s.session, session)
4241                self.assertEqual(s.session_reused, True)
4242
4243            with client_context2.wrap_socket(socket.socket(),
4244                                             server_hostname=hostname) as s:
4245                # cannot re-use session with a different SSLContext
4246                with self.assertRaises(ValueError) as e:
4247                    s.session = session
4248                    s.connect((HOST, server.port))
4249                self.assertEqual(str(e.exception),
4250                                 'Session refers to a different SSLContext.')
4251
4252
4253@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
4254class TestPostHandshakeAuth(unittest.TestCase):
4255    def test_pha_setter(self):
4256        protocols = [
4257            ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4258        ]
4259        for protocol in protocols:
4260            ctx = ssl.SSLContext(protocol)
4261            self.assertEqual(ctx.post_handshake_auth, False)
4262
4263            ctx.post_handshake_auth = True
4264            self.assertEqual(ctx.post_handshake_auth, True)
4265
4266            ctx.verify_mode = ssl.CERT_REQUIRED
4267            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4268            self.assertEqual(ctx.post_handshake_auth, True)
4269
4270            ctx.post_handshake_auth = False
4271            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4272            self.assertEqual(ctx.post_handshake_auth, False)
4273
4274            ctx.verify_mode = ssl.CERT_OPTIONAL
4275            ctx.post_handshake_auth = True
4276            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4277            self.assertEqual(ctx.post_handshake_auth, True)
4278
4279    def test_pha_required(self):
4280        client_context, server_context, hostname = testing_context()
4281        server_context.post_handshake_auth = True
4282        server_context.verify_mode = ssl.CERT_REQUIRED
4283        client_context.post_handshake_auth = True
4284        client_context.load_cert_chain(SIGNED_CERTFILE)
4285
4286        server = ThreadedEchoServer(context=server_context, chatty=False)
4287        with server:
4288            with client_context.wrap_socket(socket.socket(),
4289                                            server_hostname=hostname) as s:
4290                s.connect((HOST, server.port))
4291                s.write(b'HASCERT')
4292                self.assertEqual(s.recv(1024), b'FALSE\n')
4293                s.write(b'PHA')
4294                self.assertEqual(s.recv(1024), b'OK\n')
4295                s.write(b'HASCERT')
4296                self.assertEqual(s.recv(1024), b'TRUE\n')
4297                # PHA method just returns true when cert is already available
4298                s.write(b'PHA')
4299                self.assertEqual(s.recv(1024), b'OK\n')
4300                s.write(b'GETCERT')
4301                cert_text = s.recv(4096).decode('us-ascii')
4302                self.assertIn('Python Software Foundation CA', cert_text)
4303
4304    def test_pha_required_nocert(self):
4305        client_context, server_context, hostname = testing_context()
4306        server_context.post_handshake_auth = True
4307        server_context.verify_mode = ssl.CERT_REQUIRED
4308        client_context.post_handshake_auth = True
4309
4310        server = ThreadedEchoServer(context=server_context, chatty=False)
4311        with server:
4312            with client_context.wrap_socket(socket.socket(),
4313                                            server_hostname=hostname) as s:
4314                s.connect((HOST, server.port))
4315                s.write(b'PHA')
4316                # receive CertificateRequest
4317                self.assertEqual(s.recv(1024), b'OK\n')
4318                # send empty Certificate + Finish
4319                s.write(b'HASCERT')
4320                # receive alert
4321                with self.assertRaisesRegex(
4322                        ssl.SSLError,
4323                        'tlsv13 alert certificate required'):
4324                    s.recv(1024)
4325
4326    def test_pha_optional(self):
4327        if support.verbose:
4328            sys.stdout.write("\n")
4329
4330        client_context, server_context, hostname = testing_context()
4331        server_context.post_handshake_auth = True
4332        server_context.verify_mode = ssl.CERT_REQUIRED
4333        client_context.post_handshake_auth = True
4334        client_context.load_cert_chain(SIGNED_CERTFILE)
4335
4336        # check CERT_OPTIONAL
4337        server_context.verify_mode = ssl.CERT_OPTIONAL
4338        server = ThreadedEchoServer(context=server_context, chatty=False)
4339        with server:
4340            with client_context.wrap_socket(socket.socket(),
4341                                            server_hostname=hostname) as s:
4342                s.connect((HOST, server.port))
4343                s.write(b'HASCERT')
4344                self.assertEqual(s.recv(1024), b'FALSE\n')
4345                s.write(b'PHA')
4346                self.assertEqual(s.recv(1024), b'OK\n')
4347                s.write(b'HASCERT')
4348                self.assertEqual(s.recv(1024), b'TRUE\n')
4349
4350    def test_pha_optional_nocert(self):
4351        if support.verbose:
4352            sys.stdout.write("\n")
4353
4354        client_context, server_context, hostname = testing_context()
4355        server_context.post_handshake_auth = True
4356        server_context.verify_mode = ssl.CERT_OPTIONAL
4357        client_context.post_handshake_auth = True
4358
4359        server = ThreadedEchoServer(context=server_context, chatty=False)
4360        with server:
4361            with client_context.wrap_socket(socket.socket(),
4362                                            server_hostname=hostname) as s:
4363                s.connect((HOST, server.port))
4364                s.write(b'HASCERT')
4365                self.assertEqual(s.recv(1024), b'FALSE\n')
4366                s.write(b'PHA')
4367                self.assertEqual(s.recv(1024), b'OK\n')
4368                # optional doens't fail when client does not have a cert
4369                s.write(b'HASCERT')
4370                self.assertEqual(s.recv(1024), b'FALSE\n')
4371
4372    def test_pha_no_pha_client(self):
4373        client_context, server_context, hostname = testing_context()
4374        server_context.post_handshake_auth = True
4375        server_context.verify_mode = ssl.CERT_REQUIRED
4376        client_context.load_cert_chain(SIGNED_CERTFILE)
4377
4378        server = ThreadedEchoServer(context=server_context, chatty=False)
4379        with server:
4380            with client_context.wrap_socket(socket.socket(),
4381                                            server_hostname=hostname) as s:
4382                s.connect((HOST, server.port))
4383                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4384                    s.verify_client_post_handshake()
4385                s.write(b'PHA')
4386                self.assertIn(b'extension not received', s.recv(1024))
4387
4388    def test_pha_no_pha_server(self):
4389        # server doesn't have PHA enabled, cert is requested in handshake
4390        client_context, server_context, hostname = testing_context()
4391        server_context.verify_mode = ssl.CERT_REQUIRED
4392        client_context.post_handshake_auth = True
4393        client_context.load_cert_chain(SIGNED_CERTFILE)
4394
4395        server = ThreadedEchoServer(context=server_context, chatty=False)
4396        with server:
4397            with client_context.wrap_socket(socket.socket(),
4398                                            server_hostname=hostname) as s:
4399                s.connect((HOST, server.port))
4400                s.write(b'HASCERT')
4401                self.assertEqual(s.recv(1024), b'TRUE\n')
4402                # PHA doesn't fail if there is already a cert
4403                s.write(b'PHA')
4404                self.assertEqual(s.recv(1024), b'OK\n')
4405                s.write(b'HASCERT')
4406                self.assertEqual(s.recv(1024), b'TRUE\n')
4407
4408    def test_pha_not_tls13(self):
4409        # TLS 1.2
4410        client_context, server_context, hostname = testing_context()
4411        server_context.verify_mode = ssl.CERT_REQUIRED
4412        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4413        client_context.post_handshake_auth = True
4414        client_context.load_cert_chain(SIGNED_CERTFILE)
4415
4416        server = ThreadedEchoServer(context=server_context, chatty=False)
4417        with server:
4418            with client_context.wrap_socket(socket.socket(),
4419                                            server_hostname=hostname) as s:
4420                s.connect((HOST, server.port))
4421                # PHA fails for TLS != 1.3
4422                s.write(b'PHA')
4423                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4424
4425
4426def test_main(verbose=False):
4427    if support.verbose:
4428        import warnings
4429        plats = {
4430            'Linux': platform.linux_distribution,
4431            'Mac': platform.mac_ver,
4432            'Windows': platform.win32_ver,
4433        }
4434        with warnings.catch_warnings():
4435            warnings.filterwarnings(
4436                'ignore',
4437                r'dist\(\) and linux_distribution\(\) '
4438                'functions are deprecated .*',
4439                DeprecationWarning,
4440            )
4441            for name, func in plats.items():
4442                plat = func()
4443                if plat and plat[0]:
4444                    plat = '%s %r' % (name, plat)
4445                    break
4446            else:
4447                plat = repr(platform.platform())
4448        print("test_ssl: testing with %r %r" %
4449            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
4450        print("          under %s" % plat)
4451        print("          HAS_SNI = %r" % ssl.HAS_SNI)
4452        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
4453        try:
4454            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
4455        except AttributeError:
4456            pass
4457
4458    for filename in [
4459        CERTFILE, BYTES_CERTFILE,
4460        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
4461        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
4462        BADCERT, BADKEY, EMPTYCERT]:
4463        if not os.path.exists(filename):
4464            raise support.TestFailed("Can't read certificate file %r" % filename)
4465
4466    tests = [
4467        ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
4468        SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
4469        TestPostHandshakeAuth
4470    ]
4471
4472    if support.is_resource_enabled('network'):
4473        tests.append(NetworkedTests)
4474
4475    thread_info = support.threading_setup()
4476    try:
4477        support.run_unittest(*tests)
4478    finally:
4479        support.threading_cleanup(*thread_info)
4480
4481if __name__ == "__main__":
4482    test_main()
4483