1#  Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14
15"""Unittest for saving and loading keys."""
16
17import base64
18import os.path
19import pickle
20import unittest
21import warnings
22from unittest import mock
23
24import rsa.key
25
26B64PRIV_DER = b'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
27PRIVATE_DER = base64.standard_b64decode(B64PRIV_DER)
28
29B64PUB_DER = b'MAwCBQDeKYlRAgMBAAE='
30PUBLIC_DER = base64.standard_b64decode(B64PUB_DER)
31
32PRIVATE_PEM = b'''\
33-----BEGIN CONFUSING STUFF-----
34Cruft before the key
35
36-----BEGIN RSA PRIVATE KEY-----
37Comment: something blah
38
39''' + B64PRIV_DER + b'''
40-----END RSA PRIVATE KEY-----
41
42Stuff after the key
43-----END CONFUSING STUFF-----
44'''
45
46CLEAN_PRIVATE_PEM = b'''\
47-----BEGIN RSA PRIVATE KEY-----
48''' + B64PRIV_DER + b'''
49-----END RSA PRIVATE KEY-----
50'''
51
52PUBLIC_PEM = b'''\
53-----BEGIN CONFUSING STUFF-----
54Cruft before the key
55
56-----BEGIN RSA PUBLIC KEY-----
57Comment: something blah
58
59''' + B64PUB_DER + b'''
60-----END RSA PUBLIC KEY-----
61
62Stuff after the key
63-----END CONFUSING STUFF-----
64'''
65
66CLEAN_PUBLIC_PEM = b'''\
67-----BEGIN RSA PUBLIC KEY-----
68''' + B64PUB_DER + b'''
69-----END RSA PUBLIC KEY-----
70'''
71
72
73class DerTest(unittest.TestCase):
74    """Test saving and loading DER keys."""
75
76    def test_load_private_key(self):
77        """Test loading private DER keys."""
78
79        key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_DER, 'DER')
80        expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
81
82        self.assertEqual(expected, key)
83        self.assertEqual(key.exp1, 55063)
84        self.assertEqual(key.exp2, 10095)
85        self.assertEqual(key.coef, 50797)
86
87    @mock.patch('pyasn1.codec.der.decoder.decode')
88    def test_load_malformed_private_key(self, der_decode):
89        """Test loading malformed private DER keys."""
90
91        # Decode returns an invalid exp2 value.
92        der_decode.return_value = (
93            [0, 3727264081, 65537, 3349121513, 65063, 57287, 55063, 0, 50797],
94            0,
95        )
96
97        with warnings.catch_warnings(record=True) as w:
98            # Always print warnings
99            warnings.simplefilter('always')
100
101            # Load 3 keys
102            for _ in range(3):
103                key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_DER, 'DER')
104
105            # Check that 3 warnings were generated.
106            self.assertEqual(3, len(w))
107
108            for warning in w:
109                self.assertTrue(issubclass(warning.category, UserWarning))
110                self.assertIn('malformed', str(warning.message))
111
112        # Check that we are creating the key with correct values
113        self.assertEqual(key.exp1, 55063)
114        self.assertEqual(key.exp2, 10095)
115        self.assertEqual(key.coef, 50797)
116
117    def test_save_private_key(self):
118        """Test saving private DER keys."""
119
120        key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
121        der = key.save_pkcs1('DER')
122
123        self.assertIsInstance(der, bytes)
124        self.assertEqual(PRIVATE_DER, der)
125
126    def test_load_public_key(self):
127        """Test loading public DER keys."""
128
129        key = rsa.key.PublicKey.load_pkcs1(PUBLIC_DER, 'DER')
130        expected = rsa.key.PublicKey(3727264081, 65537)
131
132        self.assertEqual(expected, key)
133
134    def test_save_public_key(self):
135        """Test saving public DER keys."""
136
137        key = rsa.key.PublicKey(3727264081, 65537)
138        der = key.save_pkcs1('DER')
139
140        self.assertIsInstance(der, bytes)
141        self.assertEqual(PUBLIC_DER, der)
142
143
144class PemTest(unittest.TestCase):
145    """Test saving and loading PEM keys."""
146
147    def test_load_private_key(self):
148        """Test loading private PEM files."""
149
150        key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_PEM, 'PEM')
151        expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
152
153        self.assertEqual(expected, key)
154        self.assertEqual(key.exp1, 55063)
155        self.assertEqual(key.exp2, 10095)
156        self.assertEqual(key.coef, 50797)
157
158    def test_save_private_key(self):
159        """Test saving private PEM files."""
160
161        key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
162        pem = key.save_pkcs1('PEM')
163
164        self.assertIsInstance(pem, bytes)
165        self.assertEqual(CLEAN_PRIVATE_PEM, pem)
166
167    def test_load_public_key(self):
168        """Test loading public PEM files."""
169
170        key = rsa.key.PublicKey.load_pkcs1(PUBLIC_PEM, 'PEM')
171        expected = rsa.key.PublicKey(3727264081, 65537)
172
173        self.assertEqual(expected, key)
174
175    def test_save_public_key(self):
176        """Test saving public PEM files."""
177
178        key = rsa.key.PublicKey(3727264081, 65537)
179        pem = key.save_pkcs1('PEM')
180
181        self.assertIsInstance(pem, bytes)
182        self.assertEqual(CLEAN_PUBLIC_PEM, pem)
183
184    def test_load_from_disk(self):
185        """Test loading a PEM file from disk."""
186
187        fname = os.path.join(os.path.dirname(__file__), 'private.pem')
188        with open(fname, mode='rb') as privatefile:
189            keydata = privatefile.read()
190        privkey = rsa.key.PrivateKey.load_pkcs1(keydata)
191
192        self.assertEqual(15945948582725241569, privkey.p)
193        self.assertEqual(14617195220284816877, privkey.q)
194
195
196class PickleTest(unittest.TestCase):
197    """Test saving and loading keys by pickling."""
198
199    def test_private_key(self):
200        pk = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
201
202        pickled = pickle.dumps(pk)
203        unpickled = pickle.loads(pickled)
204        self.assertEqual(pk, unpickled)
205
206    def test_public_key(self):
207        pk = rsa.key.PublicKey(3727264081, 65537)
208
209        pickled = pickle.dumps(pk)
210        unpickled = pickle.loads(pickled)
211
212        self.assertEqual(pk, unpickled)
213