1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import base64
16import os
17
18import mock
19import unittest2
20
21from oauth2client import _helpers
22from oauth2client import client
23from oauth2client import crypt
24from oauth2client import service_account
25
26
27def data_filename(filename):
28    return os.path.join(os.path.dirname(__file__), 'data', filename)
29
30
31def datafile(filename):
32    with open(data_filename(filename), 'rb') as file_obj:
33        return file_obj.read()
34
35
36class Test__bad_pkcs12_key_as_pem(unittest2.TestCase):
37
38    def test_fails(self):
39        with self.assertRaises(NotImplementedError):
40            crypt._bad_pkcs12_key_as_pem()
41
42
43class Test_pkcs12_key_as_pem(unittest2.TestCase):
44
45    def _make_svc_account_creds(self, private_key_file='privatekey.p12'):
46        filename = data_filename(private_key_file)
47        credentials = (
48            service_account.ServiceAccountCredentials.from_p12_keyfile(
49                'some_account@example.com', filename,
50                scopes='read+write'))
51        credentials._kwargs['sub'] = 'joe@example.org'
52        return credentials
53
54    def _succeeds_helper(self, password=None):
55        self.assertEqual(True, client.HAS_OPENSSL)
56
57        credentials = self._make_svc_account_creds()
58        if password is None:
59            password = credentials._private_key_password
60        pem_contents = crypt.pkcs12_key_as_pem(
61            credentials._private_key_pkcs12, password)
62        pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
63        pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
64        alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
65        self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
66
67    def test_succeeds(self):
68        self._succeeds_helper()
69
70    def test_succeeds_with_unicode_password(self):
71        password = u'notasecret'
72        self._succeeds_helper(password)
73
74
75class Test__verify_signature(unittest2.TestCase):
76
77    def test_success_single_cert(self):
78        cert_value = 'cert-value'
79        certs = [cert_value]
80        message = object()
81        signature = object()
82
83        verifier = mock.MagicMock()
84        verifier.verify = mock.MagicMock(name='verify', return_value=True)
85        with mock.patch('oauth2client.crypt.Verifier') as Verifier:
86            Verifier.from_string = mock.MagicMock(name='from_string',
87                                                  return_value=verifier)
88            result = crypt._verify_signature(message, signature, certs)
89            self.assertEqual(result, None)
90
91            # Make sure our mocks were called as expected.
92            Verifier.from_string.assert_called_once_with(cert_value,
93                                                         is_x509_cert=True)
94            verifier.verify.assert_called_once_with(message, signature)
95
96    def test_success_multiple_certs(self):
97        cert_value1 = 'cert-value1'
98        cert_value2 = 'cert-value2'
99        cert_value3 = 'cert-value3'
100        certs = [cert_value1, cert_value2, cert_value3]
101        message = object()
102        signature = object()
103
104        verifier = mock.MagicMock()
105        # Use side_effect to force all 3 cert values to be used by failing
106        # to verify on the first two.
107        verifier.verify = mock.MagicMock(name='verify',
108                                         side_effect=[False, False, True])
109        with mock.patch('oauth2client.crypt.Verifier') as Verifier:
110            Verifier.from_string = mock.MagicMock(name='from_string',
111                                                  return_value=verifier)
112            result = crypt._verify_signature(message, signature, certs)
113            self.assertEqual(result, None)
114
115            # Make sure our mocks were called three times.
116            expected_from_string_calls = [
117                mock.call(cert_value1, is_x509_cert=True),
118                mock.call(cert_value2, is_x509_cert=True),
119                mock.call(cert_value3, is_x509_cert=True),
120            ]
121            self.assertEqual(Verifier.from_string.mock_calls,
122                             expected_from_string_calls)
123            expected_verify_calls = [mock.call(message, signature)] * 3
124            self.assertEqual(verifier.verify.mock_calls,
125                             expected_verify_calls)
126
127    def test_failure(self):
128        cert_value = 'cert-value'
129        certs = [cert_value]
130        message = object()
131        signature = object()
132
133        verifier = mock.MagicMock()
134        verifier.verify = mock.MagicMock(name='verify', return_value=False)
135        with mock.patch('oauth2client.crypt.Verifier') as Verifier:
136            Verifier.from_string = mock.MagicMock(name='from_string',
137                                                  return_value=verifier)
138            with self.assertRaises(crypt.AppIdentityError):
139                crypt._verify_signature(message, signature, certs)
140
141            # Make sure our mocks were called as expected.
142            Verifier.from_string.assert_called_once_with(cert_value,
143                                                         is_x509_cert=True)
144            verifier.verify.assert_called_once_with(message, signature)
145
146
147class Test__check_audience(unittest2.TestCase):
148
149    def test_null_audience(self):
150        result = crypt._check_audience(None, None)
151        self.assertEqual(result, None)
152
153    def test_success(self):
154        audience = 'audience'
155        payload_dict = {'aud': audience}
156        result = crypt._check_audience(payload_dict, audience)
157        # No exception and no result.
158        self.assertEqual(result, None)
159
160    def test_missing_aud(self):
161        audience = 'audience'
162        payload_dict = {}
163        with self.assertRaises(crypt.AppIdentityError):
164            crypt._check_audience(payload_dict, audience)
165
166    def test_wrong_aud(self):
167        audience1 = 'audience1'
168        audience2 = 'audience2'
169        self.assertNotEqual(audience1, audience2)
170        payload_dict = {'aud': audience1}
171        with self.assertRaises(crypt.AppIdentityError):
172            crypt._check_audience(payload_dict, audience2)
173
174
175class Test__verify_time_range(unittest2.TestCase):
176
177    def _exception_helper(self, payload_dict):
178        exception_caught = None
179        try:
180            crypt._verify_time_range(payload_dict)
181        except crypt.AppIdentityError as exc:
182            exception_caught = exc
183
184        return exception_caught
185
186    def test_without_issued_at(self):
187        payload_dict = {}
188        exception_caught = self._exception_helper(payload_dict)
189        self.assertNotEqual(exception_caught, None)
190        self.assertTrue(str(exception_caught).startswith(
191            'No iat field in token'))
192
193    def test_without_expiration(self):
194        payload_dict = {'iat': 'iat'}
195        exception_caught = self._exception_helper(payload_dict)
196        self.assertNotEqual(exception_caught, None)
197        self.assertTrue(str(exception_caught).startswith(
198            'No exp field in token'))
199
200    def test_with_bad_token_lifetime(self):
201        current_time = 123456
202        payload_dict = {
203            'iat': 'iat',
204            'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS + 1,
205        }
206        with mock.patch('oauth2client.crypt.time') as time:
207            time.time = mock.MagicMock(name='time',
208                                       return_value=current_time)
209
210            exception_caught = self._exception_helper(payload_dict)
211            self.assertNotEqual(exception_caught, None)
212            self.assertTrue(str(exception_caught).startswith(
213                'exp field too far in future'))
214
215    def test_with_issued_at_in_future(self):
216        current_time = 123456
217        payload_dict = {
218            'iat': current_time + crypt.CLOCK_SKEW_SECS + 1,
219            'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1,
220        }
221        with mock.patch('oauth2client.crypt.time') as time:
222            time.time = mock.MagicMock(name='time',
223                                       return_value=current_time)
224
225            exception_caught = self._exception_helper(payload_dict)
226            self.assertNotEqual(exception_caught, None)
227            self.assertTrue(str(exception_caught).startswith(
228                'Token used too early'))
229
230    def test_with_expiration_in_the_past(self):
231        current_time = 123456
232        payload_dict = {
233            'iat': current_time,
234            'exp': current_time - crypt.CLOCK_SKEW_SECS - 1,
235        }
236        with mock.patch('oauth2client.crypt.time') as time:
237            time.time = mock.MagicMock(name='time',
238                                       return_value=current_time)
239
240            exception_caught = self._exception_helper(payload_dict)
241            self.assertNotEqual(exception_caught, None)
242            self.assertTrue(str(exception_caught).startswith(
243                'Token used too late'))
244
245    def test_success(self):
246        current_time = 123456
247        payload_dict = {
248            'iat': current_time,
249            'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1,
250        }
251        with mock.patch('oauth2client.crypt.time') as time:
252            time.time = mock.MagicMock(name='time',
253                                       return_value=current_time)
254
255            exception_caught = self._exception_helper(payload_dict)
256            self.assertEqual(exception_caught, None)
257
258
259class Test_verify_signed_jwt_with_certs(unittest2.TestCase):
260
261    def test_jwt_no_segments(self):
262        exception_caught = None
263        try:
264            crypt.verify_signed_jwt_with_certs(b'', None)
265        except crypt.AppIdentityError as exc:
266            exception_caught = exc
267
268        self.assertNotEqual(exception_caught, None)
269        self.assertTrue(str(exception_caught).startswith(
270            'Wrong number of segments in token'))
271
272    def test_jwt_payload_bad_json(self):
273        header = signature = b''
274        payload = base64.b64encode(b'{BADJSON')
275        jwt = b'.'.join([header, payload, signature])
276
277        exception_caught = None
278        try:
279            crypt.verify_signed_jwt_with_certs(jwt, None)
280        except crypt.AppIdentityError as exc:
281            exception_caught = exc
282
283        self.assertNotEqual(exception_caught, None)
284        self.assertTrue(str(exception_caught).startswith(
285            'Can\'t parse token'))
286
287    @mock.patch('oauth2client.crypt._check_audience')
288    @mock.patch('oauth2client.crypt._verify_time_range')
289    @mock.patch('oauth2client.crypt._verify_signature')
290    def test_success(self, verify_sig, verify_time, check_aud):
291        certs = mock.MagicMock()
292        cert_values = object()
293        certs.values = mock.MagicMock(name='values',
294                                      return_value=cert_values)
295        audience = object()
296
297        header = b'header'
298        signature_bytes = b'signature'
299        signature = base64.b64encode(signature_bytes)
300        payload_dict = {'a': 'b'}
301        payload = base64.b64encode(b'{"a": "b"}')
302        jwt = b'.'.join([header, payload, signature])
303
304        result = crypt.verify_signed_jwt_with_certs(
305            jwt, certs, audience=audience)
306        self.assertEqual(result, payload_dict)
307
308        message_to_sign = header + b'.' + payload
309        verify_sig.assert_called_once_with(
310            message_to_sign, signature_bytes, cert_values)
311        verify_time.assert_called_once_with(payload_dict)
312        check_aud.assert_called_once_with(payload_dict, audience)
313        certs.values.assert_called_once_with()
314