1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import os
8
9import pytest
10
11from cryptography import x509
12from cryptography.hazmat.backends.interfaces import DERSerializationBackend
13from cryptography.hazmat.primitives.serialization import load_pem_private_key
14from cryptography.hazmat.primitives.serialization.pkcs12 import (
15    load_key_and_certificates
16)
17
18from .utils import load_vectors_from_file
19
20
21@pytest.mark.requires_backend_interface(interface=DERSerializationBackend)
22class TestPKCS12(object):
23    @pytest.mark.parametrize(
24        ("filename", "password"),
25        [
26            ("cert-key-aes256cbc.p12", b"cryptography"),
27            ("cert-none-key-none.p12", b"cryptography"),
28            ("cert-rc2-key-3des.p12", b"cryptography"),
29            ("no-password.p12", None),
30        ]
31    )
32    def test_load_pkcs12_ec_keys(self, filename, password, backend):
33        cert = load_vectors_from_file(
34            os.path.join("x509", "custom", "ca", "ca.pem"),
35            lambda pemfile: x509.load_pem_x509_certificate(
36                pemfile.read(), backend
37            ), mode="rb"
38        )
39        key = load_vectors_from_file(
40            os.path.join("x509", "custom", "ca", "ca_key.pem"),
41            lambda pemfile: load_pem_private_key(
42                pemfile.read(), None, backend
43            ), mode="rb"
44        )
45        parsed_key, parsed_cert, parsed_more_certs = load_vectors_from_file(
46            os.path.join("pkcs12", filename),
47            lambda derfile: load_key_and_certificates(
48                derfile.read(), password, backend
49            ), mode="rb"
50        )
51        assert parsed_cert == cert
52        assert parsed_key.private_numbers() == key.private_numbers()
53        assert parsed_more_certs == []
54
55    def test_load_pkcs12_cert_only(self, backend):
56        cert = load_vectors_from_file(
57            os.path.join("x509", "custom", "ca", "ca.pem"),
58            lambda pemfile: x509.load_pem_x509_certificate(
59                pemfile.read(), backend
60            ), mode="rb"
61        )
62        parsed_key, parsed_cert, parsed_more_certs = load_vectors_from_file(
63            os.path.join("pkcs12", "cert-aes256cbc-no-key.p12"),
64            lambda data: load_key_and_certificates(
65                data.read(), b"cryptography", backend
66            ),
67            mode="rb"
68        )
69        assert parsed_cert is None
70        assert parsed_key is None
71        assert parsed_more_certs == [cert]
72
73    def test_load_pkcs12_key_only(self, backend):
74        key = load_vectors_from_file(
75            os.path.join("x509", "custom", "ca", "ca_key.pem"),
76            lambda pemfile: load_pem_private_key(
77                pemfile.read(), None, backend
78            ), mode="rb"
79        )
80        parsed_key, parsed_cert, parsed_more_certs = load_vectors_from_file(
81            os.path.join("pkcs12", "no-cert-key-aes256cbc.p12"),
82            lambda data: load_key_and_certificates(
83                data.read(), b"cryptography", backend
84            ),
85            mode="rb"
86        )
87        assert parsed_key.private_numbers() == key.private_numbers()
88        assert parsed_cert is None
89        assert parsed_more_certs == []
90
91    def test_non_bytes(self, backend):
92        with pytest.raises(TypeError):
93            load_key_and_certificates(
94                b"irrelevant", object(), backend
95            )
96
97    def test_not_a_pkcs12(self, backend):
98        with pytest.raises(ValueError):
99            load_key_and_certificates(
100                b"invalid", b"pass", backend
101            )
102
103    def test_invalid_password(self, backend):
104        with pytest.raises(ValueError):
105            load_vectors_from_file(
106                os.path.join("pkcs12", "cert-key-aes256cbc.p12"),
107                lambda derfile: load_key_and_certificates(
108                    derfile.read(), b"invalid", backend
109                ), mode="rb"
110            )
111
112    def test_buffer_protocol(self, backend):
113        p12 = load_vectors_from_file(
114            os.path.join("pkcs12", "cert-key-aes256cbc.p12"),
115            lambda derfile: derfile.read(), mode="rb"
116        )
117        p12buffer = bytearray(p12)
118        parsed_key, parsed_cert, parsed_more_certs = load_key_and_certificates(
119            p12buffer, bytearray(b"cryptography"), backend
120        )
121        assert parsed_key is not None
122        assert parsed_cert is not None
123        assert parsed_more_certs == []
124