1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7from cryptography.exceptions import InvalidTag
8
9
10_ENCRYPT = 1
11_DECRYPT = 0
12
13
14def _aead_cipher_name(cipher):
15    from cryptography.hazmat.primitives.ciphers.aead import (
16        AESCCM, AESGCM, ChaCha20Poly1305
17    )
18    if isinstance(cipher, ChaCha20Poly1305):
19        return b"chacha20-poly1305"
20    elif isinstance(cipher, AESCCM):
21        return "aes-{0}-ccm".format(len(cipher._key) * 8).encode("ascii")
22    else:
23        assert isinstance(cipher, AESGCM)
24        return "aes-{0}-gcm".format(len(cipher._key) * 8).encode("ascii")
25
26
27def _aead_setup(backend, cipher_name, key, nonce, tag, tag_len, operation):
28    evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
29    backend.openssl_assert(evp_cipher != backend._ffi.NULL)
30    ctx = backend._lib.EVP_CIPHER_CTX_new()
31    ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
32    res = backend._lib.EVP_CipherInit_ex(
33        ctx, evp_cipher,
34        backend._ffi.NULL,
35        backend._ffi.NULL,
36        backend._ffi.NULL,
37        int(operation == _ENCRYPT)
38    )
39    backend.openssl_assert(res != 0)
40    res = backend._lib.EVP_CIPHER_CTX_set_key_length(ctx, len(key))
41    backend.openssl_assert(res != 0)
42    res = backend._lib.EVP_CIPHER_CTX_ctrl(
43        ctx, backend._lib.EVP_CTRL_AEAD_SET_IVLEN, len(nonce),
44        backend._ffi.NULL
45    )
46    backend.openssl_assert(res != 0)
47    if operation == _DECRYPT:
48        res = backend._lib.EVP_CIPHER_CTX_ctrl(
49            ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
50        )
51        backend.openssl_assert(res != 0)
52    else:
53        res = backend._lib.EVP_CIPHER_CTX_ctrl(
54            ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, tag_len, backend._ffi.NULL
55        )
56
57    nonce_ptr = backend._ffi.from_buffer(nonce)
58    key_ptr = backend._ffi.from_buffer(key)
59    res = backend._lib.EVP_CipherInit_ex(
60        ctx,
61        backend._ffi.NULL,
62        backend._ffi.NULL,
63        key_ptr,
64        nonce_ptr,
65        int(operation == _ENCRYPT)
66    )
67    backend.openssl_assert(res != 0)
68    return ctx
69
70
71def _set_length(backend, ctx, data_len):
72    intptr = backend._ffi.new("int *")
73    res = backend._lib.EVP_CipherUpdate(
74        ctx,
75        backend._ffi.NULL,
76        intptr,
77        backend._ffi.NULL,
78        data_len
79    )
80    backend.openssl_assert(res != 0)
81
82
83def _process_aad(backend, ctx, associated_data):
84    outlen = backend._ffi.new("int *")
85    res = backend._lib.EVP_CipherUpdate(
86        ctx, backend._ffi.NULL, outlen, associated_data, len(associated_data)
87    )
88    backend.openssl_assert(res != 0)
89
90
91def _process_data(backend, ctx, data):
92    outlen = backend._ffi.new("int *")
93    buf = backend._ffi.new("unsigned char[]", len(data))
94    res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
95    backend.openssl_assert(res != 0)
96    return backend._ffi.buffer(buf, outlen[0])[:]
97
98
99def _encrypt(backend, cipher, nonce, data, associated_data, tag_length):
100    from cryptography.hazmat.primitives.ciphers.aead import AESCCM
101    cipher_name = _aead_cipher_name(cipher)
102    ctx = _aead_setup(
103        backend, cipher_name, cipher._key, nonce, None, tag_length, _ENCRYPT
104    )
105    # CCM requires us to pass the length of the data before processing anything
106    # However calling this with any other AEAD results in an error
107    if isinstance(cipher, AESCCM):
108        _set_length(backend, ctx, len(data))
109
110    _process_aad(backend, ctx, associated_data)
111    processed_data = _process_data(backend, ctx, data)
112    outlen = backend._ffi.new("int *")
113    res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
114    backend.openssl_assert(res != 0)
115    backend.openssl_assert(outlen[0] == 0)
116    tag_buf = backend._ffi.new("unsigned char[]", tag_length)
117    res = backend._lib.EVP_CIPHER_CTX_ctrl(
118        ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf
119    )
120    backend.openssl_assert(res != 0)
121    tag = backend._ffi.buffer(tag_buf)[:]
122
123    return processed_data + tag
124
125
126def _decrypt(backend, cipher, nonce, data, associated_data, tag_length):
127    from cryptography.hazmat.primitives.ciphers.aead import AESCCM
128    if len(data) < tag_length:
129        raise InvalidTag
130    tag = data[-tag_length:]
131    data = data[:-tag_length]
132    cipher_name = _aead_cipher_name(cipher)
133    ctx = _aead_setup(
134        backend, cipher_name, cipher._key, nonce, tag, tag_length, _DECRYPT
135    )
136    # CCM requires us to pass the length of the data before processing anything
137    # However calling this with any other AEAD results in an error
138    if isinstance(cipher, AESCCM):
139        _set_length(backend, ctx, len(data))
140
141    _process_aad(backend, ctx, associated_data)
142    # CCM has a different error path if the tag doesn't match. Errors are
143    # raised in Update and Final is irrelevant.
144    if isinstance(cipher, AESCCM):
145        outlen = backend._ffi.new("int *")
146        buf = backend._ffi.new("unsigned char[]", len(data))
147        res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
148        if res != 1:
149            backend._consume_errors()
150            raise InvalidTag
151
152        processed_data = backend._ffi.buffer(buf, outlen[0])[:]
153    else:
154        processed_data = _process_data(backend, ctx, data)
155        outlen = backend._ffi.new("int *")
156        res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
157        if res == 0:
158            backend._consume_errors()
159            raise InvalidTag
160
161    return processed_data
162