1# Copyright 2015 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Unit tests for oauth2client.multistore_file.""" 16 17import contextlib 18import datetime 19import json 20import multiprocessing 21import os 22import tempfile 23 24import fasteners 25import mock 26from six import StringIO 27import unittest2 28 29from oauth2client import client 30from oauth2client.contrib import multiprocess_file_storage 31 32from ..http_mock import HttpMockSequence 33 34 35@contextlib.contextmanager 36def scoped_child_process(target, **kwargs): 37 die_event = multiprocessing.Event() 38 ready_event = multiprocessing.Event() 39 process = multiprocessing.Process( 40 target=target, args=(die_event, ready_event), kwargs=kwargs) 41 process.start() 42 try: 43 ready_event.wait() 44 yield 45 finally: 46 die_event.set() 47 process.join(5) 48 49 50def _create_test_credentials(expiration=None): 51 access_token = 'foo' 52 client_secret = 'cOuDdkfjxxnv+' 53 refresh_token = '1/0/a.df219fjls0' 54 token_expiry = expiration or ( 55 datetime.datetime.utcnow() + datetime.timedelta(seconds=3600)) 56 token_uri = 'https://www.google.com/accounts/o8/oauth2/token' 57 user_agent = 'refresh_checker/1.0' 58 59 credentials = client.OAuth2Credentials( 60 access_token, 'test-client-id', client_secret, 61 refresh_token, token_expiry, token_uri, 62 user_agent) 63 return credentials 64 65 66def _generate_token_response_http(new_token='new_token'): 67 token_response = json.dumps({ 68 'access_token': new_token, 69 'expires_in': '3600', 70 }) 71 http = HttpMockSequence([ 72 ({'status': '200'}, token_response), 73 ]) 74 75 return http 76 77 78class MultiprocessStorageBehaviorTests(unittest2.TestCase): 79 80 def setUp(self): 81 filehandle, self.filename = tempfile.mkstemp( 82 'oauth2client_test.data') 83 os.close(filehandle) 84 85 def tearDown(self): 86 try: 87 os.unlink(self.filename) 88 os.unlink('{0}.lock'.format(self.filename)) 89 except OSError: # pragma: NO COVER 90 pass 91 92 def test_basic_operations(self): 93 credentials = _create_test_credentials() 94 95 store = multiprocess_file_storage.MultiprocessFileStorage( 96 self.filename, 'basic') 97 98 # Save credentials 99 store.put(credentials) 100 credentials = store.get() 101 102 self.assertIsNotNone(credentials) 103 self.assertEqual('foo', credentials.access_token) 104 105 # Reset internal cache, ensure credentials were saved. 106 store._backend._credentials = {} 107 credentials = store.get() 108 109 self.assertIsNotNone(credentials) 110 self.assertEqual('foo', credentials.access_token) 111 112 # Delete credentials 113 store.delete() 114 credentials = store.get() 115 116 self.assertIsNone(credentials) 117 118 def test_single_process_refresh(self): 119 store = multiprocess_file_storage.MultiprocessFileStorage( 120 self.filename, 'single-process') 121 credentials = _create_test_credentials() 122 credentials.set_store(store) 123 124 http = _generate_token_response_http() 125 credentials.refresh(http) 126 self.assertEqual(credentials.access_token, 'new_token') 127 128 retrieved = store.get() 129 self.assertEqual(retrieved.access_token, 'new_token') 130 131 def test_multi_process_refresh(self): 132 # This will test that two processes attempting to refresh credentials 133 # will only refresh once. 134 store = multiprocess_file_storage.MultiprocessFileStorage( 135 self.filename, 'multi-process') 136 credentials = _create_test_credentials() 137 credentials.set_store(store) 138 store.put(credentials) 139 140 def child_process_func( 141 die_event, ready_event, check_event): # pragma: NO COVER 142 store = multiprocess_file_storage.MultiprocessFileStorage( 143 self.filename, 'multi-process') 144 145 credentials = store.get() 146 self.assertIsNotNone(credentials) 147 148 # Make sure this thread gets to refresh first. 149 original_acquire_lock = store.acquire_lock 150 151 def replacement_acquire_lock(*args, **kwargs): 152 result = original_acquire_lock(*args, **kwargs) 153 ready_event.set() 154 check_event.wait() 155 return result 156 157 credentials.store.acquire_lock = replacement_acquire_lock 158 159 http = _generate_token_response_http('b') 160 credentials.refresh(http) 161 162 self.assertEqual(credentials.access_token, 'b') 163 164 check_event = multiprocessing.Event() 165 with scoped_child_process(child_process_func, check_event=check_event): 166 # The lock should be currently held by the child process. 167 self.assertFalse( 168 store._backend._process_lock.acquire(blocking=False)) 169 check_event.set() 170 171 # The child process will refresh first, so we should end up 172 # with 'b' as the token. 173 http = mock.Mock() 174 credentials.refresh(http=http) 175 self.assertEqual(credentials.access_token, 'b') 176 self.assertFalse(http.request.called) 177 178 retrieved = store.get() 179 self.assertEqual(retrieved.access_token, 'b') 180 181 def test_read_only_file_fail_lock(self): 182 credentials = _create_test_credentials() 183 184 # Grab the lock in another process, preventing this process from 185 # acquiring the lock. 186 def child_process(die_event, ready_event): # pragma: NO COVER 187 lock = fasteners.InterProcessLock( 188 '{0}.lock'.format(self.filename)) 189 with lock: 190 ready_event.set() 191 die_event.wait() 192 193 with scoped_child_process(child_process): 194 store = multiprocess_file_storage.MultiprocessFileStorage( 195 self.filename, 'fail-lock') 196 store.put(credentials) 197 self.assertTrue(store._backend._read_only) 198 199 # These credentials should still be in the store's memory-only cache. 200 self.assertIsNotNone(store.get()) 201 202 203class MultiprocessStorageUnitTests(unittest2.TestCase): 204 205 def setUp(self): 206 filehandle, self.filename = tempfile.mkstemp( 207 'oauth2client_test.data') 208 os.close(filehandle) 209 210 def tearDown(self): 211 try: 212 os.unlink(self.filename) 213 os.unlink('{0}.lock'.format(self.filename)) 214 except OSError: # pragma: NO COVER 215 pass 216 217 def test__create_file_if_needed(self): 218 self.assertFalse( 219 multiprocess_file_storage._create_file_if_needed(self.filename)) 220 os.unlink(self.filename) 221 self.assertTrue( 222 multiprocess_file_storage._create_file_if_needed(self.filename)) 223 self.assertTrue( 224 os.path.exists(self.filename)) 225 226 def test__get_backend(self): 227 backend_one = multiprocess_file_storage._get_backend('file_a') 228 backend_two = multiprocess_file_storage._get_backend('file_a') 229 backend_three = multiprocess_file_storage._get_backend('file_b') 230 231 self.assertIs(backend_one, backend_two) 232 self.assertIsNot(backend_one, backend_three) 233 234 def test__read_write_credentials_file(self): 235 credentials = _create_test_credentials() 236 contents = StringIO() 237 238 multiprocess_file_storage._write_credentials_file( 239 contents, {'key': credentials}) 240 241 contents.seek(0) 242 data = json.load(contents) 243 self.assertEqual(data['file_version'], 2) 244 self.assertTrue(data['credentials']['key']) 245 246 # Read it back. 247 contents.seek(0) 248 results = multiprocess_file_storage._load_credentials_file(contents) 249 self.assertEqual( 250 results['key'].access_token, credentials.access_token) 251 252 # Add an invalid credential and try reading it back. It should ignore 253 # the invalid one but still load the valid one. 254 data['credentials']['invalid'] = '123' 255 results = multiprocess_file_storage._load_credentials_file( 256 StringIO(json.dumps(data))) 257 self.assertNotIn('invalid', results) 258 self.assertEqual( 259 results['key'].access_token, credentials.access_token) 260 261 def test__load_credentials_file_invalid_json(self): 262 contents = StringIO('{[') 263 self.assertEqual( 264 multiprocess_file_storage._load_credentials_file(contents), {}) 265 266 def test__load_credentials_file_no_file_version(self): 267 contents = StringIO('{}') 268 self.assertEqual( 269 multiprocess_file_storage._load_credentials_file(contents), {}) 270 271 def test__load_credentials_file_bad_file_version(self): 272 contents = StringIO(json.dumps({'file_version': 1})) 273 self.assertEqual( 274 multiprocess_file_storage._load_credentials_file(contents), {}) 275 276 def test__load_credentials_no_open_file(self): 277 backend = multiprocess_file_storage._get_backend(self.filename) 278 backend._credentials = mock.Mock() 279 backend._credentials.update.side_effect = AssertionError() 280 backend._load_credentials() 281 282 def test_acquire_lock_nonexistent_file(self): 283 backend = multiprocess_file_storage._get_backend(self.filename) 284 os.unlink(self.filename) 285 backend._process_lock = mock.Mock() 286 backend._process_lock.acquire.return_value = False 287 backend.acquire_lock() 288 self.assertIsNone(backend._file) 289 290 def test_release_lock_with_no_file(self): 291 backend = multiprocess_file_storage._get_backend(self.filename) 292 backend._file = None 293 backend._read_only = True 294 backend._thread_lock.acquire() 295 backend.release_lock() 296 297 def test__refresh_predicate(self): 298 backend = multiprocess_file_storage._get_backend(self.filename) 299 300 credentials = _create_test_credentials() 301 self.assertFalse(backend._refresh_predicate(credentials)) 302 303 credentials.invalid = True 304 self.assertTrue(backend._refresh_predicate(credentials)) 305 306 credentials = _create_test_credentials( 307 expiration=( 308 datetime.datetime.utcnow() - datetime.timedelta(seconds=3600))) 309 self.assertTrue(backend._refresh_predicate(credentials)) 310 311 312if __name__ == '__main__': # pragma: NO COVER 313 unittest2.main() 314