1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import math
8
9from cryptography import utils
10from cryptography.exceptions import (
11    InvalidSignature, UnsupportedAlgorithm, _Reasons
12)
13from cryptography.hazmat.backends.openssl.utils import (
14    _calculate_digest_and_algorithm, _check_not_prehashed,
15    _warn_sign_verify_deprecated
16)
17from cryptography.hazmat.primitives import hashes
18from cryptography.hazmat.primitives.asymmetric import (
19    AsymmetricSignatureContext, AsymmetricVerificationContext, rsa
20)
21from cryptography.hazmat.primitives.asymmetric.padding import (
22    AsymmetricPadding, MGF1, OAEP, PKCS1v15, PSS, calculate_max_pss_salt_length
23)
24from cryptography.hazmat.primitives.asymmetric.rsa import (
25    RSAPrivateKeyWithSerialization, RSAPublicKeyWithSerialization
26)
27
28
29def _get_rsa_pss_salt_length(pss, key, hash_algorithm):
30    salt = pss._salt_length
31
32    if salt is MGF1.MAX_LENGTH or salt is PSS.MAX_LENGTH:
33        return calculate_max_pss_salt_length(key, hash_algorithm)
34    else:
35        return salt
36
37
38def _enc_dec_rsa(backend, key, data, padding):
39    if not isinstance(padding, AsymmetricPadding):
40        raise TypeError("Padding must be an instance of AsymmetricPadding.")
41
42    if isinstance(padding, PKCS1v15):
43        padding_enum = backend._lib.RSA_PKCS1_PADDING
44    elif isinstance(padding, OAEP):
45        padding_enum = backend._lib.RSA_PKCS1_OAEP_PADDING
46
47        if not isinstance(padding._mgf, MGF1):
48            raise UnsupportedAlgorithm(
49                "Only MGF1 is supported by this backend.",
50                _Reasons.UNSUPPORTED_MGF
51            )
52
53        if not backend.rsa_padding_supported(padding):
54            raise UnsupportedAlgorithm(
55                "This combination of padding and hash algorithm is not "
56                "supported by this backend.",
57                _Reasons.UNSUPPORTED_PADDING
58            )
59
60    else:
61        raise UnsupportedAlgorithm(
62            "{0} is not supported by this backend.".format(
63                padding.name
64            ),
65            _Reasons.UNSUPPORTED_PADDING
66        )
67
68    return _enc_dec_rsa_pkey_ctx(backend, key, data, padding_enum, padding)
69
70
71def _enc_dec_rsa_pkey_ctx(backend, key, data, padding_enum, padding):
72    if isinstance(key, _RSAPublicKey):
73        init = backend._lib.EVP_PKEY_encrypt_init
74        crypt = backend._lib.EVP_PKEY_encrypt
75    else:
76        init = backend._lib.EVP_PKEY_decrypt_init
77        crypt = backend._lib.EVP_PKEY_decrypt
78
79    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(
80        key._evp_pkey, backend._ffi.NULL
81    )
82    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
83    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
84    res = init(pkey_ctx)
85    backend.openssl_assert(res == 1)
86    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(
87        pkey_ctx, padding_enum)
88    backend.openssl_assert(res > 0)
89    buf_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
90    backend.openssl_assert(buf_size > 0)
91    if (
92        isinstance(padding, OAEP) and
93        backend._lib.Cryptography_HAS_RSA_OAEP_MD
94    ):
95        mgf1_md = backend._evp_md_non_null_from_algorithm(
96            padding._mgf._algorithm)
97        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
98        backend.openssl_assert(res > 0)
99        oaep_md = backend._evp_md_non_null_from_algorithm(padding._algorithm)
100        res = backend._lib.EVP_PKEY_CTX_set_rsa_oaep_md(pkey_ctx, oaep_md)
101        backend.openssl_assert(res > 0)
102
103    if (
104        isinstance(padding, OAEP) and
105        padding._label is not None and
106        len(padding._label) > 0
107    ):
108        # set0_rsa_oaep_label takes ownership of the char * so we need to
109        # copy it into some new memory
110        labelptr = backend._lib.OPENSSL_malloc(len(padding._label))
111        backend.openssl_assert(labelptr != backend._ffi.NULL)
112        backend._ffi.memmove(labelptr, padding._label, len(padding._label))
113        res = backend._lib.EVP_PKEY_CTX_set0_rsa_oaep_label(
114            pkey_ctx, labelptr, len(padding._label)
115        )
116        backend.openssl_assert(res == 1)
117
118    outlen = backend._ffi.new("size_t *", buf_size)
119    buf = backend._ffi.new("unsigned char[]", buf_size)
120    res = crypt(pkey_ctx, buf, outlen, data, len(data))
121    if res <= 0:
122        _handle_rsa_enc_dec_error(backend, key)
123
124    return backend._ffi.buffer(buf)[:outlen[0]]
125
126
127def _handle_rsa_enc_dec_error(backend, key):
128    errors = backend._consume_errors()
129    backend.openssl_assert(errors)
130    backend.openssl_assert(errors[0].lib == backend._lib.ERR_LIB_RSA)
131    if isinstance(key, _RSAPublicKey):
132        backend.openssl_assert(
133            errors[0].reason == backend._lib.RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE
134        )
135        raise ValueError(
136            "Data too long for key size. Encrypt less data or use a "
137            "larger key size."
138        )
139    else:
140        decoding_errors = [
141            backend._lib.RSA_R_BLOCK_TYPE_IS_NOT_01,
142            backend._lib.RSA_R_BLOCK_TYPE_IS_NOT_02,
143            backend._lib.RSA_R_OAEP_DECODING_ERROR,
144            # Though this error looks similar to the
145            # RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE, this occurs on decrypts,
146            # rather than on encrypts
147            backend._lib.RSA_R_DATA_TOO_LARGE_FOR_MODULUS,
148        ]
149        if backend._lib.Cryptography_HAS_RSA_R_PKCS_DECODING_ERROR:
150            decoding_errors.append(backend._lib.RSA_R_PKCS_DECODING_ERROR)
151
152        backend.openssl_assert(errors[0].reason in decoding_errors)
153        raise ValueError("Decryption failed.")
154
155
156def _rsa_sig_determine_padding(backend, key, padding, algorithm):
157    if not isinstance(padding, AsymmetricPadding):
158        raise TypeError("Expected provider of AsymmetricPadding.")
159
160    pkey_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
161    backend.openssl_assert(pkey_size > 0)
162
163    if isinstance(padding, PKCS1v15):
164        padding_enum = backend._lib.RSA_PKCS1_PADDING
165    elif isinstance(padding, PSS):
166        if not isinstance(padding._mgf, MGF1):
167            raise UnsupportedAlgorithm(
168                "Only MGF1 is supported by this backend.",
169                _Reasons.UNSUPPORTED_MGF
170            )
171
172        # Size of key in bytes - 2 is the maximum
173        # PSS signature length (salt length is checked later)
174        if pkey_size - algorithm.digest_size - 2 < 0:
175            raise ValueError("Digest too large for key size. Use a larger "
176                             "key or different digest.")
177
178        padding_enum = backend._lib.RSA_PKCS1_PSS_PADDING
179    else:
180        raise UnsupportedAlgorithm(
181            "{0} is not supported by this backend.".format(padding.name),
182            _Reasons.UNSUPPORTED_PADDING
183        )
184
185    return padding_enum
186
187
188def _rsa_sig_setup(backend, padding, algorithm, key, data, init_func):
189    padding_enum = _rsa_sig_determine_padding(backend, key, padding, algorithm)
190    evp_md = backend._evp_md_non_null_from_algorithm(algorithm)
191    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL)
192    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
193    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
194    res = init_func(pkey_ctx)
195    backend.openssl_assert(res == 1)
196    res = backend._lib.EVP_PKEY_CTX_set_signature_md(pkey_ctx, evp_md)
197    if res == 0:
198        backend._consume_errors()
199        raise UnsupportedAlgorithm(
200            "{0} is not supported by this backend for RSA signing.".format(
201                algorithm.name
202            ),
203            _Reasons.UNSUPPORTED_HASH
204        )
205    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum)
206    backend.openssl_assert(res > 0)
207    if isinstance(padding, PSS):
208        res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
209            pkey_ctx, _get_rsa_pss_salt_length(padding, key, algorithm)
210        )
211        backend.openssl_assert(res > 0)
212
213        mgf1_md = backend._evp_md_non_null_from_algorithm(
214            padding._mgf._algorithm)
215        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
216        backend.openssl_assert(res > 0)
217
218    return pkey_ctx
219
220
221def _rsa_sig_sign(backend, padding, algorithm, private_key, data):
222    pkey_ctx = _rsa_sig_setup(
223        backend, padding, algorithm, private_key, data,
224        backend._lib.EVP_PKEY_sign_init
225    )
226    buflen = backend._ffi.new("size_t *")
227    res = backend._lib.EVP_PKEY_sign(
228        pkey_ctx,
229        backend._ffi.NULL,
230        buflen,
231        data,
232        len(data)
233    )
234    backend.openssl_assert(res == 1)
235    buf = backend._ffi.new("unsigned char[]", buflen[0])
236    res = backend._lib.EVP_PKEY_sign(
237        pkey_ctx, buf, buflen, data, len(data))
238    if res != 1:
239        errors = backend._consume_errors()
240        backend.openssl_assert(errors[0].lib == backend._lib.ERR_LIB_RSA)
241        if (
242            errors[0].reason ==
243            backend._lib.RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE
244        ):
245            reason = ("Salt length too long for key size. Try using "
246                      "MAX_LENGTH instead.")
247        else:
248            backend.openssl_assert(
249                errors[0].reason ==
250                backend._lib.RSA_R_DIGEST_TOO_BIG_FOR_RSA_KEY
251            )
252            reason = "Digest too large for key size. Use a larger key."
253        raise ValueError(reason)
254
255    return backend._ffi.buffer(buf)[:]
256
257
258def _rsa_sig_verify(backend, padding, algorithm, public_key, signature, data):
259    pkey_ctx = _rsa_sig_setup(
260        backend, padding, algorithm, public_key, data,
261        backend._lib.EVP_PKEY_verify_init
262    )
263    res = backend._lib.EVP_PKEY_verify(
264        pkey_ctx, signature, len(signature), data, len(data)
265    )
266    # The previous call can return negative numbers in the event of an
267    # error. This is not a signature failure but we need to fail if it
268    # occurs.
269    backend.openssl_assert(res >= 0)
270    if res == 0:
271        backend._consume_errors()
272        raise InvalidSignature
273
274
275@utils.register_interface(AsymmetricSignatureContext)
276class _RSASignatureContext(object):
277    def __init__(self, backend, private_key, padding, algorithm):
278        self._backend = backend
279        self._private_key = private_key
280
281        # We now call _rsa_sig_determine_padding in _rsa_sig_setup. However
282        # we need to make a pointless call to it here so we maintain the
283        # API of erroring on init with this context if the values are invalid.
284        _rsa_sig_determine_padding(backend, private_key, padding, algorithm)
285        self._padding = padding
286        self._algorithm = algorithm
287        self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
288
289    def update(self, data):
290        self._hash_ctx.update(data)
291
292    def finalize(self):
293        return _rsa_sig_sign(
294            self._backend,
295            self._padding,
296            self._algorithm,
297            self._private_key,
298            self._hash_ctx.finalize()
299        )
300
301
302@utils.register_interface(AsymmetricVerificationContext)
303class _RSAVerificationContext(object):
304    def __init__(self, backend, public_key, signature, padding, algorithm):
305        self._backend = backend
306        self._public_key = public_key
307        self._signature = signature
308        self._padding = padding
309        # We now call _rsa_sig_determine_padding in _rsa_sig_setup. However
310        # we need to make a pointless call to it here so we maintain the
311        # API of erroring on init with this context if the values are invalid.
312        _rsa_sig_determine_padding(backend, public_key, padding, algorithm)
313
314        padding = padding
315        self._algorithm = algorithm
316        self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
317
318    def update(self, data):
319        self._hash_ctx.update(data)
320
321    def verify(self):
322        return _rsa_sig_verify(
323            self._backend,
324            self._padding,
325            self._algorithm,
326            self._public_key,
327            self._signature,
328            self._hash_ctx.finalize()
329        )
330
331
332@utils.register_interface(RSAPrivateKeyWithSerialization)
333class _RSAPrivateKey(object):
334    def __init__(self, backend, rsa_cdata, evp_pkey):
335        self._backend = backend
336        self._rsa_cdata = rsa_cdata
337        self._evp_pkey = evp_pkey
338
339        n = self._backend._ffi.new("BIGNUM **")
340        self._backend._lib.RSA_get0_key(
341            self._rsa_cdata, n, self._backend._ffi.NULL,
342            self._backend._ffi.NULL
343        )
344        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
345        self._key_size = self._backend._lib.BN_num_bits(n[0])
346
347    key_size = utils.read_only_property("_key_size")
348
349    def signer(self, padding, algorithm):
350        _warn_sign_verify_deprecated()
351        _check_not_prehashed(algorithm)
352        return _RSASignatureContext(self._backend, self, padding, algorithm)
353
354    def decrypt(self, ciphertext, padding):
355        key_size_bytes = int(math.ceil(self.key_size / 8.0))
356        if key_size_bytes != len(ciphertext):
357            raise ValueError("Ciphertext length must be equal to key size.")
358
359        return _enc_dec_rsa(self._backend, self, ciphertext, padding)
360
361    def public_key(self):
362        ctx = self._backend._lib.RSAPublicKey_dup(self._rsa_cdata)
363        self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
364        ctx = self._backend._ffi.gc(ctx, self._backend._lib.RSA_free)
365        res = self._backend._lib.RSA_blinding_on(ctx, self._backend._ffi.NULL)
366        self._backend.openssl_assert(res == 1)
367        evp_pkey = self._backend._rsa_cdata_to_evp_pkey(ctx)
368        return _RSAPublicKey(self._backend, ctx, evp_pkey)
369
370    def private_numbers(self):
371        n = self._backend._ffi.new("BIGNUM **")
372        e = self._backend._ffi.new("BIGNUM **")
373        d = self._backend._ffi.new("BIGNUM **")
374        p = self._backend._ffi.new("BIGNUM **")
375        q = self._backend._ffi.new("BIGNUM **")
376        dmp1 = self._backend._ffi.new("BIGNUM **")
377        dmq1 = self._backend._ffi.new("BIGNUM **")
378        iqmp = self._backend._ffi.new("BIGNUM **")
379        self._backend._lib.RSA_get0_key(self._rsa_cdata, n, e, d)
380        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
381        self._backend.openssl_assert(e[0] != self._backend._ffi.NULL)
382        self._backend.openssl_assert(d[0] != self._backend._ffi.NULL)
383        self._backend._lib.RSA_get0_factors(self._rsa_cdata, p, q)
384        self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
385        self._backend.openssl_assert(q[0] != self._backend._ffi.NULL)
386        self._backend._lib.RSA_get0_crt_params(
387            self._rsa_cdata, dmp1, dmq1, iqmp
388        )
389        self._backend.openssl_assert(dmp1[0] != self._backend._ffi.NULL)
390        self._backend.openssl_assert(dmq1[0] != self._backend._ffi.NULL)
391        self._backend.openssl_assert(iqmp[0] != self._backend._ffi.NULL)
392        return rsa.RSAPrivateNumbers(
393            p=self._backend._bn_to_int(p[0]),
394            q=self._backend._bn_to_int(q[0]),
395            d=self._backend._bn_to_int(d[0]),
396            dmp1=self._backend._bn_to_int(dmp1[0]),
397            dmq1=self._backend._bn_to_int(dmq1[0]),
398            iqmp=self._backend._bn_to_int(iqmp[0]),
399            public_numbers=rsa.RSAPublicNumbers(
400                e=self._backend._bn_to_int(e[0]),
401                n=self._backend._bn_to_int(n[0]),
402            )
403        )
404
405    def private_bytes(self, encoding, format, encryption_algorithm):
406        return self._backend._private_key_bytes(
407            encoding,
408            format,
409            encryption_algorithm,
410            self._evp_pkey,
411            self._rsa_cdata
412        )
413
414    def sign(self, data, padding, algorithm):
415        data, algorithm = _calculate_digest_and_algorithm(
416            self._backend, data, algorithm
417        )
418        return _rsa_sig_sign(self._backend, padding, algorithm, self, data)
419
420
421@utils.register_interface(RSAPublicKeyWithSerialization)
422class _RSAPublicKey(object):
423    def __init__(self, backend, rsa_cdata, evp_pkey):
424        self._backend = backend
425        self._rsa_cdata = rsa_cdata
426        self._evp_pkey = evp_pkey
427
428        n = self._backend._ffi.new("BIGNUM **")
429        self._backend._lib.RSA_get0_key(
430            self._rsa_cdata, n, self._backend._ffi.NULL,
431            self._backend._ffi.NULL
432        )
433        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
434        self._key_size = self._backend._lib.BN_num_bits(n[0])
435
436    key_size = utils.read_only_property("_key_size")
437
438    def verifier(self, signature, padding, algorithm):
439        _warn_sign_verify_deprecated()
440        utils._check_bytes("signature", signature)
441
442        _check_not_prehashed(algorithm)
443        return _RSAVerificationContext(
444            self._backend, self, signature, padding, algorithm
445        )
446
447    def encrypt(self, plaintext, padding):
448        return _enc_dec_rsa(self._backend, self, plaintext, padding)
449
450    def public_numbers(self):
451        n = self._backend._ffi.new("BIGNUM **")
452        e = self._backend._ffi.new("BIGNUM **")
453        self._backend._lib.RSA_get0_key(
454            self._rsa_cdata, n, e, self._backend._ffi.NULL
455        )
456        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
457        self._backend.openssl_assert(e[0] != self._backend._ffi.NULL)
458        return rsa.RSAPublicNumbers(
459            e=self._backend._bn_to_int(e[0]),
460            n=self._backend._bn_to_int(n[0]),
461        )
462
463    def public_bytes(self, encoding, format):
464        return self._backend._public_key_bytes(
465            encoding,
466            format,
467            self,
468            self._evp_pkey,
469            self._rsa_cdata
470        )
471
472    def verify(self, signature, data, padding, algorithm):
473        data, algorithm = _calculate_digest_and_algorithm(
474            self._backend, data, algorithm
475        )
476        return _rsa_sig_verify(
477            self._backend, padding, algorithm, self, signature, data
478        )
479