1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Oauth2client tests.
16
17Unit tests for service account credentials implemented using RSA.
18"""
19
20import datetime
21import json
22import os
23import tempfile
24
25import httplib2
26import mock
27import rsa
28from six import BytesIO
29import unittest2
30
31from oauth2client import client
32from oauth2client import crypt
33from oauth2client import service_account
34from .http_mock import HttpMockSequence
35
36
37def data_filename(filename):
38    return os.path.join(os.path.dirname(__file__), 'data', filename)
39
40
41def datafile(filename):
42    with open(data_filename(filename), 'rb') as file_obj:
43        return file_obj.read()
44
45
46class ServiceAccountCredentialsTests(unittest2.TestCase):
47
48    def setUp(self):
49        self.client_id = '123'
50        self.service_account_email = 'dummy@google.com'
51        self.private_key_id = 'ABCDEF'
52        self.private_key = datafile('pem_from_pkcs12.pem')
53        self.scopes = ['dummy_scope']
54        self.signer = crypt.Signer.from_string(self.private_key)
55        self.credentials = service_account.ServiceAccountCredentials(
56            self.service_account_email,
57            self.signer,
58            private_key_id=self.private_key_id,
59            client_id=self.client_id,
60        )
61
62    def test__to_json_override(self):
63        signer = object()
64        creds = service_account.ServiceAccountCredentials(
65            'name@email.com', signer)
66        self.assertEqual(creds._signer, signer)
67        # Serialize over-ridden data (unrelated to ``creds``).
68        to_serialize = {'unrelated': 'data'}
69        serialized_str = creds._to_json([], to_serialize.copy())
70        serialized_data = json.loads(serialized_str)
71        expected_serialized = {
72            '_class': 'ServiceAccountCredentials',
73            '_module': 'oauth2client.service_account',
74            'token_expiry': None,
75        }
76        expected_serialized.update(to_serialize)
77        self.assertEqual(serialized_data, expected_serialized)
78
79    def test_sign_blob(self):
80        private_key_id, signature = self.credentials.sign_blob('Google')
81        self.assertEqual(self.private_key_id, private_key_id)
82
83        pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
84            datafile('publickey_openssl.pem'))
85
86        self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
87
88        with self.assertRaises(rsa.pkcs1.VerificationError):
89            rsa.pkcs1.verify(b'Orest', signature, pub_key)
90        with self.assertRaises(rsa.pkcs1.VerificationError):
91            rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
92
93    def test_service_account_email(self):
94        self.assertEqual(self.service_account_email,
95                         self.credentials.service_account_email)
96
97    @staticmethod
98    def _from_json_keyfile_name_helper(payload, scopes=None,
99                                       token_uri=None, revoke_uri=None):
100        filehandle, filename = tempfile.mkstemp()
101        os.close(filehandle)
102        try:
103            with open(filename, 'w') as file_obj:
104                json.dump(payload, file_obj)
105            return (
106                service_account.ServiceAccountCredentials
107                .from_json_keyfile_name(
108                    filename, scopes=scopes, token_uri=token_uri,
109                    revoke_uri=revoke_uri))
110        finally:
111            os.remove(filename)
112
113    @mock.patch('oauth2client.crypt.Signer.from_string',
114                return_value=object())
115    def test_from_json_keyfile_name_factory(self, signer_factory):
116        client_id = 'id123'
117        client_email = 'foo@bar.com'
118        private_key_id = 'pkid456'
119        private_key = 's3kr3tz'
120        payload = {
121            'type': client.SERVICE_ACCOUNT,
122            'client_id': client_id,
123            'client_email': client_email,
124            'private_key_id': private_key_id,
125            'private_key': private_key,
126        }
127        scopes = ['foo', 'bar']
128        token_uri = 'baz'
129        revoke_uri = 'qux'
130        base_creds = self._from_json_keyfile_name_helper(
131            payload, scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri)
132        self.assertEqual(base_creds._signer, signer_factory.return_value)
133        signer_factory.assert_called_once_with(private_key)
134
135        payload['token_uri'] = token_uri
136        payload['revoke_uri'] = revoke_uri
137        creds_with_uris_from_file = self._from_json_keyfile_name_helper(
138            payload, scopes=scopes)
139        for creds in (base_creds, creds_with_uris_from_file):
140            self.assertIsInstance(
141                creds, service_account.ServiceAccountCredentials)
142            self.assertEqual(creds.client_id, client_id)
143            self.assertEqual(creds._service_account_email, client_email)
144            self.assertEqual(creds._private_key_id, private_key_id)
145            self.assertEqual(creds._private_key_pkcs8_pem, private_key)
146            self.assertEqual(creds._scopes, ' '.join(scopes))
147            self.assertEqual(creds.token_uri, token_uri)
148            self.assertEqual(creds.revoke_uri, revoke_uri)
149
150    def test_from_json_keyfile_name_factory_bad_type(self):
151        type_ = 'bad-type'
152        self.assertNotEqual(type_, client.SERVICE_ACCOUNT)
153        payload = {'type': type_}
154        with self.assertRaises(ValueError):
155            self._from_json_keyfile_name_helper(payload)
156
157    def test_from_json_keyfile_name_factory_missing_field(self):
158        payload = {
159            'type': client.SERVICE_ACCOUNT,
160            'client_id': 'my-client',
161        }
162        with self.assertRaises(KeyError):
163            self._from_json_keyfile_name_helper(payload)
164
165    def _from_p12_keyfile_helper(self, private_key_password=None, scopes='',
166                                 token_uri=None, revoke_uri=None):
167        service_account_email = 'name@email.com'
168        filename = data_filename('privatekey.p12')
169        with open(filename, 'rb') as file_obj:
170            key_contents = file_obj.read()
171        creds_from_filename = (
172            service_account.ServiceAccountCredentials.from_p12_keyfile(
173                service_account_email, filename,
174                private_key_password=private_key_password,
175                scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri))
176        creds_from_file_contents = (
177            service_account.ServiceAccountCredentials.from_p12_keyfile_buffer(
178                service_account_email, BytesIO(key_contents),
179                private_key_password=private_key_password,
180                scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri))
181        for creds in (creds_from_filename, creds_from_file_contents):
182            self.assertIsInstance(
183                creds, service_account.ServiceAccountCredentials)
184            self.assertIsNone(creds.client_id)
185            self.assertEqual(creds._service_account_email,
186                             service_account_email)
187            self.assertIsNone(creds._private_key_id)
188            self.assertIsNone(creds._private_key_pkcs8_pem)
189            self.assertEqual(creds._private_key_pkcs12, key_contents)
190            if private_key_password is not None:
191                self.assertEqual(creds._private_key_password,
192                                 private_key_password)
193            self.assertEqual(creds._scopes, ' '.join(scopes))
194            self.assertEqual(creds.token_uri, token_uri)
195            self.assertEqual(creds.revoke_uri, revoke_uri)
196
197    def _p12_not_implemented_helper(self):
198        service_account_email = 'name@email.com'
199        filename = data_filename('privatekey.p12')
200        with self.assertRaises(NotImplementedError):
201            service_account.ServiceAccountCredentials.from_p12_keyfile(
202                service_account_email, filename)
203
204    @mock.patch('oauth2client.crypt.Signer', new=crypt.PyCryptoSigner)
205    def test_from_p12_keyfile_with_pycrypto(self):
206        self._p12_not_implemented_helper()
207
208    @mock.patch('oauth2client.crypt.Signer', new=crypt.RsaSigner)
209    def test_from_p12_keyfile_with_rsa(self):
210        self._p12_not_implemented_helper()
211
212    def test_from_p12_keyfile_defaults(self):
213        self._from_p12_keyfile_helper()
214
215    def test_from_p12_keyfile_explicit(self):
216        password = 'notasecret'
217        self._from_p12_keyfile_helper(private_key_password=password,
218                                      scopes=['foo', 'bar'],
219                                      token_uri='baz', revoke_uri='qux')
220
221    def test_create_scoped_required_without_scopes(self):
222        self.assertTrue(self.credentials.create_scoped_required())
223
224    def test_create_scoped_required_with_scopes(self):
225        signer = object()
226        self.credentials = service_account.ServiceAccountCredentials(
227            self.service_account_email,
228            signer,
229            scopes=self.scopes,
230            private_key_id=self.private_key_id,
231            client_id=self.client_id,
232        )
233        self.assertFalse(self.credentials.create_scoped_required())
234
235    def test_create_scoped(self):
236        new_credentials = self.credentials.create_scoped(self.scopes)
237        self.assertNotEqual(self.credentials, new_credentials)
238        self.assertIsInstance(new_credentials,
239                              service_account.ServiceAccountCredentials)
240        self.assertEqual('dummy_scope', new_credentials._scopes)
241
242    def test_create_delegated(self):
243        signer = object()
244        sub = 'foo@email.com'
245        creds = service_account.ServiceAccountCredentials(
246            'name@email.com', signer)
247        self.assertNotIn('sub', creds._kwargs)
248        delegated_creds = creds.create_delegated(sub)
249        self.assertEqual(delegated_creds._kwargs['sub'], sub)
250        # Make sure the original is unchanged.
251        self.assertNotIn('sub', creds._kwargs)
252
253    def test_create_delegated_existing_sub(self):
254        signer = object()
255        sub1 = 'existing@email.com'
256        sub2 = 'new@email.com'
257        creds = service_account.ServiceAccountCredentials(
258            'name@email.com', signer, sub=sub1)
259        self.assertEqual(creds._kwargs['sub'], sub1)
260        delegated_creds = creds.create_delegated(sub2)
261        self.assertEqual(delegated_creds._kwargs['sub'], sub2)
262        # Make sure the original is unchanged.
263        self.assertEqual(creds._kwargs['sub'], sub1)
264
265    @mock.patch('oauth2client.client._UTCNOW')
266    def test_access_token(self, utcnow):
267        # Configure the patch.
268        seconds = 11
269        NOW = datetime.datetime(1992, 12, 31, second=seconds)
270        utcnow.return_value = NOW
271
272        # Create a custom credentials with a mock signer.
273        signer = mock.MagicMock()
274        signed_value = b'signed-content'
275        signer.sign = mock.MagicMock(name='sign',
276                                     return_value=signed_value)
277        credentials = service_account.ServiceAccountCredentials(
278            self.service_account_email,
279            signer,
280            private_key_id=self.private_key_id,
281            client_id=self.client_id,
282        )
283
284        # Begin testing.
285        lifetime = 2  # number of seconds in which the token expires
286        EXPIRY_TIME = datetime.datetime(1992, 12, 31,
287                                        second=seconds + lifetime)
288
289        token1 = u'first_token'
290        token_response_first = {
291            'access_token': token1,
292            'expires_in': lifetime,
293        }
294        token2 = u'second_token'
295        token_response_second = {
296            'access_token': token2,
297            'expires_in': lifetime,
298        }
299        http = HttpMockSequence([
300            ({'status': '200'},
301             json.dumps(token_response_first).encode('utf-8')),
302            ({'status': '200'},
303             json.dumps(token_response_second).encode('utf-8')),
304        ])
305
306        # Get Access Token, First attempt.
307        self.assertIsNone(credentials.access_token)
308        self.assertFalse(credentials.access_token_expired)
309        self.assertIsNone(credentials.token_expiry)
310        token = credentials.get_access_token(http=http)
311        self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
312        self.assertEqual(token1, token.access_token)
313        self.assertEqual(lifetime, token.expires_in)
314        self.assertEqual(token_response_first,
315                         credentials.token_response)
316        # Two utcnow calls are expected:
317        # - get_access_token() -> _do_refresh_request (setting expires in)
318        # - get_access_token() -> _expires_in()
319        expected_utcnow_calls = [mock.call()] * 2
320        self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
321        # One call to sign() expected: Actual refresh was needed.
322        self.assertEqual(len(signer.sign.mock_calls), 1)
323
324        # Get Access Token, Second Attempt (not expired)
325        self.assertEqual(credentials.access_token, token1)
326        self.assertFalse(credentials.access_token_expired)
327        token = credentials.get_access_token(http=http)
328        # Make sure no refresh occurred since the token was not expired.
329        self.assertEqual(token1, token.access_token)
330        self.assertEqual(lifetime, token.expires_in)
331        self.assertEqual(token_response_first, credentials.token_response)
332        # Three more utcnow calls are expected:
333        # - access_token_expired
334        # - get_access_token() -> access_token_expired
335        # - get_access_token -> _expires_in
336        expected_utcnow_calls = [mock.call()] * (2 + 3)
337        self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
338        # No call to sign() expected: the token was not expired.
339        self.assertEqual(len(signer.sign.mock_calls), 1 + 0)
340
341        # Get Access Token, Third Attempt (force expiration)
342        self.assertEqual(credentials.access_token, token1)
343        credentials.token_expiry = NOW  # Manually force expiry.
344        self.assertTrue(credentials.access_token_expired)
345        token = credentials.get_access_token(http=http)
346        # Make sure refresh occurred since the token was not expired.
347        self.assertEqual(token2, token.access_token)
348        self.assertEqual(lifetime, token.expires_in)
349        self.assertFalse(credentials.access_token_expired)
350        self.assertEqual(token_response_second,
351                         credentials.token_response)
352        # Five more utcnow calls are expected:
353        # - access_token_expired
354        # - get_access_token -> access_token_expired
355        # - get_access_token -> _do_refresh_request
356        # - get_access_token -> _expires_in
357        # - access_token_expired
358        expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
359        self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
360        # One more call to sign() expected: Actual refresh was needed.
361        self.assertEqual(len(signer.sign.mock_calls), 1 + 0 + 1)
362
363        self.assertEqual(credentials.access_token, token2)
364
365TOKEN_LIFE = service_account._JWTAccessCredentials._MAX_TOKEN_LIFETIME_SECS
366T1 = 42
367T1_DATE = datetime.datetime(1970, 1, 1, second=T1)
368T1_EXPIRY = T1 + TOKEN_LIFE
369T1_EXPIRY_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE)
370
371T2 = T1 + 100
372T2_DATE = T1_DATE + datetime.timedelta(seconds=100)
373T2_EXPIRY = T2 + TOKEN_LIFE
374T2_EXPIRY_DATE = T2_DATE + datetime.timedelta(seconds=TOKEN_LIFE)
375
376T3 = T1 + TOKEN_LIFE + 1
377T3_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE + 1)
378T3_EXPIRY = T3 + TOKEN_LIFE
379T3_EXPIRY_DATE = T3_DATE + datetime.timedelta(seconds=TOKEN_LIFE)
380
381
382class JWTAccessCredentialsTests(unittest2.TestCase):
383
384    def setUp(self):
385        self.client_id = '123'
386        self.service_account_email = 'dummy@google.com'
387        self.private_key_id = 'ABCDEF'
388        self.private_key = datafile('pem_from_pkcs12.pem')
389        self.signer = crypt.Signer.from_string(self.private_key)
390        self.url = 'https://test.url.com'
391        self.jwt = service_account._JWTAccessCredentials(
392            self.service_account_email, self.signer,
393            private_key_id=self.private_key_id, client_id=self.client_id,
394            additional_claims={'aud': self.url})
395
396    @mock.patch('oauth2client.client._UTCNOW')
397    @mock.patch('time.time')
398    def test_get_access_token_no_claims(self, time, utcnow):
399        utcnow.return_value = T1_DATE
400        time.return_value = T1
401
402        token_info = self.jwt.get_access_token()
403        payload = crypt.verify_signed_jwt_with_certs(
404            token_info.access_token,
405            {'key': datafile('public_cert.pem')}, audience=self.url)
406        self.assertEqual(payload['iss'], self.service_account_email)
407        self.assertEqual(payload['sub'], self.service_account_email)
408        self.assertEqual(payload['iat'], T1)
409        self.assertEqual(payload['exp'], T1_EXPIRY)
410        self.assertEqual(token_info.expires_in, T1_EXPIRY - T1)
411
412        # Verify that we vend the same token after 100 seconds
413        utcnow.return_value = T2_DATE
414        token_info = self.jwt.get_access_token()
415        payload = crypt.verify_signed_jwt_with_certs(
416            token_info.access_token,
417            {'key': datafile('public_cert.pem')}, audience=self.url)
418        self.assertEqual(payload['iat'], T1)
419        self.assertEqual(payload['exp'], T1_EXPIRY)
420        self.assertEqual(token_info.expires_in, T1_EXPIRY - T2)
421
422        # Verify that we vend a new token after _MAX_TOKEN_LIFETIME_SECS
423        utcnow.return_value = T3_DATE
424        time.return_value = T3
425        token_info = self.jwt.get_access_token()
426        payload = crypt.verify_signed_jwt_with_certs(
427            token_info.access_token,
428            {'key': datafile('public_cert.pem')}, audience=self.url)
429        expires_in = token_info.expires_in
430        self.assertEqual(payload['iat'], T3)
431        self.assertEqual(payload['exp'], T3_EXPIRY)
432        self.assertEqual(expires_in, T3_EXPIRY - T3)
433
434    @mock.patch('oauth2client.client._UTCNOW')
435    @mock.patch('time.time')
436    def test_get_access_token_additional_claims(self, time, utcnow):
437        utcnow.return_value = T1_DATE
438        time.return_value = T1
439
440        token_info = self.jwt.get_access_token(
441            additional_claims={'aud': 'https://test2.url.com',
442                               'sub': 'dummy2@google.com'
443                               })
444        payload = crypt.verify_signed_jwt_with_certs(
445            token_info.access_token,
446            {'key': datafile('public_cert.pem')},
447            audience='https://test2.url.com')
448        expires_in = token_info.expires_in
449        self.assertEqual(payload['iss'], self.service_account_email)
450        self.assertEqual(payload['sub'], 'dummy2@google.com')
451        self.assertEqual(payload['iat'], T1)
452        self.assertEqual(payload['exp'], T1_EXPIRY)
453        self.assertEqual(expires_in, T1_EXPIRY - T1)
454
455    def test_revoke(self):
456        self.jwt.revoke(None)
457
458    def test_create_scoped_required(self):
459        self.assertTrue(self.jwt.create_scoped_required())
460
461    def test_create_scoped(self):
462        self.jwt._private_key_pkcs12 = ''
463        self.jwt._private_key_password = ''
464
465        new_credentials = self.jwt.create_scoped('dummy_scope')
466        self.assertNotEqual(self.jwt, new_credentials)
467        self.assertIsInstance(
468            new_credentials, service_account.ServiceAccountCredentials)
469        self.assertEqual('dummy_scope', new_credentials._scopes)
470
471    @mock.patch('oauth2client.client._UTCNOW')
472    @mock.patch('time.time')
473    def test_authorize_success(self, time, utcnow):
474        utcnow.return_value = T1_DATE
475        time.return_value = T1
476
477        def mock_request(uri, method='GET', body=None, headers=None,
478                         redirections=0, connection_type=None):
479            self.assertEqual(uri, self.url)
480            bearer, token = headers[b'Authorization'].split()
481            payload = crypt.verify_signed_jwt_with_certs(
482                token,
483                {'key': datafile('public_cert.pem')},
484                audience=self.url)
485            self.assertEqual(payload['iss'], self.service_account_email)
486            self.assertEqual(payload['sub'], self.service_account_email)
487            self.assertEqual(payload['iat'], T1)
488            self.assertEqual(payload['exp'], T1_EXPIRY)
489            self.assertEqual(uri, self.url)
490            self.assertEqual(bearer, b'Bearer')
491            return (httplib2.Response({'status': '200'}), b'')
492
493        h = httplib2.Http()
494        h.request = mock_request
495        self.jwt.authorize(h)
496        h.request(self.url)
497
498        # Ensure we use the cached token
499        utcnow.return_value = T2_DATE
500        h.request(self.url)
501
502    @mock.patch('oauth2client.client._UTCNOW')
503    @mock.patch('time.time')
504    def test_authorize_no_aud(self, time, utcnow):
505        utcnow.return_value = T1_DATE
506        time.return_value = T1
507
508        jwt = service_account._JWTAccessCredentials(
509            self.service_account_email, self.signer,
510            private_key_id=self.private_key_id, client_id=self.client_id)
511
512        def mock_request(uri, method='GET', body=None, headers=None,
513                         redirections=0, connection_type=None):
514            self.assertEqual(uri, self.url)
515            bearer, token = headers[b'Authorization'].split()
516            payload = crypt.verify_signed_jwt_with_certs(
517                token,
518                {'key': datafile('public_cert.pem')},
519                audience=self.url)
520            self.assertEqual(payload['iss'], self.service_account_email)
521            self.assertEqual(payload['sub'], self.service_account_email)
522            self.assertEqual(payload['iat'], T1)
523            self.assertEqual(payload['exp'], T1_EXPIRY)
524            self.assertEqual(uri, self.url)
525            self.assertEqual(bearer, b'Bearer')
526            return httplib2.Response({'status': '200'}), b''
527
528        h = httplib2.Http()
529        h.request = mock_request
530        jwt.authorize(h)
531        h.request(self.url)
532
533        # Ensure we do not cache the token
534        self.assertIsNone(jwt.access_token)
535
536    @mock.patch('oauth2client.client._UTCNOW')
537    def test_authorize_stale_token(self, utcnow):
538        utcnow.return_value = T1_DATE
539        # Create an initial token
540        h = HttpMockSequence([({'status': '200'}, b''),
541                              ({'status': '200'}, b'')])
542        self.jwt.authorize(h)
543        h.request(self.url)
544        token_1 = self.jwt.access_token
545
546        # Expire the token
547        utcnow.return_value = T3_DATE
548        h.request(self.url)
549        token_2 = self.jwt.access_token
550        self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE)
551        self.assertNotEqual(token_1, token_2)
552
553    @mock.patch('oauth2client.client._UTCNOW')
554    def test_authorize_401(self, utcnow):
555        utcnow.return_value = T1_DATE
556
557        h = HttpMockSequence([
558            ({'status': '200'}, b''),
559            ({'status': '401'}, b''),
560            ({'status': '200'}, b'')])
561        self.jwt.authorize(h)
562        h.request(self.url)
563        token_1 = self.jwt.access_token
564
565        utcnow.return_value = T2_DATE
566        self.assertEquals(h.request(self.url)[0].status, 200)
567        token_2 = self.jwt.access_token
568        # Check the 401 forced a new token
569        self.assertNotEqual(token_1, token_2)
570
571    @mock.patch('oauth2client.client._UTCNOW')
572    def test_refresh(self, utcnow):
573        utcnow.return_value = T1_DATE
574        token_1 = self.jwt.access_token
575
576        utcnow.return_value = T2_DATE
577        self.jwt.refresh(None)
578        token_2 = self.jwt.access_token
579        self.assertEquals(self.jwt.token_expiry, T2_EXPIRY_DATE)
580        self.assertNotEqual(token_1, token_2)
581