1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import binascii
8import itertools
9import os
10
11import pytest
12
13from cryptography.exceptions import (
14    AlreadyFinalized, AlreadyUpdated, InvalidSignature, InvalidTag,
15    NotYetFinalized
16)
17from cryptography.hazmat.primitives import hashes, hmac
18from cryptography.hazmat.primitives.asymmetric import rsa
19from cryptography.hazmat.primitives.ciphers import Cipher
20from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand
21from cryptography.hazmat.primitives.kdf.kbkdf import (
22    CounterLocation, KBKDFHMAC, Mode
23)
24from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
25
26from ...utils import load_vectors_from_file
27
28
29def _load_all_params(path, file_names, param_loader):
30    all_params = []
31    for file_name in file_names:
32        all_params.extend(
33            load_vectors_from_file(os.path.join(path, file_name), param_loader)
34        )
35    return all_params
36
37
38def generate_encrypt_test(param_loader, path, file_names, cipher_factory,
39                          mode_factory):
40    all_params = _load_all_params(path, file_names, param_loader)
41
42    @pytest.mark.parametrize("params", all_params)
43    def test_encryption(self, backend, params):
44        encrypt_test(backend, cipher_factory, mode_factory, params)
45
46    return test_encryption
47
48
49def encrypt_test(backend, cipher_factory, mode_factory, params):
50    assert backend.cipher_supported(
51        cipher_factory(**params), mode_factory(**params)
52    )
53
54    plaintext = params["plaintext"]
55    ciphertext = params["ciphertext"]
56    cipher = Cipher(
57        cipher_factory(**params),
58        mode_factory(**params),
59        backend=backend
60    )
61    encryptor = cipher.encryptor()
62    actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext))
63    actual_ciphertext += encryptor.finalize()
64    assert actual_ciphertext == binascii.unhexlify(ciphertext)
65    decryptor = cipher.decryptor()
66    actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
67    actual_plaintext += decryptor.finalize()
68    assert actual_plaintext == binascii.unhexlify(plaintext)
69
70
71def generate_aead_test(param_loader, path, file_names, cipher_factory,
72                       mode_factory):
73    all_params = _load_all_params(path, file_names, param_loader)
74
75    @pytest.mark.parametrize("params", all_params)
76    def test_aead(self, backend, params):
77        aead_test(backend, cipher_factory, mode_factory, params)
78
79    return test_aead
80
81
82def aead_test(backend, cipher_factory, mode_factory, params):
83    if params.get("pt") is not None:
84        plaintext = params["pt"]
85    ciphertext = params["ct"]
86    aad = params["aad"]
87    if params.get("fail") is True:
88        cipher = Cipher(
89            cipher_factory(binascii.unhexlify(params["key"])),
90            mode_factory(binascii.unhexlify(params["iv"]),
91                         binascii.unhexlify(params["tag"]),
92                         len(binascii.unhexlify(params["tag"]))),
93            backend
94        )
95        decryptor = cipher.decryptor()
96        decryptor.authenticate_additional_data(binascii.unhexlify(aad))
97        actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
98        with pytest.raises(InvalidTag):
99            decryptor.finalize()
100    else:
101        cipher = Cipher(
102            cipher_factory(binascii.unhexlify(params["key"])),
103            mode_factory(binascii.unhexlify(params["iv"]), None),
104            backend
105        )
106        encryptor = cipher.encryptor()
107        encryptor.authenticate_additional_data(binascii.unhexlify(aad))
108        actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext))
109        actual_ciphertext += encryptor.finalize()
110        tag_len = len(binascii.unhexlify(params["tag"]))
111        assert binascii.hexlify(encryptor.tag[:tag_len]) == params["tag"]
112        cipher = Cipher(
113            cipher_factory(binascii.unhexlify(params["key"])),
114            mode_factory(binascii.unhexlify(params["iv"]),
115                         binascii.unhexlify(params["tag"]),
116                         min_tag_length=tag_len),
117            backend
118        )
119        decryptor = cipher.decryptor()
120        decryptor.authenticate_additional_data(binascii.unhexlify(aad))
121        actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
122        actual_plaintext += decryptor.finalize()
123        assert actual_plaintext == binascii.unhexlify(plaintext)
124
125
126def generate_stream_encryption_test(param_loader, path, file_names,
127                                    cipher_factory):
128    all_params = _load_all_params(path, file_names, param_loader)
129
130    @pytest.mark.parametrize("params", all_params)
131    def test_stream_encryption(self, backend, params):
132        stream_encryption_test(backend, cipher_factory, params)
133    return test_stream_encryption
134
135
136def stream_encryption_test(backend, cipher_factory, params):
137    plaintext = params["plaintext"]
138    ciphertext = params["ciphertext"]
139    offset = params["offset"]
140    cipher = Cipher(cipher_factory(**params), None, backend=backend)
141    encryptor = cipher.encryptor()
142    # throw away offset bytes
143    encryptor.update(b"\x00" * int(offset))
144    actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext))
145    actual_ciphertext += encryptor.finalize()
146    assert actual_ciphertext == binascii.unhexlify(ciphertext)
147    decryptor = cipher.decryptor()
148    decryptor.update(b"\x00" * int(offset))
149    actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
150    actual_plaintext += decryptor.finalize()
151    assert actual_plaintext == binascii.unhexlify(plaintext)
152
153
154def generate_hash_test(param_loader, path, file_names, hash_cls):
155    all_params = _load_all_params(path, file_names, param_loader)
156
157    @pytest.mark.parametrize("params", all_params)
158    def test_hash(self, backend, params):
159        hash_test(backend, hash_cls, params)
160    return test_hash
161
162
163def hash_test(backend, algorithm, params):
164    msg, md = params
165    m = hashes.Hash(algorithm, backend=backend)
166    m.update(binascii.unhexlify(msg))
167    expected_md = md.replace(" ", "").lower().encode("ascii")
168    assert m.finalize() == binascii.unhexlify(expected_md)
169
170
171def generate_base_hash_test(algorithm, digest_size):
172    def test_base_hash(self, backend):
173        base_hash_test(backend, algorithm, digest_size)
174    return test_base_hash
175
176
177def base_hash_test(backend, algorithm, digest_size):
178    m = hashes.Hash(algorithm, backend=backend)
179    assert m.algorithm.digest_size == digest_size
180    m_copy = m.copy()
181    assert m != m_copy
182    assert m._ctx != m_copy._ctx
183
184    m.update(b"abc")
185    copy = m.copy()
186    copy.update(b"123")
187    m.update(b"123")
188    assert copy.finalize() == m.finalize()
189
190
191def generate_base_hmac_test(hash_cls):
192    def test_base_hmac(self, backend):
193        base_hmac_test(backend, hash_cls)
194    return test_base_hmac
195
196
197def base_hmac_test(backend, algorithm):
198    key = b"ab"
199    h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend)
200    h_copy = h.copy()
201    assert h != h_copy
202    assert h._ctx != h_copy._ctx
203
204
205def generate_hmac_test(param_loader, path, file_names, algorithm):
206    all_params = _load_all_params(path, file_names, param_loader)
207
208    @pytest.mark.parametrize("params", all_params)
209    def test_hmac(self, backend, params):
210        hmac_test(backend, algorithm, params)
211    return test_hmac
212
213
214def hmac_test(backend, algorithm, params):
215    msg, md, key = params
216    h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend)
217    h.update(binascii.unhexlify(msg))
218    assert h.finalize() == binascii.unhexlify(md.encode("ascii"))
219
220
221def generate_pbkdf2_test(param_loader, path, file_names, algorithm):
222    all_params = _load_all_params(path, file_names, param_loader)
223
224    @pytest.mark.parametrize("params", all_params)
225    def test_pbkdf2(self, backend, params):
226        pbkdf2_test(backend, algorithm, params)
227    return test_pbkdf2
228
229
230def pbkdf2_test(backend, algorithm, params):
231    # Password and salt can contain \0, which should be loaded as a null char.
232    # The NIST loader loads them as literal strings so we replace with the
233    # proper value.
234    kdf = PBKDF2HMAC(
235        algorithm,
236        int(params["length"]),
237        params["salt"],
238        int(params["iterations"]),
239        backend
240    )
241    derived_key = kdf.derive(params["password"])
242    assert binascii.hexlify(derived_key) == params["derived_key"]
243
244
245def generate_aead_exception_test(cipher_factory, mode_factory):
246    def test_aead_exception(self, backend):
247        aead_exception_test(backend, cipher_factory, mode_factory)
248    return test_aead_exception
249
250
251def aead_exception_test(backend, cipher_factory, mode_factory):
252    cipher = Cipher(
253        cipher_factory(binascii.unhexlify(b"0" * 32)),
254        mode_factory(binascii.unhexlify(b"0" * 24)),
255        backend
256    )
257    encryptor = cipher.encryptor()
258    encryptor.update(b"a" * 16)
259    with pytest.raises(NotYetFinalized):
260        encryptor.tag
261    with pytest.raises(AlreadyUpdated):
262        encryptor.authenticate_additional_data(b"b" * 16)
263    encryptor.finalize()
264    with pytest.raises(AlreadyFinalized):
265        encryptor.authenticate_additional_data(b"b" * 16)
266    with pytest.raises(AlreadyFinalized):
267        encryptor.update(b"b" * 16)
268    with pytest.raises(AlreadyFinalized):
269        encryptor.finalize()
270    cipher = Cipher(
271        cipher_factory(binascii.unhexlify(b"0" * 32)),
272        mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16),
273        backend
274    )
275    decryptor = cipher.decryptor()
276    decryptor.update(b"a" * 16)
277    with pytest.raises(AttributeError):
278        decryptor.tag
279
280
281def generate_aead_tag_exception_test(cipher_factory, mode_factory):
282    def test_aead_tag_exception(self, backend):
283        aead_tag_exception_test(backend, cipher_factory, mode_factory)
284    return test_aead_tag_exception
285
286
287def aead_tag_exception_test(backend, cipher_factory, mode_factory):
288    cipher = Cipher(
289        cipher_factory(binascii.unhexlify(b"0" * 32)),
290        mode_factory(binascii.unhexlify(b"0" * 24)),
291        backend
292    )
293
294    with pytest.raises(ValueError):
295        mode_factory(binascii.unhexlify(b"0" * 24), b"000")
296
297    with pytest.raises(ValueError):
298        mode_factory(binascii.unhexlify(b"0" * 24), b"000000", 2)
299
300    cipher = Cipher(
301        cipher_factory(binascii.unhexlify(b"0" * 32)),
302        mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16),
303        backend
304    )
305    with pytest.raises(ValueError):
306        cipher.encryptor()
307
308
309def hkdf_derive_test(backend, algorithm, params):
310    hkdf = HKDF(
311        algorithm,
312        int(params["l"]),
313        salt=binascii.unhexlify(params["salt"]) or None,
314        info=binascii.unhexlify(params["info"]) or None,
315        backend=backend
316    )
317
318    okm = hkdf.derive(binascii.unhexlify(params["ikm"]))
319
320    assert okm == binascii.unhexlify(params["okm"])
321
322
323def hkdf_extract_test(backend, algorithm, params):
324    hkdf = HKDF(
325        algorithm,
326        int(params["l"]),
327        salt=binascii.unhexlify(params["salt"]) or None,
328        info=binascii.unhexlify(params["info"]) or None,
329        backend=backend
330    )
331
332    prk = hkdf._extract(binascii.unhexlify(params["ikm"]))
333
334    assert prk == binascii.unhexlify(params["prk"])
335
336
337def hkdf_expand_test(backend, algorithm, params):
338    hkdf = HKDFExpand(
339        algorithm,
340        int(params["l"]),
341        info=binascii.unhexlify(params["info"]) or None,
342        backend=backend
343    )
344
345    okm = hkdf.derive(binascii.unhexlify(params["prk"]))
346
347    assert okm == binascii.unhexlify(params["okm"])
348
349
350def generate_hkdf_test(param_loader, path, file_names, algorithm):
351    all_params = _load_all_params(path, file_names, param_loader)
352
353    all_tests = [hkdf_extract_test, hkdf_expand_test, hkdf_derive_test]
354
355    @pytest.mark.parametrize(
356        ("params", "hkdf_test"),
357        itertools.product(all_params, all_tests)
358    )
359    def test_hkdf(self, backend, params, hkdf_test):
360        hkdf_test(backend, algorithm, params)
361
362    return test_hkdf
363
364
365def generate_kbkdf_counter_mode_test(param_loader, path, file_names):
366    all_params = _load_all_params(path, file_names, param_loader)
367
368    @pytest.mark.parametrize("params", all_params)
369    def test_kbkdf(self, backend, params):
370        kbkdf_counter_mode_test(backend, params)
371    return test_kbkdf
372
373
374def kbkdf_counter_mode_test(backend, params):
375    supported_algorithms = {
376        'hmac_sha1': hashes.SHA1,
377        'hmac_sha224': hashes.SHA224,
378        'hmac_sha256': hashes.SHA256,
379        'hmac_sha384': hashes.SHA384,
380        'hmac_sha512': hashes.SHA512,
381    }
382
383    supported_counter_locations = {
384        "before_fixed": CounterLocation.BeforeFixed,
385        "after_fixed": CounterLocation.AfterFixed,
386    }
387
388    algorithm = supported_algorithms.get(params.get('prf'))
389    if algorithm is None or not backend.hmac_supported(algorithm()):
390        pytest.skip("KBKDF does not support algorithm: {0}".format(
391            params.get('prf')
392        ))
393
394    ctr_loc = supported_counter_locations.get(params.get("ctrlocation"))
395    if ctr_loc is None or not isinstance(ctr_loc, CounterLocation):
396        pytest.skip("Does not support counter location: {0}".format(
397            params.get('ctrlocation')
398        ))
399
400    ctrkdf = KBKDFHMAC(
401        algorithm(),
402        Mode.CounterMode,
403        params['l'] // 8,
404        params['rlen'] // 8,
405        None,
406        ctr_loc,
407        None,
408        None,
409        binascii.unhexlify(params['fixedinputdata']),
410        backend=backend)
411
412    ko = ctrkdf.derive(binascii.unhexlify(params['ki']))
413    assert binascii.hexlify(ko) == params["ko"]
414
415
416def generate_rsa_verification_test(param_loader, path, file_names, hash_alg,
417                                   pad_factory):
418    all_params = _load_all_params(path, file_names, param_loader)
419    all_params = [i for i in all_params
420                  if i["algorithm"] == hash_alg.name.upper()]
421
422    @pytest.mark.parametrize("params", all_params)
423    def test_rsa_verification(self, backend, params):
424        rsa_verification_test(backend, params, hash_alg, pad_factory)
425
426    return test_rsa_verification
427
428
429def rsa_verification_test(backend, params, hash_alg, pad_factory):
430    public_numbers = rsa.RSAPublicNumbers(
431        e=params["public_exponent"],
432        n=params["modulus"]
433    )
434    public_key = public_numbers.public_key(backend)
435    pad = pad_factory(params, hash_alg)
436    signature = binascii.unhexlify(params["s"])
437    msg = binascii.unhexlify(params["msg"])
438    if params["fail"]:
439        with pytest.raises(InvalidSignature):
440            public_key.verify(
441                signature,
442                msg,
443                pad,
444                hash_alg
445            )
446    else:
447        public_key.verify(
448            signature,
449            msg,
450            pad,
451            hash_alg
452        )
453
454
455def _check_rsa_private_numbers(skey):
456    assert skey
457    pkey = skey.public_numbers
458    assert pkey
459    assert pkey.e
460    assert pkey.n
461    assert skey.d
462    assert skey.p * skey.q == pkey.n
463    assert skey.dmp1 == rsa.rsa_crt_dmp1(skey.d, skey.p)
464    assert skey.dmq1 == rsa.rsa_crt_dmq1(skey.d, skey.q)
465    assert skey.iqmp == rsa.rsa_crt_iqmp(skey.p, skey.q)
466
467
468def _check_dsa_private_numbers(skey):
469    assert skey
470    pkey = skey.public_numbers
471    params = pkey.parameter_numbers
472    assert pow(params.g, skey.x, params.p) == pkey.y
473