1# Copyright 2015 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for oauth2client.contrib.devshell.""" 16 17import datetime 18import json 19import os 20import socket 21import threading 22 23import mock 24import unittest2 25 26from oauth2client import _helpers 27from oauth2client import client 28from oauth2client.contrib import devshell 29 30# A dummy value to use for the expires_in field 31# in CredentialInfoResponse. 32EXPIRES_IN = 1000 33DEFAULT_CREDENTIAL_JSON = json.dumps([ 34 'joe@example.com', 35 'fooproj', 36 'sometoken', 37 EXPIRES_IN 38]) 39 40 41class TestCredentialInfoResponse(unittest2.TestCase): 42 43 def test_constructor_with_non_list(self): 44 json_non_list = '{}' 45 with self.assertRaises(ValueError): 46 devshell.CredentialInfoResponse(json_non_list) 47 48 def test_constructor_with_bad_json(self): 49 json_non_list = '{BADJSON' 50 with self.assertRaises(ValueError): 51 devshell.CredentialInfoResponse(json_non_list) 52 53 def test_constructor_empty_list(self): 54 info_response = devshell.CredentialInfoResponse('[]') 55 self.assertEqual(info_response.user_email, None) 56 self.assertEqual(info_response.project_id, None) 57 self.assertEqual(info_response.access_token, None) 58 self.assertEqual(info_response.expires_in, None) 59 60 def test_constructor_full_list(self): 61 user_email = 'user_email' 62 project_id = 'project_id' 63 access_token = 'access_token' 64 expires_in = 1 65 json_string = json.dumps( 66 [user_email, project_id, access_token, expires_in]) 67 info_response = devshell.CredentialInfoResponse(json_string) 68 self.assertEqual(info_response.user_email, user_email) 69 self.assertEqual(info_response.project_id, project_id) 70 self.assertEqual(info_response.access_token, access_token) 71 self.assertEqual(info_response.expires_in, expires_in) 72 73 74class Test_SendRecv(unittest2.TestCase): 75 76 def test_port_zero(self): 77 with mock.patch('oauth2client.contrib.devshell.os') as os_mod: 78 os_mod.getenv = mock.MagicMock(name='getenv', return_value=0) 79 with self.assertRaises(devshell.NoDevshellServer): 80 devshell._SendRecv() 81 os_mod.getenv.assert_called_once_with(devshell.DEVSHELL_ENV, 0) 82 83 def test_no_newline_in_received_header(self): 84 non_zero_port = 1 85 sock = mock.MagicMock() 86 87 header_without_newline = '' 88 sock.recv(6).decode = mock.MagicMock( 89 name='decode', return_value=header_without_newline) 90 91 with mock.patch('oauth2client.contrib.devshell.os') as os_mod: 92 os_mod.getenv = mock.MagicMock(name='getenv', 93 return_value=non_zero_port) 94 with mock.patch('oauth2client.contrib.devshell.socket') as socket: 95 socket.socket = mock.MagicMock(name='socket', 96 return_value=sock) 97 with self.assertRaises(devshell.CommunicationError): 98 devshell._SendRecv() 99 os_mod.getenv.assert_called_once_with(devshell.DEVSHELL_ENV, 0) 100 socket.socket.assert_called_once_with() 101 sock.recv(6).decode.assert_called_once_with() 102 103 data = devshell.CREDENTIAL_INFO_REQUEST_JSON 104 msg = _helpers._to_bytes( 105 '{0}\n{1}'.format(len(data), data), encoding='utf-8') 106 expected_sock_calls = [ 107 mock.call.recv(6), # From the set-up above 108 mock.call.connect(('localhost', non_zero_port)), 109 mock.call.sendall(msg), 110 mock.call.recv(6), 111 mock.call.recv(6), # From the check above 112 ] 113 self.assertEqual(sock.method_calls, expected_sock_calls) 114 115 116class _AuthReferenceServer(threading.Thread): 117 118 def __init__(self, response=None): 119 super(_AuthReferenceServer, self).__init__(None) 120 self.response = response or DEFAULT_CREDENTIAL_JSON 121 self.bad_request = False 122 123 def __enter__(self): 124 return self.start_server() 125 126 def start_server(self): 127 self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 128 self._socket.bind(('localhost', 0)) 129 port = self._socket.getsockname()[1] 130 os.environ[devshell.DEVSHELL_ENV] = str(port) 131 self._socket.listen(0) 132 self.daemon = True 133 self.start() 134 return self 135 136 def __exit__(self, e_type, value, traceback): 137 self.stop_server() 138 139 def stop_server(self): 140 del os.environ[devshell.DEVSHELL_ENV] 141 self._socket.close() 142 143 def run(self): 144 s = None 145 try: 146 # Do not set the timeout on the socket, leave it in the blocking 147 # mode as setting the timeout seems to cause spurious EAGAIN 148 # errors on OSX. 149 self._socket.settimeout(None) 150 151 s, unused_addr = self._socket.accept() 152 resp_buffer = '' 153 resp_1 = s.recv(6).decode() 154 nstr, extra = resp_1.split('\n', 1) 155 resp_buffer = extra 156 n = int(nstr) 157 to_read = n - len(extra) 158 if to_read > 0: 159 resp_buffer += _helpers._from_bytes( 160 s.recv(to_read, socket.MSG_WAITALL)) 161 if resp_buffer != devshell.CREDENTIAL_INFO_REQUEST_JSON: 162 self.bad_request = True 163 l = len(self.response) 164 s.sendall('{0}\n{1}'.format(l, self.response).encode()) 165 finally: 166 # Will fail if s is None, but these tests never encounter 167 # that scenario. 168 s.close() 169 170 171class DevshellCredentialsTests(unittest2.TestCase): 172 173 def test_signals_no_server(self): 174 with self.assertRaises(devshell.NoDevshellServer): 175 devshell.DevshellCredentials() 176 177 def test_bad_message_to_mock_server(self): 178 request_content = devshell.CREDENTIAL_INFO_REQUEST_JSON + 'extrastuff' 179 request_message = _helpers._to_bytes( 180 '{0}\n{1}'.format(len(request_content), request_content)) 181 response_message = 'foobar' 182 with _AuthReferenceServer(response_message) as auth_server: 183 self.assertFalse(auth_server.bad_request) 184 sock = socket.socket() 185 port = int(os.getenv(devshell.DEVSHELL_ENV, 0)) 186 sock.connect(('localhost', port)) 187 sock.sendall(request_message) 188 189 # Mimic the receive part of _SendRecv 190 header = sock.recv(6).decode() 191 len_str, result = header.split('\n', 1) 192 to_read = int(len_str) - len(result) 193 result += sock.recv(to_read, socket.MSG_WAITALL).decode() 194 195 self.assertTrue(auth_server.bad_request) 196 self.assertEqual(result, response_message) 197 198 def test_request_response(self): 199 with _AuthReferenceServer(): 200 response = devshell._SendRecv() 201 self.assertEqual(response.user_email, 'joe@example.com') 202 self.assertEqual(response.project_id, 'fooproj') 203 self.assertEqual(response.access_token, 'sometoken') 204 205 def test_no_refresh_token(self): 206 with _AuthReferenceServer(): 207 creds = devshell.DevshellCredentials() 208 self.assertEquals(None, creds.refresh_token) 209 210 @mock.patch('oauth2client.client._UTCNOW') 211 def test_reads_credentials(self, utcnow): 212 NOW = datetime.datetime(1992, 12, 31) 213 utcnow.return_value = NOW 214 with _AuthReferenceServer(): 215 creds = devshell.DevshellCredentials() 216 self.assertEqual('joe@example.com', creds.user_email) 217 self.assertEqual('fooproj', creds.project_id) 218 self.assertEqual('sometoken', creds.access_token) 219 self.assertEqual( 220 NOW + datetime.timedelta(seconds=EXPIRES_IN), 221 creds.token_expiry) 222 utcnow.assert_called_once_with() 223 224 def test_handles_skipped_fields(self): 225 with _AuthReferenceServer('["joe@example.com"]'): 226 creds = devshell.DevshellCredentials() 227 self.assertEqual('joe@example.com', creds.user_email) 228 self.assertEqual(None, creds.project_id) 229 self.assertEqual(None, creds.access_token) 230 self.assertEqual(None, creds.token_expiry) 231 232 def test_handles_tiny_response(self): 233 with _AuthReferenceServer('[]'): 234 creds = devshell.DevshellCredentials() 235 self.assertEqual(None, creds.user_email) 236 self.assertEqual(None, creds.project_id) 237 self.assertEqual(None, creds.access_token) 238 239 def test_handles_ignores_extra_fields(self): 240 with _AuthReferenceServer( 241 '["joe@example.com", "fooproj", "sometoken", 1, "extra"]'): 242 creds = devshell.DevshellCredentials() 243 self.assertEqual('joe@example.com', creds.user_email) 244 self.assertEqual('fooproj', creds.project_id) 245 self.assertEqual('sometoken', creds.access_token) 246 247 def test_refuses_to_save_to_well_known_file(self): 248 ORIGINAL_ISDIR = os.path.isdir 249 try: 250 os.path.isdir = lambda path: True 251 with _AuthReferenceServer(): 252 creds = devshell.DevshellCredentials() 253 with self.assertRaises(NotImplementedError): 254 client.save_to_well_known_file(creds) 255 finally: 256 os.path.isdir = ORIGINAL_ISDIR 257 258 def test_from_json(self): 259 with self.assertRaises(NotImplementedError): 260 devshell.DevshellCredentials.from_json(None) 261 262 def test_serialization_data(self): 263 with _AuthReferenceServer('[]'): 264 credentials = devshell.DevshellCredentials() 265 with self.assertRaises(NotImplementedError): 266 getattr(credentials, 'serialization_data') 267