1#!/usr/bin/env python2
2#
3# Copyright 2016 Dirkjan Ochtman.
4#
5# Permission to use, copy, modify, and/or distribute this software for any
6# purpose with or without fee is hereby granted, provided that the above
7# copyright notice and this permission notice appear in all copies.
8#
9# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
10# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
12# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
14# OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
15# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16'''
17Script to generate *ring* test file for RSA PKCS1 v1.5 signing test vectors
18from the NIST FIPS 186-4 test vectors. Takes as single argument on the
19command-line the path to the test vector file (tested with SigGen15_186-3.txt).
20
21Requires the cryptography library from pyca.
22'''
23
24from cryptography.hazmat.backends import default_backend
25from cryptography.hazmat.primitives import serialization, hashes
26from cryptography.hazmat.primitives.asymmetric import rsa, padding
27import hashlib
28import sys, copy
29import codecs
30
31DIGEST_OUTPUT_LENGTHS = {
32    'SHA1': 80,
33    'SHA256': 128,
34    'SHA384': 192,
35    'SHA512': 256
36}
37
38# Prints reasons for skipping tests
39DEBUG = False
40
41def debug(str, flag):
42    if flag:
43        sys.stderr.write(str + "\n")
44        sys.stderr.flush()
45
46def decode_hex(s):
47    decoder = codecs.getdecoder("hex_codec")
48    return decoder(s)[0]
49
50# Some fields in the input files are encoded without a leading "0", but
51# `decode_hex` requires every byte to be encoded with two hex digits.
52def from_hex(hex):
53    return decode_hex(hex if len(hex) % 2 == 0 else "0" + hex)
54
55def to_hex(bytes):
56    return ''.join('{:02x}'.format(b) for b in bytes)
57
58# Some fields in the input files are encoded without a leading "0", but the
59# *ring* test framework requires every byte to be encoded with two hex digits.
60def reformat_hex(hex):
61    return to_hex(from_hex(hex))
62
63def parse(fn, last_field):
64    '''Parse input test vector file, leaving out comments and empty lines, and
65    returns a list of self-contained test cases. Depends on the last_field
66    being the last value in each test case.'''
67    cases = []
68    with open(fn) as f:
69        cur = {}
70        for ln in f:
71            if not ln.strip():
72                continue
73            if ln[0] in {'#', '['}:
74                continue
75            name, val = ln.split('=', 1)
76            cur[name.strip()] = val.strip()
77            if name.strip() == last_field:
78                cases.append(cur)
79                cur = copy.copy(cur)
80    return cases
81
82def print_sign_test(case, n, e, d, padding_alg):
83    # Recover the prime factors and CRT numbers.
84    p, q = rsa.rsa_recover_prime_factors(n, e, d)
85    # cryptography returns p, q with p < q by default. *ring* requires
86    # p > q, so swap them here.
87    p, q = max(p, q), min(p, q)
88    dmp1 = rsa.rsa_crt_dmp1(d, p)
89    dmq1 = rsa.rsa_crt_dmq1(d, q)
90    iqmp = rsa.rsa_crt_iqmp(p, q)
91
92    # Create a private key instance.
93    pub = rsa.RSAPublicNumbers(e, n)
94
95    priv = rsa.RSAPrivateNumbers(p, q, d, dmp1, dmq1, iqmp, pub)
96    key = priv.private_key(default_backend())
97
98    msg = decode_hex(case['Msg'])
99
100    # Recalculate and compare the signature to validate our processing.
101    if padding_alg == 'PKCS#1 1.5':
102        sig = key.sign(msg, padding.PKCS1v15(),
103                       getattr(hashes, case['SHAAlg'])())
104        hex_sig = to_hex(sig)
105        assert hex_sig == case['S']
106    elif padding_alg == "PSS":
107        # PSS is randomised, can't recompute this way
108        pass
109    else:
110        print("Invalid padding algorithm")
111        quit()
112
113    # Serialize the private key in DER format.
114    der = key.private_bytes(serialization.Encoding.DER,
115                            serialization.PrivateFormat.TraditionalOpenSSL,
116                            serialization.NoEncryption())
117
118    # Print the test case data in the format used by *ring* test files.
119    print('Digest = %s' % case['SHAAlg'])
120    print('Key = %s' % to_hex(der))
121    print('Msg = %s' % reformat_hex(case['Msg']))
122
123    if padding_alg == "PSS":
124        print('Salt = %s' % reformat_hex(case['SaltVal']))
125
126    print('Sig = %s' % reformat_hex(case['S']))
127    print('Result = Pass')
128    print('')
129
130def print_verify_test(case, n, e):
131    # Create a private key instance.
132    pub = rsa.RSAPublicNumbers(e, n)
133    key = pub.public_key(default_backend())
134
135    der = key.public_bytes(serialization.Encoding.DER,
136                           serialization.PublicFormat.PKCS1)
137
138    # Print the test case data in the format used by *ring* test files.
139    print('Digest = %s' % case['SHAAlg'])
140    print('Key = %s' % to_hex(der))
141    print('Msg = %s' % reformat_hex(case['Msg']))
142    print('Sig = %s' % reformat_hex(case['S']))
143    print('Result = %s' % case['Result'])
144    print('')
145
146def main(fn, test_type, padding_alg):
147    input_file_digest = hashlib.sha384(open(fn, 'rb').read()).hexdigest()
148    # File header
149    print("# RSA %(padding_alg)s Test Vectors for FIPS 186-4 from %(fn)s in" % \
150            { "fn": fn, "padding_alg": padding_alg })
151    print("# http://csrc.nist.gov/groups/STM/cavp/documents/dss/186-3rsatestvectors.zip")
152    print("# accessible from")
153    print("# http://csrc.nist.gov/groups/STM/cavp/digital-signatures.html#test-vectors")
154    print("# with SHA-384 digest %s" % (input_file_digest))
155    print("# filtered and reformatted using %s." % __file__)
156    print("#")
157    print("# Digest = SHAAlg.")
158    if test_type == "verify":
159        print("# Key is (n, e) encoded in an ASN.1 (DER) sequence.")
160    elif test_type == "sign":
161        print("# Key is an ASN.1 (DER) RSAPrivateKey.")
162    else:
163        print("Invalid test_type: %s" % test_type)
164        quit()
165
166    print("# Sig = S.")
167    print()
168
169    num_cases = 0
170
171    # Each test type has a different field as the last entry per case
172    # For verify tests,PKCS "Result" is always the last field.
173    # Otherwise, for signing tests, it is dependent on the padding used.
174    if test_type == "verify":
175        last_field = "Result"
176    else:
177        if padding_alg == "PSS":
178            last_field = "SaltVal"
179        else:
180            last_field = "S"
181
182    for case in parse(fn, last_field):
183        if case['SHAAlg'] == 'SHA224':
184            # SHA224 not supported in *ring*.
185            debug("Skipping due to use of SHA224", DEBUG)
186            continue
187
188        if padding_alg == "PSS":
189            if case['SHAAlg'] == 'SHA1':
190                # SHA-1 with PSS not supported in *ring*.
191                debug("Skipping due to use of SHA1 and PSS.", DEBUG)
192                continue
193
194            # *ring* only supports PSS where the salt length is equal to the
195            # output length of the hash algorithm.
196            if len(case['SaltVal']) * 2 != DIGEST_OUTPUT_LENGTHS[case['SHAAlg']]:
197                debug("Skipping due to unsupported salt length.", DEBUG)
198                continue
199
200        # Read private key components.
201        n = int(case['n'], 16)
202        e = int(case['e'], 16)
203        d = int(case['d'], 16)
204
205        if test_type == 'sign':
206            if n.bit_length() // 8 < 2048 // 8:
207                debug("Skipping due to modulus length (too small).", DEBUG)
208                continue
209            if n.bit_length() > 4096:
210                debug("Skipping due to modulus length (too large).", DEBUG)
211                continue
212
213            print_sign_test(case, n, e, d, padding_alg)
214        else:
215            legacy = case['SHAAlg'] in ["SHA1", "SHA256", "SHA512"]
216            if (n.bit_length() // 8 < 2048 // 8 and not legacy) or n.bit_length() // 8 < 1024 // 8:
217                debug("Skipping due to modulus length (too small).", DEBUG)
218                continue
219            print_verify_test(case, n, e)
220
221        num_cases += 1
222
223    debug("%d test cases output." % num_cases, True)
224
225if __name__ == '__main__':
226    if len(sys.argv) != 2:
227        print("Usage:\n python %s <filename>" % sys.argv[0])
228    else:
229        fn = sys.argv[1]
230        if 'PSS' in fn:
231            pad_alg = 'PSS'
232        elif '15' in fn:
233            pad_alg = 'PKCS#1 1.5'
234        else:
235            print("Could not determine padding algorithm,")
236            quit()
237
238        if 'Gen' in fn:
239            test_type = 'sign'
240        elif 'Ver' in fn:
241            test_type = 'verify'
242        else:
243            print("Could not determine test type.")
244            quit()
245
246        main(sys.argv[1], test_type, pad_alg)
247