1"""
2Unit tests for CLI entry points.
3"""
4
5from __future__ import print_function
6
7import unittest
8import sys
9import functools
10from contextlib import contextmanager
11
12import os
13from io import StringIO, BytesIO
14
15import rsa
16import rsa.cli
17import rsa.util
18from rsa._compat import PY2
19
20
21def make_buffer():
22    if PY2:
23        return BytesIO()
24    buf = StringIO()
25    buf.buffer = BytesIO()
26    return buf
27
28
29def get_bytes_out(out):
30    if PY2:
31        # Python 2.x writes 'str' to stdout
32        return out.getvalue()
33    # Python 3.x writes 'bytes' to stdout.buffer
34    return out.buffer.getvalue()
35
36
37@contextmanager
38def captured_output():
39    """Captures output to stdout and stderr"""
40
41    new_out, new_err = make_buffer(), make_buffer()
42    old_out, old_err = sys.stdout, sys.stderr
43    try:
44        sys.stdout, sys.stderr = new_out, new_err
45        yield new_out, new_err
46    finally:
47        sys.stdout, sys.stderr = old_out, old_err
48
49
50@contextmanager
51def cli_args(*new_argv):
52    """Updates sys.argv[1:] for a single test."""
53
54    old_args = sys.argv[:]
55    sys.argv[1:] = [str(arg) for arg in new_argv]
56
57    try:
58        yield
59    finally:
60        sys.argv[1:] = old_args
61
62
63def remove_if_exists(fname):
64    """Removes a file if it exists."""
65
66    if os.path.exists(fname):
67        os.unlink(fname)
68
69
70def cleanup_files(*filenames):
71    """Makes sure the files don't exist when the test runs, and deletes them afterward."""
72
73    def remove():
74        for fname in filenames:
75            remove_if_exists(fname)
76
77    def decorator(func):
78        @functools.wraps(func)
79        def wrapper(*args, **kwargs):
80            remove()
81            try:
82                return func(*args, **kwargs)
83            finally:
84                remove()
85
86        return wrapper
87
88    return decorator
89
90
91class AbstractCliTest(unittest.TestCase):
92    @classmethod
93    def setUpClass(cls):
94        # Ensure there is a key to use
95        cls.pub_key, cls.priv_key = rsa.newkeys(512)
96        cls.pub_fname = '%s.pub' % cls.__name__
97        cls.priv_fname = '%s.key' % cls.__name__
98
99        with open(cls.pub_fname, 'wb') as outfile:
100            outfile.write(cls.pub_key.save_pkcs1())
101
102        with open(cls.priv_fname, 'wb') as outfile:
103            outfile.write(cls.priv_key.save_pkcs1())
104
105    @classmethod
106    def tearDownClass(cls):
107        if hasattr(cls, 'pub_fname'):
108            remove_if_exists(cls.pub_fname)
109        if hasattr(cls, 'priv_fname'):
110            remove_if_exists(cls.priv_fname)
111
112    def assertExits(self, status_code, func, *args, **kwargs):
113        try:
114            func(*args, **kwargs)
115        except SystemExit as ex:
116            if status_code == ex.code:
117                return
118            self.fail('SystemExit() raised by %r, but exited with code %r, expected %r' % (
119                func, ex.code, status_code))
120        else:
121            self.fail('SystemExit() not raised by %r' % func)
122
123
124class KeygenTest(AbstractCliTest):
125    def test_keygen_no_args(self):
126        with cli_args():
127            self.assertExits(1, rsa.cli.keygen)
128
129    def test_keygen_priv_stdout(self):
130        with captured_output() as (out, err):
131            with cli_args(128):
132                rsa.cli.keygen()
133
134        lines = get_bytes_out(out).splitlines()
135        self.assertEqual(b'-----BEGIN RSA PRIVATE KEY-----', lines[0])
136        self.assertEqual(b'-----END RSA PRIVATE KEY-----', lines[-1])
137
138        # The key size should be shown on stderr
139        self.assertTrue('128-bit key' in err.getvalue())
140
141    @cleanup_files('test_cli_privkey_out.pem')
142    def test_keygen_priv_out_pem(self):
143        with captured_output() as (out, err):
144            with cli_args('--out=test_cli_privkey_out.pem', '--form=PEM', 128):
145                rsa.cli.keygen()
146
147        # The key size should be shown on stderr
148        self.assertTrue('128-bit key' in err.getvalue())
149
150        # The output file should be shown on stderr
151        self.assertTrue('test_cli_privkey_out.pem' in err.getvalue())
152
153        # If we can load the file as PEM, it's good enough.
154        with open('test_cli_privkey_out.pem', 'rb') as pemfile:
155            rsa.PrivateKey.load_pkcs1(pemfile.read())
156
157    @cleanup_files('test_cli_privkey_out.der')
158    def test_keygen_priv_out_der(self):
159        with captured_output() as (out, err):
160            with cli_args('--out=test_cli_privkey_out.der', '--form=DER', 128):
161                rsa.cli.keygen()
162
163        # The key size should be shown on stderr
164        self.assertTrue('128-bit key' in err.getvalue())
165
166        # The output file should be shown on stderr
167        self.assertTrue('test_cli_privkey_out.der' in err.getvalue())
168
169        # If we can load the file as der, it's good enough.
170        with open('test_cli_privkey_out.der', 'rb') as derfile:
171            rsa.PrivateKey.load_pkcs1(derfile.read(), format='DER')
172
173    @cleanup_files('test_cli_privkey_out.pem', 'test_cli_pubkey_out.pem')
174    def test_keygen_pub_out_pem(self):
175        with captured_output() as (out, err):
176            with cli_args('--out=test_cli_privkey_out.pem',
177                          '--pubout=test_cli_pubkey_out.pem',
178                          '--form=PEM', 256):
179                rsa.cli.keygen()
180
181        # The key size should be shown on stderr
182        self.assertTrue('256-bit key' in err.getvalue())
183
184        # The output files should be shown on stderr
185        self.assertTrue('test_cli_privkey_out.pem' in err.getvalue())
186        self.assertTrue('test_cli_pubkey_out.pem' in err.getvalue())
187
188        # If we can load the file as PEM, it's good enough.
189        with open('test_cli_pubkey_out.pem', 'rb') as pemfile:
190            rsa.PublicKey.load_pkcs1(pemfile.read())
191
192
193class EncryptDecryptTest(AbstractCliTest):
194    def test_empty_decrypt(self):
195        with cli_args():
196            self.assertExits(1, rsa.cli.decrypt)
197
198    def test_empty_encrypt(self):
199        with cli_args():
200            self.assertExits(1, rsa.cli.encrypt)
201
202    @cleanup_files('encrypted.txt', 'cleartext.txt')
203    def test_encrypt_decrypt(self):
204        with open('cleartext.txt', 'wb') as outfile:
205            outfile.write(b'Hello cleartext RSA users!')
206
207        with cli_args('-i', 'cleartext.txt', '--out=encrypted.txt', self.pub_fname):
208            with captured_output():
209                rsa.cli.encrypt()
210
211        with cli_args('-i', 'encrypted.txt', self.priv_fname):
212            with captured_output() as (out, err):
213                rsa.cli.decrypt()
214
215        # We should have the original cleartext on stdout now.
216        output = get_bytes_out(out)
217        self.assertEqual(b'Hello cleartext RSA users!', output)
218
219    @cleanup_files('encrypted.txt', 'cleartext.txt')
220    def test_encrypt_decrypt_unhappy(self):
221        with open('cleartext.txt', 'wb') as outfile:
222            outfile.write(b'Hello cleartext RSA users!')
223
224        with cli_args('-i', 'cleartext.txt', '--out=encrypted.txt', self.pub_fname):
225            with captured_output():
226                rsa.cli.encrypt()
227
228        # Change a few bytes in the encrypted stream.
229        with open('encrypted.txt', 'r+b') as encfile:
230            encfile.seek(40)
231            encfile.write(b'hahaha')
232
233        with cli_args('-i', 'encrypted.txt', self.priv_fname):
234            with captured_output() as (out, err):
235                self.assertRaises(rsa.DecryptionError, rsa.cli.decrypt)
236
237
238class SignVerifyTest(AbstractCliTest):
239    def test_empty_verify(self):
240        with cli_args():
241            self.assertExits(1, rsa.cli.verify)
242
243    def test_empty_sign(self):
244        with cli_args():
245            self.assertExits(1, rsa.cli.sign)
246
247    @cleanup_files('signature.txt', 'cleartext.txt')
248    def test_sign_verify(self):
249        with open('cleartext.txt', 'wb') as outfile:
250            outfile.write(b'Hello RSA users!')
251
252        with cli_args('-i', 'cleartext.txt', '--out=signature.txt', self.priv_fname, 'SHA-256'):
253            with captured_output():
254                rsa.cli.sign()
255
256        with cli_args('-i', 'cleartext.txt', self.pub_fname, 'signature.txt'):
257            with captured_output() as (out, err):
258                rsa.cli.verify()
259
260        self.assertFalse(b'Verification OK' in get_bytes_out(out))
261
262    @cleanup_files('signature.txt', 'cleartext.txt')
263    def test_sign_verify_unhappy(self):
264        with open('cleartext.txt', 'wb') as outfile:
265            outfile.write(b'Hello RSA users!')
266
267        with cli_args('-i', 'cleartext.txt', '--out=signature.txt', self.priv_fname, 'SHA-256'):
268            with captured_output():
269                rsa.cli.sign()
270
271        # Change a few bytes in the cleartext file.
272        with open('cleartext.txt', 'r+b') as encfile:
273            encfile.seek(6)
274            encfile.write(b'DSA')
275
276        with cli_args('-i', 'cleartext.txt', self.pub_fname, 'signature.txt'):
277            with captured_output() as (out, err):
278                self.assertExits('Verification failed.', rsa.cli.verify)
279
280
281class PrivatePublicTest(AbstractCliTest):
282    """Test CLI command to convert a private to a public key."""
283
284    @cleanup_files('test_private_to_public.pem')
285    def test_private_to_public(self):
286
287        with cli_args('-i', self.priv_fname, '-o', 'test_private_to_public.pem'):
288            with captured_output():
289                rsa.util.private_to_public()
290
291        # Check that the key is indeed valid.
292        with open('test_private_to_public.pem', 'rb') as pemfile:
293            key = rsa.PublicKey.load_pkcs1(pemfile.read())
294
295        self.assertEqual(self.priv_key.n, key.n)
296        self.assertEqual(self.priv_key.e, key.e)
297