1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import binascii
8import os
9
10import pytest
11
12from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons
13from cryptography.hazmat.backends.interfaces import CipherBackend
14from cryptography.hazmat.primitives.ciphers.aead import (
15    AESCCM, AESGCM, ChaCha20Poly1305
16)
17
18from .utils import _load_all_params
19from ...utils import (
20    load_nist_ccm_vectors, load_nist_vectors, load_vectors_from_file,
21    raises_unsupported_algorithm
22)
23
24
25class FakeData(object):
26    def __len__(self):
27        return 2 ** 32 + 1
28
29
30def _aead_supported(cls):
31    try:
32        cls(b"0" * 32)
33        return True
34    except UnsupportedAlgorithm:
35        return False
36
37
38@pytest.mark.skipif(
39    _aead_supported(ChaCha20Poly1305),
40    reason="Requires OpenSSL without ChaCha20Poly1305 support"
41)
42@pytest.mark.requires_backend_interface(interface=CipherBackend)
43def test_chacha20poly1305_unsupported_on_older_openssl(backend):
44    with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_CIPHER):
45        ChaCha20Poly1305(ChaCha20Poly1305.generate_key())
46
47
48@pytest.mark.skipif(
49    not _aead_supported(ChaCha20Poly1305),
50    reason="Does not support ChaCha20Poly1305"
51)
52@pytest.mark.requires_backend_interface(interface=CipherBackend)
53class TestChaCha20Poly1305(object):
54    def test_data_too_large(self):
55        key = ChaCha20Poly1305.generate_key()
56        chacha = ChaCha20Poly1305(key)
57        nonce = b"0" * 12
58
59        with pytest.raises(OverflowError):
60            chacha.encrypt(nonce, FakeData(), b"")
61
62        with pytest.raises(OverflowError):
63            chacha.encrypt(nonce, b"", FakeData())
64
65    def test_generate_key(self):
66        key = ChaCha20Poly1305.generate_key()
67        assert len(key) == 32
68
69    def test_bad_key(self, backend):
70        with pytest.raises(TypeError):
71            ChaCha20Poly1305(object())
72
73        with pytest.raises(ValueError):
74            ChaCha20Poly1305(b"0" * 31)
75
76    @pytest.mark.parametrize(
77        ("nonce", "data", "associated_data"),
78        [
79            [object(), b"data", b""],
80            [b"0" * 12, object(), b""],
81            [b"0" * 12, b"data", object()]
82        ]
83    )
84    def test_params_not_bytes_encrypt(self, nonce, data, associated_data,
85                                      backend):
86        key = ChaCha20Poly1305.generate_key()
87        chacha = ChaCha20Poly1305(key)
88        with pytest.raises(TypeError):
89            chacha.encrypt(nonce, data, associated_data)
90
91        with pytest.raises(TypeError):
92            chacha.decrypt(nonce, data, associated_data)
93
94    def test_nonce_not_12_bytes(self, backend):
95        key = ChaCha20Poly1305.generate_key()
96        chacha = ChaCha20Poly1305(key)
97        with pytest.raises(ValueError):
98            chacha.encrypt(b"00", b"hello", b"")
99
100        with pytest.raises(ValueError):
101            chacha.decrypt(b"00", b"hello", b"")
102
103    def test_decrypt_data_too_short(self, backend):
104        key = ChaCha20Poly1305.generate_key()
105        chacha = ChaCha20Poly1305(key)
106        with pytest.raises(InvalidTag):
107            chacha.decrypt(b"0" * 12, b"0", None)
108
109    def test_associated_data_none_equal_to_empty_bytestring(self, backend):
110        key = ChaCha20Poly1305.generate_key()
111        chacha = ChaCha20Poly1305(key)
112        nonce = os.urandom(12)
113        ct1 = chacha.encrypt(nonce, b"some_data", None)
114        ct2 = chacha.encrypt(nonce, b"some_data", b"")
115        assert ct1 == ct2
116        pt1 = chacha.decrypt(nonce, ct1, None)
117        pt2 = chacha.decrypt(nonce, ct2, b"")
118        assert pt1 == pt2
119
120    @pytest.mark.parametrize(
121        "vector",
122        load_vectors_from_file(
123            os.path.join("ciphers", "ChaCha20Poly1305", "openssl.txt"),
124            load_nist_vectors
125        )
126    )
127    def test_openssl_vectors(self, vector, backend):
128        key = binascii.unhexlify(vector["key"])
129        nonce = binascii.unhexlify(vector["iv"])
130        aad = binascii.unhexlify(vector["aad"])
131        tag = binascii.unhexlify(vector["tag"])
132        pt = binascii.unhexlify(vector["plaintext"])
133        ct = binascii.unhexlify(vector["ciphertext"])
134        chacha = ChaCha20Poly1305(key)
135        if vector.get("result") == b"CIPHERFINAL_ERROR":
136            with pytest.raises(InvalidTag):
137                chacha.decrypt(nonce, ct + tag, aad)
138        else:
139            computed_pt = chacha.decrypt(nonce, ct + tag, aad)
140            assert computed_pt == pt
141            computed_ct = chacha.encrypt(nonce, pt, aad)
142            assert computed_ct == ct + tag
143
144    @pytest.mark.parametrize(
145        "vector",
146        load_vectors_from_file(
147            os.path.join("ciphers", "ChaCha20Poly1305", "boringssl.txt"),
148            load_nist_vectors
149        )
150    )
151    def test_boringssl_vectors(self, vector, backend):
152        key = binascii.unhexlify(vector["key"])
153        nonce = binascii.unhexlify(vector["nonce"])
154        if vector["ad"].startswith(b'"'):
155            aad = vector["ad"][1:-1]
156        else:
157            aad = binascii.unhexlify(vector["ad"])
158        tag = binascii.unhexlify(vector["tag"])
159        if vector["in"].startswith(b'"'):
160            pt = vector["in"][1:-1]
161        else:
162            pt = binascii.unhexlify(vector["in"])
163        ct = binascii.unhexlify(vector["ct"].strip(b'"'))
164        chacha = ChaCha20Poly1305(key)
165        computed_pt = chacha.decrypt(nonce, ct + tag, aad)
166        assert computed_pt == pt
167        computed_ct = chacha.encrypt(nonce, pt, aad)
168        assert computed_ct == ct + tag
169
170    def test_buffer_protocol(self, backend):
171        key = ChaCha20Poly1305.generate_key()
172        chacha = ChaCha20Poly1305(key)
173        pt = b"encrypt me"
174        ad = b"additional"
175        nonce = os.urandom(12)
176        ct = chacha.encrypt(nonce, pt, ad)
177        computed_pt = chacha.decrypt(nonce, ct, ad)
178        assert computed_pt == pt
179        chacha2 = ChaCha20Poly1305(bytearray(key))
180        ct2 = chacha2.encrypt(bytearray(nonce), pt, ad)
181        assert ct2 == ct
182        computed_pt2 = chacha2.decrypt(bytearray(nonce), ct2, ad)
183        assert computed_pt2 == pt
184
185
186@pytest.mark.skipif(
187    _aead_supported(AESCCM),
188    reason="Requires OpenSSL without AES-CCM support"
189)
190@pytest.mark.requires_backend_interface(interface=CipherBackend)
191def test_aesccm_unsupported_on_older_openssl(backend):
192    with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_CIPHER):
193        AESCCM(AESCCM.generate_key(128))
194
195
196@pytest.mark.skipif(
197    not _aead_supported(AESCCM),
198    reason="Does not support AESCCM"
199)
200@pytest.mark.requires_backend_interface(interface=CipherBackend)
201class TestAESCCM(object):
202    def test_data_too_large(self):
203        key = AESCCM.generate_key(128)
204        aesccm = AESCCM(key)
205        nonce = b"0" * 12
206
207        with pytest.raises(OverflowError):
208            aesccm.encrypt(nonce, FakeData(), b"")
209
210        with pytest.raises(OverflowError):
211            aesccm.encrypt(nonce, b"", FakeData())
212
213    def test_default_tag_length(self, backend):
214        key = AESCCM.generate_key(128)
215        aesccm = AESCCM(key)
216        nonce = os.urandom(12)
217        pt = b"hello"
218        ct = aesccm.encrypt(nonce, pt, None)
219        assert len(ct) == len(pt) + 16
220
221    def test_invalid_tag_length(self, backend):
222        key = AESCCM.generate_key(128)
223        with pytest.raises(ValueError):
224            AESCCM(key, tag_length=7)
225
226        with pytest.raises(ValueError):
227            AESCCM(key, tag_length=2)
228
229        with pytest.raises(TypeError):
230            AESCCM(key, tag_length="notanint")
231
232    def test_invalid_nonce_length(self, backend):
233        key = AESCCM.generate_key(128)
234        aesccm = AESCCM(key)
235        pt = b"hello"
236        nonce = os.urandom(14)
237        with pytest.raises(ValueError):
238            aesccm.encrypt(nonce, pt, None)
239
240        with pytest.raises(ValueError):
241            aesccm.encrypt(nonce[:6], pt, None)
242
243    @pytest.mark.parametrize(
244        "vector",
245        _load_all_params(
246            os.path.join("ciphers", "AES", "CCM"),
247            [
248                "DVPT128.rsp", "DVPT192.rsp", "DVPT256.rsp",
249                "VADT128.rsp", "VADT192.rsp", "VADT256.rsp",
250                "VNT128.rsp", "VNT192.rsp", "VNT256.rsp",
251                "VPT128.rsp", "VPT192.rsp", "VPT256.rsp",
252            ],
253            load_nist_ccm_vectors
254        )
255    )
256    def test_vectors(self, vector, backend):
257        key = binascii.unhexlify(vector["key"])
258        nonce = binascii.unhexlify(vector["nonce"])
259        adata = binascii.unhexlify(vector["adata"])[:vector["alen"]]
260        ct = binascii.unhexlify(vector["ct"])
261        pt = binascii.unhexlify(vector["payload"])[:vector["plen"]]
262        aesccm = AESCCM(key, vector["tlen"])
263        if vector.get('fail'):
264            with pytest.raises(InvalidTag):
265                aesccm.decrypt(nonce, ct, adata)
266        else:
267            computed_pt = aesccm.decrypt(nonce, ct, adata)
268            assert computed_pt == pt
269            assert aesccm.encrypt(nonce, pt, adata) == ct
270
271    def test_roundtrip(self, backend):
272        key = AESCCM.generate_key(128)
273        aesccm = AESCCM(key)
274        pt = b"encrypt me"
275        ad = b"additional"
276        nonce = os.urandom(12)
277        ct = aesccm.encrypt(nonce, pt, ad)
278        computed_pt = aesccm.decrypt(nonce, ct, ad)
279        assert computed_pt == pt
280
281    def test_nonce_too_long(self, backend):
282        key = AESCCM.generate_key(128)
283        aesccm = AESCCM(key)
284        pt = b"encrypt me" * 6600
285        # pt can be no more than 65536 bytes when nonce is 13 bytes
286        nonce = os.urandom(13)
287        with pytest.raises(ValueError):
288            aesccm.encrypt(nonce, pt, None)
289
290    @pytest.mark.parametrize(
291        ("nonce", "data", "associated_data"),
292        [
293            [object(), b"data", b""],
294            [b"0" * 12, object(), b""],
295            [b"0" * 12, b"data", object()],
296        ]
297    )
298    def test_params_not_bytes(self, nonce, data, associated_data, backend):
299        key = AESCCM.generate_key(128)
300        aesccm = AESCCM(key)
301        with pytest.raises(TypeError):
302            aesccm.encrypt(nonce, data, associated_data)
303
304    def test_bad_key(self, backend):
305        with pytest.raises(TypeError):
306            AESCCM(object())
307
308        with pytest.raises(ValueError):
309            AESCCM(b"0" * 31)
310
311    def test_bad_generate_key(self, backend):
312        with pytest.raises(TypeError):
313            AESCCM.generate_key(object())
314
315        with pytest.raises(ValueError):
316            AESCCM.generate_key(129)
317
318    def test_associated_data_none_equal_to_empty_bytestring(self, backend):
319        key = AESCCM.generate_key(128)
320        aesccm = AESCCM(key)
321        nonce = os.urandom(12)
322        ct1 = aesccm.encrypt(nonce, b"some_data", None)
323        ct2 = aesccm.encrypt(nonce, b"some_data", b"")
324        assert ct1 == ct2
325        pt1 = aesccm.decrypt(nonce, ct1, None)
326        pt2 = aesccm.decrypt(nonce, ct2, b"")
327        assert pt1 == pt2
328
329    def test_decrypt_data_too_short(self, backend):
330        key = AESCCM.generate_key(128)
331        aesccm = AESCCM(key)
332        with pytest.raises(InvalidTag):
333            aesccm.decrypt(b"0" * 12, b"0", None)
334
335    def test_buffer_protocol(self, backend):
336        key = AESCCM.generate_key(128)
337        aesccm = AESCCM(key)
338        pt = b"encrypt me"
339        ad = b"additional"
340        nonce = os.urandom(12)
341        ct = aesccm.encrypt(nonce, pt, ad)
342        computed_pt = aesccm.decrypt(nonce, ct, ad)
343        assert computed_pt == pt
344        aesccm2 = AESCCM(bytearray(key))
345        ct2 = aesccm2.encrypt(bytearray(nonce), pt, ad)
346        assert ct2 == ct
347        computed_pt2 = aesccm2.decrypt(bytearray(nonce), ct2, ad)
348        assert computed_pt2 == pt
349
350
351def _load_gcm_vectors():
352    vectors = _load_all_params(
353        os.path.join("ciphers", "AES", "GCM"),
354        [
355            "gcmDecrypt128.rsp",
356            "gcmDecrypt192.rsp",
357            "gcmDecrypt256.rsp",
358            "gcmEncryptExtIV128.rsp",
359            "gcmEncryptExtIV192.rsp",
360            "gcmEncryptExtIV256.rsp",
361        ],
362        load_nist_vectors
363    )
364    return [x for x in vectors if len(x["tag"]) == 32]
365
366
367@pytest.mark.requires_backend_interface(interface=CipherBackend)
368class TestAESGCM(object):
369    def test_data_too_large(self):
370        key = AESGCM.generate_key(128)
371        aesgcm = AESGCM(key)
372        nonce = b"0" * 12
373
374        with pytest.raises(OverflowError):
375            aesgcm.encrypt(nonce, FakeData(), b"")
376
377        with pytest.raises(OverflowError):
378            aesgcm.encrypt(nonce, b"", FakeData())
379
380    @pytest.mark.parametrize("vector", _load_gcm_vectors())
381    def test_vectors(self, vector):
382        key = binascii.unhexlify(vector["key"])
383        nonce = binascii.unhexlify(vector["iv"])
384        aad = binascii.unhexlify(vector["aad"])
385        ct = binascii.unhexlify(vector["ct"])
386        pt = binascii.unhexlify(vector.get("pt", b""))
387        tag = binascii.unhexlify(vector["tag"])
388        aesgcm = AESGCM(key)
389        if vector.get("fail") is True:
390            with pytest.raises(InvalidTag):
391                aesgcm.decrypt(nonce, ct + tag, aad)
392        else:
393            computed_ct = aesgcm.encrypt(nonce, pt, aad)
394            assert computed_ct[:-16] == ct
395            assert computed_ct[-16:] == tag
396            computed_pt = aesgcm.decrypt(nonce, ct + tag, aad)
397            assert computed_pt == pt
398
399    @pytest.mark.parametrize(
400        ("nonce", "data", "associated_data"),
401        [
402            [object(), b"data", b""],
403            [b"0" * 12, object(), b""],
404            [b"0" * 12, b"data", object()]
405        ]
406    )
407    def test_params_not_bytes(self, nonce, data, associated_data, backend):
408        key = AESGCM.generate_key(128)
409        aesgcm = AESGCM(key)
410        with pytest.raises(TypeError):
411            aesgcm.encrypt(nonce, data, associated_data)
412
413        with pytest.raises(TypeError):
414            aesgcm.decrypt(nonce, data, associated_data)
415
416    def test_invalid_nonce_length(self, backend):
417        key = AESGCM.generate_key(128)
418        aesgcm = AESGCM(key)
419        with pytest.raises(ValueError):
420            aesgcm.encrypt(b"", b"hi", None)
421
422    def test_bad_key(self, backend):
423        with pytest.raises(TypeError):
424            AESGCM(object())
425
426        with pytest.raises(ValueError):
427            AESGCM(b"0" * 31)
428
429    def test_bad_generate_key(self, backend):
430        with pytest.raises(TypeError):
431            AESGCM.generate_key(object())
432
433        with pytest.raises(ValueError):
434            AESGCM.generate_key(129)
435
436    def test_associated_data_none_equal_to_empty_bytestring(self, backend):
437        key = AESGCM.generate_key(128)
438        aesgcm = AESGCM(key)
439        nonce = os.urandom(12)
440        ct1 = aesgcm.encrypt(nonce, b"some_data", None)
441        ct2 = aesgcm.encrypt(nonce, b"some_data", b"")
442        assert ct1 == ct2
443        pt1 = aesgcm.decrypt(nonce, ct1, None)
444        pt2 = aesgcm.decrypt(nonce, ct2, b"")
445        assert pt1 == pt2
446
447    def test_buffer_protocol(self, backend):
448        key = AESGCM.generate_key(128)
449        aesgcm = AESGCM(key)
450        pt = b"encrypt me"
451        ad = b"additional"
452        nonce = os.urandom(12)
453        ct = aesgcm.encrypt(nonce, pt, ad)
454        computed_pt = aesgcm.decrypt(nonce, ct, ad)
455        assert computed_pt == pt
456        aesgcm2 = AESGCM(bytearray(key))
457        ct2 = aesgcm2.encrypt(bytearray(nonce), pt, ad)
458        assert ct2 == ct
459        computed_pt2 = aesgcm2.decrypt(bytearray(nonce), ct2, ad)
460        assert computed_pt2 == pt
461