1## This file is part of Scapy
2## Copyright (C) 2008 Arnaud Ebalard <arno@natisbad.org>
3##   2015, 2016, 2017 Maxence Tury <maxence.tury@ssi.gouv.fr>
4## This program is published under a GPLv2 license
5
6"""
7PKCS #1 methods as defined in RFC 3447.
8
9We cannot rely solely on the cryptography library, because the openssl package
10used by the cryptography library may not implement the md5-sha1 hash, as with
11Ubuntu or OSX. This is why we reluctantly keep some legacy crypto here.
12"""
13
14from __future__ import absolute_import
15from scapy.compat import *
16
17from scapy.config import conf, crypto_validator
18if conf.crypto_valid:
19    from cryptography import utils
20    from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
21    from cryptography.hazmat.backends import default_backend
22    from cryptography.hazmat.primitives import hashes
23    from cryptography.hazmat.primitives.asymmetric import padding
24    from cryptography.hazmat.primitives.hashes import HashAlgorithm
25
26from scapy.utils import randstring, zerofree_randstring, strxor, strand
27from scapy.error import warning
28
29
30#####################################################################
31# Some helpers
32#####################################################################
33
34def pkcs_os2ip(s):
35    """
36    OS2IP conversion function from RFC 3447.
37
38    Input : s        octet string to be converted
39    Output: n        corresponding nonnegative integer
40    """
41    return int(bytes_hex(s), 16)
42
43def pkcs_i2osp(n, sLen):
44    """
45    I2OSP conversion function from RFC 3447.
46    The length parameter allows the function to perform the padding needed.
47    Note that the user is responsible for providing a sufficient xLen.
48
49    Input : n        nonnegative integer to be converted
50            sLen     intended length of the resulting octet string
51    Output: s        corresponding octet string
52    """
53    #if n >= 256**sLen:
54    #    raise Exception("Integer too large for provided sLen %d" % sLen)
55    fmt = "%%0%dx" % (2*sLen)
56    return hex_bytes(fmt % n)
57
58def pkcs_ilen(n):
59    """
60    This is a log base 256 which determines the minimum octet string
61    length for unequivocal representation of integer n by pkcs_i2osp.
62    """
63    i = 0
64    while n > 0:
65        n >>= 8
66        i += 1
67    return i
68
69@crypto_validator
70def _legacy_pkcs1_v1_5_encode_md5_sha1(M, emLen):
71    """
72    Legacy method for PKCS1 v1.5 encoding with MD5-SHA1 hash.
73    """
74    M = raw(M)
75    md5_hash = hashes.Hash(_get_hash("md5"), backend=default_backend())
76    md5_hash.update(M)
77    sha1_hash = hashes.Hash(_get_hash("sha1"), backend=default_backend())
78    sha1_hash.update(M)
79    H = md5_hash.finalize() + sha1_hash.finalize()
80    if emLen < 36 + 11:
81        warning("pkcs_emsa_pkcs1_v1_5_encode: "
82                "intended encoded message length too short")
83        return None
84    PS = b'\xff'*(emLen - 36 - 3)
85    return b'\x00' + b'\x01' + PS + b'\x00' + H
86
87
88#####################################################################
89# Hash and padding helpers
90#####################################################################
91
92_get_hash = None
93if conf.crypto_valid:
94
95    # first, we add the "md5-sha1" hash from openssl to python-cryptography
96    @utils.register_interface(HashAlgorithm)
97    class MD5_SHA1(object):
98        name = "md5-sha1"
99        digest_size = 36
100        block_size = 64
101
102    _hashes = {
103            "md5"      : hashes.MD5,
104            "sha1"     : hashes.SHA1,
105            "sha224"   : hashes.SHA224,
106            "sha256"   : hashes.SHA256,
107            "sha384"   : hashes.SHA384,
108            "sha512"   : hashes.SHA512,
109            "md5-sha1" : MD5_SHA1
110            }
111
112    def _get_hash(hashStr):
113        try:
114            return _hashes[hashStr]()
115        except KeyError:
116            raise KeyError("Unknown hash function %s" % hashStr)
117
118
119    def _get_padding(padStr, mgf=padding.MGF1, h=hashes.SHA256, label=None):
120        if padStr == "pkcs":
121            return padding.PKCS1v15()
122        elif padStr == "pss":
123            # Can't find where this is written, but we have to use the digest
124            # size instead of the automatic padding.PSS.MAX_LENGTH.
125            return padding.PSS(mgf=mgf(h), salt_length=h.digest_size)
126        elif padStr == "oaep":
127            return padding.OAEP(mgf=mgf(h), algorithm=h, label=label)
128        else:
129            warning("Key.encrypt(): Unknown padding type (%s)", padStr)
130            return None
131
132
133#####################################################################
134# Asymmetric Cryptography wrappers
135#####################################################################
136
137# Make sure that default values are consistent accross the whole TLS module,
138# lest they be explicitly set to None between cert.py and pkcs1.py.
139
140class _EncryptAndVerifyRSA(object):
141
142    @crypto_validator
143    def encrypt(self, m, t="pkcs", h="sha256", mgf=None, L=None):
144        mgf = mgf or padding.MGF1
145        h = _get_hash(h)
146        pad = _get_padding(t, mgf, h, L)
147        return self.pubkey.encrypt(m, pad)
148
149    @crypto_validator
150    def verify(self, M, S, t="pkcs", h="sha256", mgf=None, L=None):
151        M = raw(M)
152        mgf = mgf or padding.MGF1
153        h = _get_hash(h)
154        pad = _get_padding(t, mgf, h, L)
155        try:
156            try:
157                self.pubkey.verify(S, M, pad, h)
158            except UnsupportedAlgorithm:
159                if t != "pkcs" and h != "md5-sha1":
160                    raise UnsupportedAlgorithm("RSA verification with %s" % h)
161                self._legacy_verify_md5_sha1(M, S)
162            return True
163        except InvalidSignature:
164            return False
165
166    def _legacy_verify_md5_sha1(self, M, S):
167        k = self._modulusLen // 8
168        if len(S) != k:
169            warning("invalid signature (len(S) != k)")
170            return False
171        s = pkcs_os2ip(S)
172        n = self._modulus
173        if isinstance(s, int) and six.PY2:
174            s = long(s)
175        if (six.PY2 and not isinstance(s, long)) or s > n-1:
176            warning("Key._rsaep() expects a long between 0 and n-1")
177            return None
178        m = pow(s, self._pubExp, n)
179        EM = pkcs_i2osp(m, k)
180        EMPrime = _legacy_pkcs1_v1_5_encode_md5_sha1(M, k)
181        if EMPrime is None:
182            warning("Key._rsassa_pkcs1_v1_5_verify(): unable to encode.")
183            return False
184        return EM == EMPrime
185
186
187class _DecryptAndSignRSA(object):
188
189    @crypto_validator
190    def decrypt(self, C, t="pkcs", h="sha256", mgf=None, L=None):
191        mgf = mgf or padding.MGF1
192        h = _get_hash(h)
193        pad = _get_padding(t, mgf, h, L)
194        return self.key.decrypt(C, pad)
195
196    @crypto_validator
197    def sign(self, M, t="pkcs", h="sha256", mgf=None, L=None):
198        M = raw(M)
199        mgf = mgf or padding.MGF1
200        h = _get_hash(h)
201        pad = _get_padding(t, mgf, h, L)
202        try:
203            return self.key.sign(M, pad, h)
204        except UnsupportedAlgorithm:
205            if t != "pkcs" and h != "md5-sha1":
206                raise UnsupportedAlgorithm("RSA signature with %s" % h)
207            return self._legacy_sign_md5_sha1(M)
208
209    def _legacy_sign_md5_sha1(self, M):
210        M = raw(M)
211        k = self._modulusLen // 8
212        EM = _legacy_pkcs1_v1_5_encode_md5_sha1(M, k)
213        if EM is None:
214            warning("Key._rsassa_pkcs1_v1_5_sign(): unable to encode")
215            return None
216        m = pkcs_os2ip(EM)
217        n = self._modulus
218        if isinstance(m, int) and six.PY2:
219            m = long(m)
220        if (six.PY2 and not isinstance(m, long)) or m > n-1:
221            warning("Key._rsaep() expects a long between 0 and n-1")
222            return None
223        privExp = self.key.private_numbers().d
224        s = pow(m, privExp, n)
225        return pkcs_i2osp(s, k)
226