1import re 2 3import mock 4import six 5from six.moves import http_client 6import unittest2 7 8from apitools.base.py import credentials_lib 9from apitools.base.py import util 10 11 12def CreateUriValidator(uri_regexp, content=''): 13 def CheckUri(uri, headers=None): 14 if 'X-Google-Metadata-Request' not in headers: 15 raise ValueError('Missing required header') 16 if uri_regexp.match(uri): 17 message = content 18 status = http_client.OK 19 else: 20 message = 'Expected uri matching pattern %s' % uri_regexp.pattern 21 status = http_client.BAD_REQUEST 22 return type('HttpResponse', (object,), {'status': status})(), message 23 return CheckUri 24 25 26class CredentialsLibTest(unittest2.TestCase): 27 28 def _GetServiceCreds(self, service_account_name=None, scopes=None): 29 kwargs = {} 30 if service_account_name is not None: 31 kwargs['service_account_name'] = service_account_name 32 service_account_name = service_account_name or 'default' 33 34 def MockMetadataCalls(request_url): 35 default_scopes = scopes or ['scope1'] 36 if request_url.endswith('scopes'): 37 return six.StringIO(''.join(default_scopes)) 38 elif request_url.endswith('service-accounts'): 39 return six.StringIO(service_account_name) 40 elif request_url.endswith( 41 '/service-accounts/%s/token' % service_account_name): 42 return six.StringIO('{"access_token": "token"}') 43 self.fail('Unexpected HTTP request to %s' % request_url) 44 45 with mock.patch.object(credentials_lib, '_GceMetadataRequest', 46 side_effect=MockMetadataCalls, 47 autospec=True) as opener_mock: 48 with mock.patch.object(util, 'DetectGce', 49 autospec=True) as mock_detect: 50 mock_detect.return_value = True 51 validator = CreateUriValidator( 52 re.compile(r'.*/%s/.*' % service_account_name), 53 content='{"access_token": "token"}') 54 credentials = credentials_lib.GceAssertionCredentials( 55 scopes, **kwargs) 56 self.assertIsNone(credentials._refresh(validator)) 57 self.assertEqual(3, opener_mock.call_count) 58 59 def testGceServiceAccounts(self): 60 scopes = ['scope1'] 61 self._GetServiceCreds() 62 self._GetServiceCreds(scopes=scopes) 63 self._GetServiceCreds(service_account_name='my_service_account', 64 scopes=scopes) 65 66 67class TestGetRunFlowFlags(unittest2.TestCase): 68 69 def setUp(self): 70 self._flags_actual = credentials_lib.FLAGS 71 72 def tearDown(self): 73 credentials_lib.FLAGS = self._flags_actual 74 75 def test_with_gflags(self): 76 HOST = 'myhostname' 77 PORT = '144169' 78 79 class MockFlags(object): 80 auth_host_name = HOST 81 auth_host_port = PORT 82 auth_local_webserver = False 83 84 credentials_lib.FLAGS = MockFlags 85 flags = credentials_lib._GetRunFlowFlags([ 86 '--auth_host_name=%s' % HOST, 87 '--auth_host_port=%s' % PORT, 88 '--noauth_local_webserver', 89 ]) 90 self.assertEqual(flags.auth_host_name, HOST) 91 self.assertEqual(flags.auth_host_port, PORT) 92 self.assertEqual(flags.logging_level, 'ERROR') 93 self.assertEqual(flags.noauth_local_webserver, True) 94 95 def test_without_gflags(self): 96 credentials_lib.FLAGS = None 97 flags = credentials_lib._GetRunFlowFlags([]) 98 self.assertEqual(flags.auth_host_name, 'localhost') 99 self.assertEqual(flags.auth_host_port, [8080, 8090]) 100 self.assertEqual(flags.logging_level, 'ERROR') 101 self.assertEqual(flags.noauth_local_webserver, False) 102