1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import abc
8try:
9    # Only available in math in 3.5+
10    from math import gcd
11except ImportError:
12    from fractions import gcd
13
14import six
15
16from cryptography import utils
17from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
18from cryptography.hazmat.backends.interfaces import RSABackend
19
20
21@six.add_metaclass(abc.ABCMeta)
22class RSAPrivateKey(object):
23    @abc.abstractmethod
24    def signer(self, padding, algorithm):
25        """
26        Returns an AsymmetricSignatureContext used for signing data.
27        """
28
29    @abc.abstractmethod
30    def decrypt(self, ciphertext, padding):
31        """
32        Decrypts the provided ciphertext.
33        """
34
35    @abc.abstractproperty
36    def key_size(self):
37        """
38        The bit length of the public modulus.
39        """
40
41    @abc.abstractmethod
42    def public_key(self):
43        """
44        The RSAPublicKey associated with this private key.
45        """
46
47    @abc.abstractmethod
48    def sign(self, data, padding, algorithm):
49        """
50        Signs the data.
51        """
52
53
54@six.add_metaclass(abc.ABCMeta)
55class RSAPrivateKeyWithSerialization(RSAPrivateKey):
56    @abc.abstractmethod
57    def private_numbers(self):
58        """
59        Returns an RSAPrivateNumbers.
60        """
61
62    @abc.abstractmethod
63    def private_bytes(self, encoding, format, encryption_algorithm):
64        """
65        Returns the key serialized as bytes.
66        """
67
68
69@six.add_metaclass(abc.ABCMeta)
70class RSAPublicKey(object):
71    @abc.abstractmethod
72    def verifier(self, signature, padding, algorithm):
73        """
74        Returns an AsymmetricVerificationContext used for verifying signatures.
75        """
76
77    @abc.abstractmethod
78    def encrypt(self, plaintext, padding):
79        """
80        Encrypts the given plaintext.
81        """
82
83    @abc.abstractproperty
84    def key_size(self):
85        """
86        The bit length of the public modulus.
87        """
88
89    @abc.abstractmethod
90    def public_numbers(self):
91        """
92        Returns an RSAPublicNumbers
93        """
94
95    @abc.abstractmethod
96    def public_bytes(self, encoding, format):
97        """
98        Returns the key serialized as bytes.
99        """
100
101    @abc.abstractmethod
102    def verify(self, signature, data, padding, algorithm):
103        """
104        Verifies the signature of the data.
105        """
106
107
108RSAPublicKeyWithSerialization = RSAPublicKey
109
110
111def generate_private_key(public_exponent, key_size, backend):
112    if not isinstance(backend, RSABackend):
113        raise UnsupportedAlgorithm(
114            "Backend object does not implement RSABackend.",
115            _Reasons.BACKEND_MISSING_INTERFACE
116        )
117
118    _verify_rsa_parameters(public_exponent, key_size)
119    return backend.generate_rsa_private_key(public_exponent, key_size)
120
121
122def _verify_rsa_parameters(public_exponent, key_size):
123    if public_exponent < 3:
124        raise ValueError("public_exponent must be >= 3.")
125
126    if public_exponent & 1 == 0:
127        raise ValueError("public_exponent must be odd.")
128
129    if key_size < 512:
130        raise ValueError("key_size must be at least 512-bits.")
131
132
133def _check_private_key_components(p, q, private_exponent, dmp1, dmq1, iqmp,
134                                  public_exponent, modulus):
135    if modulus < 3:
136        raise ValueError("modulus must be >= 3.")
137
138    if p >= modulus:
139        raise ValueError("p must be < modulus.")
140
141    if q >= modulus:
142        raise ValueError("q must be < modulus.")
143
144    if dmp1 >= modulus:
145        raise ValueError("dmp1 must be < modulus.")
146
147    if dmq1 >= modulus:
148        raise ValueError("dmq1 must be < modulus.")
149
150    if iqmp >= modulus:
151        raise ValueError("iqmp must be < modulus.")
152
153    if private_exponent >= modulus:
154        raise ValueError("private_exponent must be < modulus.")
155
156    if public_exponent < 3 or public_exponent >= modulus:
157        raise ValueError("public_exponent must be >= 3 and < modulus.")
158
159    if public_exponent & 1 == 0:
160        raise ValueError("public_exponent must be odd.")
161
162    if dmp1 & 1 == 0:
163        raise ValueError("dmp1 must be odd.")
164
165    if dmq1 & 1 == 0:
166        raise ValueError("dmq1 must be odd.")
167
168    if p * q != modulus:
169        raise ValueError("p*q must equal modulus.")
170
171
172def _check_public_key_components(e, n):
173    if n < 3:
174        raise ValueError("n must be >= 3.")
175
176    if e < 3 or e >= n:
177        raise ValueError("e must be >= 3 and < n.")
178
179    if e & 1 == 0:
180        raise ValueError("e must be odd.")
181
182
183def _modinv(e, m):
184    """
185    Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
186    """
187    x1, y1, x2, y2 = 1, 0, 0, 1
188    a, b = e, m
189    while b > 0:
190        q, r = divmod(a, b)
191        xn, yn = x1 - q * x2, y1 - q * y2
192        a, b, x1, y1, x2, y2 = b, r, x2, y2, xn, yn
193    return x1 % m
194
195
196def rsa_crt_iqmp(p, q):
197    """
198    Compute the CRT (q ** -1) % p value from RSA primes p and q.
199    """
200    return _modinv(q, p)
201
202
203def rsa_crt_dmp1(private_exponent, p):
204    """
205    Compute the CRT private_exponent % (p - 1) value from the RSA
206    private_exponent (d) and p.
207    """
208    return private_exponent % (p - 1)
209
210
211def rsa_crt_dmq1(private_exponent, q):
212    """
213    Compute the CRT private_exponent % (q - 1) value from the RSA
214    private_exponent (d) and q.
215    """
216    return private_exponent % (q - 1)
217
218
219# Controls the number of iterations rsa_recover_prime_factors will perform
220# to obtain the prime factors. Each iteration increments by 2 so the actual
221# maximum attempts is half this number.
222_MAX_RECOVERY_ATTEMPTS = 1000
223
224
225def rsa_recover_prime_factors(n, e, d):
226    """
227    Compute factors p and q from the private exponent d. We assume that n has
228    no more than two factors. This function is adapted from code in PyCrypto.
229    """
230    # See 8.2.2(i) in Handbook of Applied Cryptography.
231    ktot = d * e - 1
232    # The quantity d*e-1 is a multiple of phi(n), even,
233    # and can be represented as t*2^s.
234    t = ktot
235    while t % 2 == 0:
236        t = t // 2
237    # Cycle through all multiplicative inverses in Zn.
238    # The algorithm is non-deterministic, but there is a 50% chance
239    # any candidate a leads to successful factoring.
240    # See "Digitalized Signatures and Public Key Functions as Intractable
241    # as Factorization", M. Rabin, 1979
242    spotted = False
243    a = 2
244    while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
245        k = t
246        # Cycle through all values a^{t*2^i}=a^k
247        while k < ktot:
248            cand = pow(a, k, n)
249            # Check if a^k is a non-trivial root of unity (mod n)
250            if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
251                # We have found a number such that (cand-1)(cand+1)=0 (mod n).
252                # Either of the terms divides n.
253                p = gcd(cand + 1, n)
254                spotted = True
255                break
256            k *= 2
257        # This value was not any good... let's try another!
258        a += 2
259    if not spotted:
260        raise ValueError("Unable to compute factors p and q from exponent d.")
261    # Found !
262    q, r = divmod(n, p)
263    assert r == 0
264    p, q = sorted((p, q), reverse=True)
265    return (p, q)
266
267
268class RSAPrivateNumbers(object):
269    def __init__(self, p, q, d, dmp1, dmq1, iqmp,
270                 public_numbers):
271        if (
272            not isinstance(p, six.integer_types) or
273            not isinstance(q, six.integer_types) or
274            not isinstance(d, six.integer_types) or
275            not isinstance(dmp1, six.integer_types) or
276            not isinstance(dmq1, six.integer_types) or
277            not isinstance(iqmp, six.integer_types)
278        ):
279            raise TypeError(
280                "RSAPrivateNumbers p, q, d, dmp1, dmq1, iqmp arguments must"
281                " all be an integers."
282            )
283
284        if not isinstance(public_numbers, RSAPublicNumbers):
285            raise TypeError(
286                "RSAPrivateNumbers public_numbers must be an RSAPublicNumbers"
287                " instance."
288            )
289
290        self._p = p
291        self._q = q
292        self._d = d
293        self._dmp1 = dmp1
294        self._dmq1 = dmq1
295        self._iqmp = iqmp
296        self._public_numbers = public_numbers
297
298    p = utils.read_only_property("_p")
299    q = utils.read_only_property("_q")
300    d = utils.read_only_property("_d")
301    dmp1 = utils.read_only_property("_dmp1")
302    dmq1 = utils.read_only_property("_dmq1")
303    iqmp = utils.read_only_property("_iqmp")
304    public_numbers = utils.read_only_property("_public_numbers")
305
306    def private_key(self, backend):
307        return backend.load_rsa_private_numbers(self)
308
309    def __eq__(self, other):
310        if not isinstance(other, RSAPrivateNumbers):
311            return NotImplemented
312
313        return (
314            self.p == other.p and
315            self.q == other.q and
316            self.d == other.d and
317            self.dmp1 == other.dmp1 and
318            self.dmq1 == other.dmq1 and
319            self.iqmp == other.iqmp and
320            self.public_numbers == other.public_numbers
321        )
322
323    def __ne__(self, other):
324        return not self == other
325
326    def __hash__(self):
327        return hash((
328            self.p,
329            self.q,
330            self.d,
331            self.dmp1,
332            self.dmq1,
333            self.iqmp,
334            self.public_numbers,
335        ))
336
337
338class RSAPublicNumbers(object):
339    def __init__(self, e, n):
340        if (
341            not isinstance(e, six.integer_types) or
342            not isinstance(n, six.integer_types)
343        ):
344            raise TypeError("RSAPublicNumbers arguments must be integers.")
345
346        self._e = e
347        self._n = n
348
349    e = utils.read_only_property("_e")
350    n = utils.read_only_property("_n")
351
352    def public_key(self, backend):
353        return backend.load_rsa_public_numbers(self)
354
355    def __repr__(self):
356        return "<RSAPublicNumbers(e={0.e}, n={0.n})>".format(self)
357
358    def __eq__(self, other):
359        if not isinstance(other, RSAPublicNumbers):
360            return NotImplemented
361
362        return self.e == other.e and self.n == other.n
363
364    def __ne__(self, other):
365        return not self == other
366
367    def __hash__(self):
368        return hash((self.e, self.n))
369