1# Copyright 2014 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Oauth2client tests. 16 17Unit tests for service account credentials implemented using RSA. 18""" 19 20import datetime 21import json 22import os 23import tempfile 24 25import httplib2 26import mock 27import rsa 28from six import BytesIO 29import unittest2 30 31from oauth2client import client 32from oauth2client import crypt 33from oauth2client import service_account 34from .http_mock import HttpMockSequence 35 36 37def data_filename(filename): 38 return os.path.join(os.path.dirname(__file__), 'data', filename) 39 40 41def datafile(filename): 42 with open(data_filename(filename), 'rb') as file_obj: 43 return file_obj.read() 44 45 46class ServiceAccountCredentialsTests(unittest2.TestCase): 47 48 def setUp(self): 49 self.client_id = '123' 50 self.service_account_email = 'dummy@google.com' 51 self.private_key_id = 'ABCDEF' 52 self.private_key = datafile('pem_from_pkcs12.pem') 53 self.scopes = ['dummy_scope'] 54 self.signer = crypt.Signer.from_string(self.private_key) 55 self.credentials = service_account.ServiceAccountCredentials( 56 self.service_account_email, 57 self.signer, 58 private_key_id=self.private_key_id, 59 client_id=self.client_id, 60 ) 61 62 def test__to_json_override(self): 63 signer = object() 64 creds = service_account.ServiceAccountCredentials( 65 'name@email.com', signer) 66 self.assertEqual(creds._signer, signer) 67 # Serialize over-ridden data (unrelated to ``creds``). 68 to_serialize = {'unrelated': 'data'} 69 serialized_str = creds._to_json([], to_serialize.copy()) 70 serialized_data = json.loads(serialized_str) 71 expected_serialized = { 72 '_class': 'ServiceAccountCredentials', 73 '_module': 'oauth2client.service_account', 74 'token_expiry': None, 75 } 76 expected_serialized.update(to_serialize) 77 self.assertEqual(serialized_data, expected_serialized) 78 79 def test_sign_blob(self): 80 private_key_id, signature = self.credentials.sign_blob('Google') 81 self.assertEqual(self.private_key_id, private_key_id) 82 83 pub_key = rsa.PublicKey.load_pkcs1_openssl_pem( 84 datafile('publickey_openssl.pem')) 85 86 self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key)) 87 88 with self.assertRaises(rsa.pkcs1.VerificationError): 89 rsa.pkcs1.verify(b'Orest', signature, pub_key) 90 with self.assertRaises(rsa.pkcs1.VerificationError): 91 rsa.pkcs1.verify(b'Google', b'bad signature', pub_key) 92 93 def test_service_account_email(self): 94 self.assertEqual(self.service_account_email, 95 self.credentials.service_account_email) 96 97 @staticmethod 98 def _from_json_keyfile_name_helper(payload, scopes=None, 99 token_uri=None, revoke_uri=None): 100 filehandle, filename = tempfile.mkstemp() 101 os.close(filehandle) 102 try: 103 with open(filename, 'w') as file_obj: 104 json.dump(payload, file_obj) 105 return ( 106 service_account.ServiceAccountCredentials 107 .from_json_keyfile_name( 108 filename, scopes=scopes, token_uri=token_uri, 109 revoke_uri=revoke_uri)) 110 finally: 111 os.remove(filename) 112 113 @mock.patch('oauth2client.crypt.Signer.from_string', 114 return_value=object()) 115 def test_from_json_keyfile_name_factory(self, signer_factory): 116 client_id = 'id123' 117 client_email = 'foo@bar.com' 118 private_key_id = 'pkid456' 119 private_key = 's3kr3tz' 120 payload = { 121 'type': client.SERVICE_ACCOUNT, 122 'client_id': client_id, 123 'client_email': client_email, 124 'private_key_id': private_key_id, 125 'private_key': private_key, 126 } 127 scopes = ['foo', 'bar'] 128 token_uri = 'baz' 129 revoke_uri = 'qux' 130 base_creds = self._from_json_keyfile_name_helper( 131 payload, scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri) 132 self.assertEqual(base_creds._signer, signer_factory.return_value) 133 signer_factory.assert_called_once_with(private_key) 134 135 payload['token_uri'] = token_uri 136 payload['revoke_uri'] = revoke_uri 137 creds_with_uris_from_file = self._from_json_keyfile_name_helper( 138 payload, scopes=scopes) 139 for creds in (base_creds, creds_with_uris_from_file): 140 self.assertIsInstance( 141 creds, service_account.ServiceAccountCredentials) 142 self.assertEqual(creds.client_id, client_id) 143 self.assertEqual(creds._service_account_email, client_email) 144 self.assertEqual(creds._private_key_id, private_key_id) 145 self.assertEqual(creds._private_key_pkcs8_pem, private_key) 146 self.assertEqual(creds._scopes, ' '.join(scopes)) 147 self.assertEqual(creds.token_uri, token_uri) 148 self.assertEqual(creds.revoke_uri, revoke_uri) 149 150 def test_from_json_keyfile_name_factory_bad_type(self): 151 type_ = 'bad-type' 152 self.assertNotEqual(type_, client.SERVICE_ACCOUNT) 153 payload = {'type': type_} 154 with self.assertRaises(ValueError): 155 self._from_json_keyfile_name_helper(payload) 156 157 def test_from_json_keyfile_name_factory_missing_field(self): 158 payload = { 159 'type': client.SERVICE_ACCOUNT, 160 'client_id': 'my-client', 161 } 162 with self.assertRaises(KeyError): 163 self._from_json_keyfile_name_helper(payload) 164 165 def _from_p12_keyfile_helper(self, private_key_password=None, scopes='', 166 token_uri=None, revoke_uri=None): 167 service_account_email = 'name@email.com' 168 filename = data_filename('privatekey.p12') 169 with open(filename, 'rb') as file_obj: 170 key_contents = file_obj.read() 171 creds_from_filename = ( 172 service_account.ServiceAccountCredentials.from_p12_keyfile( 173 service_account_email, filename, 174 private_key_password=private_key_password, 175 scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri)) 176 creds_from_file_contents = ( 177 service_account.ServiceAccountCredentials.from_p12_keyfile_buffer( 178 service_account_email, BytesIO(key_contents), 179 private_key_password=private_key_password, 180 scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri)) 181 for creds in (creds_from_filename, creds_from_file_contents): 182 self.assertIsInstance( 183 creds, service_account.ServiceAccountCredentials) 184 self.assertIsNone(creds.client_id) 185 self.assertEqual(creds._service_account_email, 186 service_account_email) 187 self.assertIsNone(creds._private_key_id) 188 self.assertIsNone(creds._private_key_pkcs8_pem) 189 self.assertEqual(creds._private_key_pkcs12, key_contents) 190 if private_key_password is not None: 191 self.assertEqual(creds._private_key_password, 192 private_key_password) 193 self.assertEqual(creds._scopes, ' '.join(scopes)) 194 self.assertEqual(creds.token_uri, token_uri) 195 self.assertEqual(creds.revoke_uri, revoke_uri) 196 197 def _p12_not_implemented_helper(self): 198 service_account_email = 'name@email.com' 199 filename = data_filename('privatekey.p12') 200 with self.assertRaises(NotImplementedError): 201 service_account.ServiceAccountCredentials.from_p12_keyfile( 202 service_account_email, filename) 203 204 @mock.patch('oauth2client.crypt.Signer', new=crypt.PyCryptoSigner) 205 def test_from_p12_keyfile_with_pycrypto(self): 206 self._p12_not_implemented_helper() 207 208 @mock.patch('oauth2client.crypt.Signer', new=crypt.RsaSigner) 209 def test_from_p12_keyfile_with_rsa(self): 210 self._p12_not_implemented_helper() 211 212 def test_from_p12_keyfile_defaults(self): 213 self._from_p12_keyfile_helper() 214 215 def test_from_p12_keyfile_explicit(self): 216 password = 'notasecret' 217 self._from_p12_keyfile_helper(private_key_password=password, 218 scopes=['foo', 'bar'], 219 token_uri='baz', revoke_uri='qux') 220 221 def test_create_scoped_required_without_scopes(self): 222 self.assertTrue(self.credentials.create_scoped_required()) 223 224 def test_create_scoped_required_with_scopes(self): 225 signer = object() 226 self.credentials = service_account.ServiceAccountCredentials( 227 self.service_account_email, 228 signer, 229 scopes=self.scopes, 230 private_key_id=self.private_key_id, 231 client_id=self.client_id, 232 ) 233 self.assertFalse(self.credentials.create_scoped_required()) 234 235 def test_create_scoped(self): 236 new_credentials = self.credentials.create_scoped(self.scopes) 237 self.assertNotEqual(self.credentials, new_credentials) 238 self.assertIsInstance(new_credentials, 239 service_account.ServiceAccountCredentials) 240 self.assertEqual('dummy_scope', new_credentials._scopes) 241 242 def test_create_delegated(self): 243 signer = object() 244 sub = 'foo@email.com' 245 creds = service_account.ServiceAccountCredentials( 246 'name@email.com', signer) 247 self.assertNotIn('sub', creds._kwargs) 248 delegated_creds = creds.create_delegated(sub) 249 self.assertEqual(delegated_creds._kwargs['sub'], sub) 250 # Make sure the original is unchanged. 251 self.assertNotIn('sub', creds._kwargs) 252 253 def test_create_delegated_existing_sub(self): 254 signer = object() 255 sub1 = 'existing@email.com' 256 sub2 = 'new@email.com' 257 creds = service_account.ServiceAccountCredentials( 258 'name@email.com', signer, sub=sub1) 259 self.assertEqual(creds._kwargs['sub'], sub1) 260 delegated_creds = creds.create_delegated(sub2) 261 self.assertEqual(delegated_creds._kwargs['sub'], sub2) 262 # Make sure the original is unchanged. 263 self.assertEqual(creds._kwargs['sub'], sub1) 264 265 @mock.patch('oauth2client.client._UTCNOW') 266 def test_access_token(self, utcnow): 267 # Configure the patch. 268 seconds = 11 269 NOW = datetime.datetime(1992, 12, 31, second=seconds) 270 utcnow.return_value = NOW 271 272 # Create a custom credentials with a mock signer. 273 signer = mock.MagicMock() 274 signed_value = b'signed-content' 275 signer.sign = mock.MagicMock(name='sign', 276 return_value=signed_value) 277 credentials = service_account.ServiceAccountCredentials( 278 self.service_account_email, 279 signer, 280 private_key_id=self.private_key_id, 281 client_id=self.client_id, 282 ) 283 284 # Begin testing. 285 lifetime = 2 # number of seconds in which the token expires 286 EXPIRY_TIME = datetime.datetime(1992, 12, 31, 287 second=seconds + lifetime) 288 289 token1 = u'first_token' 290 token_response_first = { 291 'access_token': token1, 292 'expires_in': lifetime, 293 } 294 token2 = u'second_token' 295 token_response_second = { 296 'access_token': token2, 297 'expires_in': lifetime, 298 } 299 http = HttpMockSequence([ 300 ({'status': '200'}, 301 json.dumps(token_response_first).encode('utf-8')), 302 ({'status': '200'}, 303 json.dumps(token_response_second).encode('utf-8')), 304 ]) 305 306 # Get Access Token, First attempt. 307 self.assertIsNone(credentials.access_token) 308 self.assertFalse(credentials.access_token_expired) 309 self.assertIsNone(credentials.token_expiry) 310 token = credentials.get_access_token(http=http) 311 self.assertEqual(credentials.token_expiry, EXPIRY_TIME) 312 self.assertEqual(token1, token.access_token) 313 self.assertEqual(lifetime, token.expires_in) 314 self.assertEqual(token_response_first, 315 credentials.token_response) 316 # Two utcnow calls are expected: 317 # - get_access_token() -> _do_refresh_request (setting expires in) 318 # - get_access_token() -> _expires_in() 319 expected_utcnow_calls = [mock.call()] * 2 320 self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) 321 # One call to sign() expected: Actual refresh was needed. 322 self.assertEqual(len(signer.sign.mock_calls), 1) 323 324 # Get Access Token, Second Attempt (not expired) 325 self.assertEqual(credentials.access_token, token1) 326 self.assertFalse(credentials.access_token_expired) 327 token = credentials.get_access_token(http=http) 328 # Make sure no refresh occurred since the token was not expired. 329 self.assertEqual(token1, token.access_token) 330 self.assertEqual(lifetime, token.expires_in) 331 self.assertEqual(token_response_first, credentials.token_response) 332 # Three more utcnow calls are expected: 333 # - access_token_expired 334 # - get_access_token() -> access_token_expired 335 # - get_access_token -> _expires_in 336 expected_utcnow_calls = [mock.call()] * (2 + 3) 337 self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) 338 # No call to sign() expected: the token was not expired. 339 self.assertEqual(len(signer.sign.mock_calls), 1 + 0) 340 341 # Get Access Token, Third Attempt (force expiration) 342 self.assertEqual(credentials.access_token, token1) 343 credentials.token_expiry = NOW # Manually force expiry. 344 self.assertTrue(credentials.access_token_expired) 345 token = credentials.get_access_token(http=http) 346 # Make sure refresh occurred since the token was not expired. 347 self.assertEqual(token2, token.access_token) 348 self.assertEqual(lifetime, token.expires_in) 349 self.assertFalse(credentials.access_token_expired) 350 self.assertEqual(token_response_second, 351 credentials.token_response) 352 # Five more utcnow calls are expected: 353 # - access_token_expired 354 # - get_access_token -> access_token_expired 355 # - get_access_token -> _do_refresh_request 356 # - get_access_token -> _expires_in 357 # - access_token_expired 358 expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) 359 self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) 360 # One more call to sign() expected: Actual refresh was needed. 361 self.assertEqual(len(signer.sign.mock_calls), 1 + 0 + 1) 362 363 self.assertEqual(credentials.access_token, token2) 364 365TOKEN_LIFE = service_account._JWTAccessCredentials._MAX_TOKEN_LIFETIME_SECS 366T1 = 42 367T1_DATE = datetime.datetime(1970, 1, 1, second=T1) 368T1_EXPIRY = T1 + TOKEN_LIFE 369T1_EXPIRY_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE) 370 371T2 = T1 + 100 372T2_DATE = T1_DATE + datetime.timedelta(seconds=100) 373T2_EXPIRY = T2 + TOKEN_LIFE 374T2_EXPIRY_DATE = T2_DATE + datetime.timedelta(seconds=TOKEN_LIFE) 375 376T3 = T1 + TOKEN_LIFE + 1 377T3_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE + 1) 378T3_EXPIRY = T3 + TOKEN_LIFE 379T3_EXPIRY_DATE = T3_DATE + datetime.timedelta(seconds=TOKEN_LIFE) 380 381 382class JWTAccessCredentialsTests(unittest2.TestCase): 383 384 def setUp(self): 385 self.client_id = '123' 386 self.service_account_email = 'dummy@google.com' 387 self.private_key_id = 'ABCDEF' 388 self.private_key = datafile('pem_from_pkcs12.pem') 389 self.signer = crypt.Signer.from_string(self.private_key) 390 self.url = 'https://test.url.com' 391 self.jwt = service_account._JWTAccessCredentials( 392 self.service_account_email, self.signer, 393 private_key_id=self.private_key_id, client_id=self.client_id, 394 additional_claims={'aud': self.url}) 395 396 @mock.patch('oauth2client.client._UTCNOW') 397 @mock.patch('time.time') 398 def test_get_access_token_no_claims(self, time, utcnow): 399 utcnow.return_value = T1_DATE 400 time.return_value = T1 401 402 token_info = self.jwt.get_access_token() 403 payload = crypt.verify_signed_jwt_with_certs( 404 token_info.access_token, 405 {'key': datafile('public_cert.pem')}, audience=self.url) 406 self.assertEqual(payload['iss'], self.service_account_email) 407 self.assertEqual(payload['sub'], self.service_account_email) 408 self.assertEqual(payload['iat'], T1) 409 self.assertEqual(payload['exp'], T1_EXPIRY) 410 self.assertEqual(token_info.expires_in, T1_EXPIRY - T1) 411 412 # Verify that we vend the same token after 100 seconds 413 utcnow.return_value = T2_DATE 414 token_info = self.jwt.get_access_token() 415 payload = crypt.verify_signed_jwt_with_certs( 416 token_info.access_token, 417 {'key': datafile('public_cert.pem')}, audience=self.url) 418 self.assertEqual(payload['iat'], T1) 419 self.assertEqual(payload['exp'], T1_EXPIRY) 420 self.assertEqual(token_info.expires_in, T1_EXPIRY - T2) 421 422 # Verify that we vend a new token after _MAX_TOKEN_LIFETIME_SECS 423 utcnow.return_value = T3_DATE 424 time.return_value = T3 425 token_info = self.jwt.get_access_token() 426 payload = crypt.verify_signed_jwt_with_certs( 427 token_info.access_token, 428 {'key': datafile('public_cert.pem')}, audience=self.url) 429 expires_in = token_info.expires_in 430 self.assertEqual(payload['iat'], T3) 431 self.assertEqual(payload['exp'], T3_EXPIRY) 432 self.assertEqual(expires_in, T3_EXPIRY - T3) 433 434 @mock.patch('oauth2client.client._UTCNOW') 435 @mock.patch('time.time') 436 def test_get_access_token_additional_claims(self, time, utcnow): 437 utcnow.return_value = T1_DATE 438 time.return_value = T1 439 440 token_info = self.jwt.get_access_token( 441 additional_claims={'aud': 'https://test2.url.com', 442 'sub': 'dummy2@google.com' 443 }) 444 payload = crypt.verify_signed_jwt_with_certs( 445 token_info.access_token, 446 {'key': datafile('public_cert.pem')}, 447 audience='https://test2.url.com') 448 expires_in = token_info.expires_in 449 self.assertEqual(payload['iss'], self.service_account_email) 450 self.assertEqual(payload['sub'], 'dummy2@google.com') 451 self.assertEqual(payload['iat'], T1) 452 self.assertEqual(payload['exp'], T1_EXPIRY) 453 self.assertEqual(expires_in, T1_EXPIRY - T1) 454 455 def test_revoke(self): 456 self.jwt.revoke(None) 457 458 def test_create_scoped_required(self): 459 self.assertTrue(self.jwt.create_scoped_required()) 460 461 def test_create_scoped(self): 462 self.jwt._private_key_pkcs12 = '' 463 self.jwt._private_key_password = '' 464 465 new_credentials = self.jwt.create_scoped('dummy_scope') 466 self.assertNotEqual(self.jwt, new_credentials) 467 self.assertIsInstance( 468 new_credentials, service_account.ServiceAccountCredentials) 469 self.assertEqual('dummy_scope', new_credentials._scopes) 470 471 @mock.patch('oauth2client.client._UTCNOW') 472 @mock.patch('time.time') 473 def test_authorize_success(self, time, utcnow): 474 utcnow.return_value = T1_DATE 475 time.return_value = T1 476 477 def mock_request(uri, method='GET', body=None, headers=None, 478 redirections=0, connection_type=None): 479 self.assertEqual(uri, self.url) 480 bearer, token = headers[b'Authorization'].split() 481 payload = crypt.verify_signed_jwt_with_certs( 482 token, 483 {'key': datafile('public_cert.pem')}, 484 audience=self.url) 485 self.assertEqual(payload['iss'], self.service_account_email) 486 self.assertEqual(payload['sub'], self.service_account_email) 487 self.assertEqual(payload['iat'], T1) 488 self.assertEqual(payload['exp'], T1_EXPIRY) 489 self.assertEqual(uri, self.url) 490 self.assertEqual(bearer, b'Bearer') 491 return (httplib2.Response({'status': '200'}), b'') 492 493 h = httplib2.Http() 494 h.request = mock_request 495 self.jwt.authorize(h) 496 h.request(self.url) 497 498 # Ensure we use the cached token 499 utcnow.return_value = T2_DATE 500 h.request(self.url) 501 502 @mock.patch('oauth2client.client._UTCNOW') 503 @mock.patch('time.time') 504 def test_authorize_no_aud(self, time, utcnow): 505 utcnow.return_value = T1_DATE 506 time.return_value = T1 507 508 jwt = service_account._JWTAccessCredentials( 509 self.service_account_email, self.signer, 510 private_key_id=self.private_key_id, client_id=self.client_id) 511 512 def mock_request(uri, method='GET', body=None, headers=None, 513 redirections=0, connection_type=None): 514 self.assertEqual(uri, self.url) 515 bearer, token = headers[b'Authorization'].split() 516 payload = crypt.verify_signed_jwt_with_certs( 517 token, 518 {'key': datafile('public_cert.pem')}, 519 audience=self.url) 520 self.assertEqual(payload['iss'], self.service_account_email) 521 self.assertEqual(payload['sub'], self.service_account_email) 522 self.assertEqual(payload['iat'], T1) 523 self.assertEqual(payload['exp'], T1_EXPIRY) 524 self.assertEqual(uri, self.url) 525 self.assertEqual(bearer, b'Bearer') 526 return httplib2.Response({'status': '200'}), b'' 527 528 h = httplib2.Http() 529 h.request = mock_request 530 jwt.authorize(h) 531 h.request(self.url) 532 533 # Ensure we do not cache the token 534 self.assertIsNone(jwt.access_token) 535 536 @mock.patch('oauth2client.client._UTCNOW') 537 def test_authorize_stale_token(self, utcnow): 538 utcnow.return_value = T1_DATE 539 # Create an initial token 540 h = HttpMockSequence([({'status': '200'}, b''), 541 ({'status': '200'}, b'')]) 542 self.jwt.authorize(h) 543 h.request(self.url) 544 token_1 = self.jwt.access_token 545 546 # Expire the token 547 utcnow.return_value = T3_DATE 548 h.request(self.url) 549 token_2 = self.jwt.access_token 550 self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE) 551 self.assertNotEqual(token_1, token_2) 552 553 @mock.patch('oauth2client.client._UTCNOW') 554 def test_authorize_401(self, utcnow): 555 utcnow.return_value = T1_DATE 556 557 h = HttpMockSequence([ 558 ({'status': '200'}, b''), 559 ({'status': '401'}, b''), 560 ({'status': '200'}, b'')]) 561 self.jwt.authorize(h) 562 h.request(self.url) 563 token_1 = self.jwt.access_token 564 565 utcnow.return_value = T2_DATE 566 self.assertEquals(h.request(self.url)[0].status, 200) 567 token_2 = self.jwt.access_token 568 # Check the 401 forced a new token 569 self.assertNotEqual(token_1, token_2) 570 571 @mock.patch('oauth2client.client._UTCNOW') 572 def test_refresh(self, utcnow): 573 utcnow.return_value = T1_DATE 574 token_1 = self.jwt.access_token 575 576 utcnow.return_value = T2_DATE 577 self.jwt.refresh(None) 578 token_2 = self.jwt.access_token 579 self.assertEquals(self.jwt.token_expiry, T2_EXPIRY_DATE) 580 self.assertNotEqual(token_1, token_2) 581