1# Copyright 2015 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unit tests for oauth2client.multistore_file."""
16
17import contextlib
18import datetime
19import json
20import multiprocessing
21import os
22import tempfile
23
24import fasteners
25import mock
26from six import StringIO
27import unittest2
28
29from oauth2client import client
30from oauth2client.contrib import multiprocess_file_storage
31
32from ..http_mock import HttpMockSequence
33
34
35@contextlib.contextmanager
36def scoped_child_process(target, **kwargs):
37    die_event = multiprocessing.Event()
38    ready_event = multiprocessing.Event()
39    process = multiprocessing.Process(
40        target=target, args=(die_event, ready_event), kwargs=kwargs)
41    process.start()
42    try:
43        ready_event.wait()
44        yield
45    finally:
46        die_event.set()
47        process.join(5)
48
49
50def _create_test_credentials(expiration=None):
51    access_token = 'foo'
52    client_secret = 'cOuDdkfjxxnv+'
53    refresh_token = '1/0/a.df219fjls0'
54    token_expiry = expiration or (
55        datetime.datetime.utcnow() + datetime.timedelta(seconds=3600))
56    token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
57    user_agent = 'refresh_checker/1.0'
58
59    credentials = client.OAuth2Credentials(
60        access_token, 'test-client-id', client_secret,
61        refresh_token, token_expiry, token_uri,
62        user_agent)
63    return credentials
64
65
66def _generate_token_response_http(new_token='new_token'):
67    token_response = json.dumps({
68        'access_token': new_token,
69        'expires_in': '3600',
70    })
71    http = HttpMockSequence([
72        ({'status': '200'}, token_response),
73    ])
74
75    return http
76
77
78class MultiprocessStorageBehaviorTests(unittest2.TestCase):
79
80    def setUp(self):
81        filehandle, self.filename = tempfile.mkstemp(
82            'oauth2client_test.data')
83        os.close(filehandle)
84
85    def tearDown(self):
86        try:
87            os.unlink(self.filename)
88            os.unlink('{0}.lock'.format(self.filename))
89        except OSError:  # pragma: NO COVER
90            pass
91
92    def test_basic_operations(self):
93        credentials = _create_test_credentials()
94
95        store = multiprocess_file_storage.MultiprocessFileStorage(
96            self.filename, 'basic')
97
98        # Save credentials
99        store.put(credentials)
100        credentials = store.get()
101
102        self.assertIsNotNone(credentials)
103        self.assertEqual('foo', credentials.access_token)
104
105        # Reset internal cache, ensure credentials were saved.
106        store._backend._credentials = {}
107        credentials = store.get()
108
109        self.assertIsNotNone(credentials)
110        self.assertEqual('foo', credentials.access_token)
111
112        # Delete credentials
113        store.delete()
114        credentials = store.get()
115
116        self.assertIsNone(credentials)
117
118    def test_single_process_refresh(self):
119        store = multiprocess_file_storage.MultiprocessFileStorage(
120            self.filename, 'single-process')
121        credentials = _create_test_credentials()
122        credentials.set_store(store)
123
124        http = _generate_token_response_http()
125        credentials.refresh(http)
126        self.assertEqual(credentials.access_token, 'new_token')
127
128        retrieved = store.get()
129        self.assertEqual(retrieved.access_token, 'new_token')
130
131    def test_multi_process_refresh(self):
132        # This will test that two processes attempting to refresh credentials
133        # will only refresh once.
134        store = multiprocess_file_storage.MultiprocessFileStorage(
135            self.filename, 'multi-process')
136        credentials = _create_test_credentials()
137        credentials.set_store(store)
138        store.put(credentials)
139
140        def child_process_func(
141                die_event, ready_event, check_event):  # pragma: NO COVER
142            store = multiprocess_file_storage.MultiprocessFileStorage(
143                self.filename, 'multi-process')
144
145            credentials = store.get()
146            self.assertIsNotNone(credentials)
147
148            # Make sure this thread gets to refresh first.
149            original_acquire_lock = store.acquire_lock
150
151            def replacement_acquire_lock(*args, **kwargs):
152                result = original_acquire_lock(*args, **kwargs)
153                ready_event.set()
154                check_event.wait()
155                return result
156
157            credentials.store.acquire_lock = replacement_acquire_lock
158
159            http = _generate_token_response_http('b')
160            credentials.refresh(http)
161
162            self.assertEqual(credentials.access_token, 'b')
163
164        check_event = multiprocessing.Event()
165        with scoped_child_process(child_process_func, check_event=check_event):
166            # The lock should be currently held by the child process.
167            self.assertFalse(
168                store._backend._process_lock.acquire(blocking=False))
169            check_event.set()
170
171            # The child process will refresh first, so we should end up
172            # with 'b' as the token.
173            http = mock.Mock()
174            credentials.refresh(http=http)
175            self.assertEqual(credentials.access_token, 'b')
176            self.assertFalse(http.request.called)
177
178        retrieved = store.get()
179        self.assertEqual(retrieved.access_token, 'b')
180
181    def test_read_only_file_fail_lock(self):
182        credentials = _create_test_credentials()
183
184        # Grab the lock in another process, preventing this process from
185        # acquiring the lock.
186        def child_process(die_event, ready_event):  # pragma: NO COVER
187            lock = fasteners.InterProcessLock(
188                '{0}.lock'.format(self.filename))
189            with lock:
190                ready_event.set()
191                die_event.wait()
192
193        with scoped_child_process(child_process):
194            store = multiprocess_file_storage.MultiprocessFileStorage(
195                self.filename, 'fail-lock')
196            store.put(credentials)
197            self.assertTrue(store._backend._read_only)
198
199        # These credentials should still be in the store's memory-only cache.
200        self.assertIsNotNone(store.get())
201
202
203class MultiprocessStorageUnitTests(unittest2.TestCase):
204
205    def setUp(self):
206        filehandle, self.filename = tempfile.mkstemp(
207            'oauth2client_test.data')
208        os.close(filehandle)
209
210    def tearDown(self):
211        try:
212            os.unlink(self.filename)
213            os.unlink('{0}.lock'.format(self.filename))
214        except OSError:  # pragma: NO COVER
215            pass
216
217    def test__create_file_if_needed(self):
218        self.assertFalse(
219            multiprocess_file_storage._create_file_if_needed(self.filename))
220        os.unlink(self.filename)
221        self.assertTrue(
222            multiprocess_file_storage._create_file_if_needed(self.filename))
223        self.assertTrue(
224            os.path.exists(self.filename))
225
226    def test__get_backend(self):
227        backend_one = multiprocess_file_storage._get_backend('file_a')
228        backend_two = multiprocess_file_storage._get_backend('file_a')
229        backend_three = multiprocess_file_storage._get_backend('file_b')
230
231        self.assertIs(backend_one, backend_two)
232        self.assertIsNot(backend_one, backend_three)
233
234    def test__read_write_credentials_file(self):
235        credentials = _create_test_credentials()
236        contents = StringIO()
237
238        multiprocess_file_storage._write_credentials_file(
239            contents, {'key': credentials})
240
241        contents.seek(0)
242        data = json.load(contents)
243        self.assertEqual(data['file_version'], 2)
244        self.assertTrue(data['credentials']['key'])
245
246        # Read it back.
247        contents.seek(0)
248        results = multiprocess_file_storage._load_credentials_file(contents)
249        self.assertEqual(
250            results['key'].access_token, credentials.access_token)
251
252        # Add an invalid credential and try reading it back. It should ignore
253        # the invalid one but still load the valid one.
254        data['credentials']['invalid'] = '123'
255        results = multiprocess_file_storage._load_credentials_file(
256            StringIO(json.dumps(data)))
257        self.assertNotIn('invalid', results)
258        self.assertEqual(
259            results['key'].access_token, credentials.access_token)
260
261    def test__load_credentials_file_invalid_json(self):
262        contents = StringIO('{[')
263        self.assertEqual(
264            multiprocess_file_storage._load_credentials_file(contents), {})
265
266    def test__load_credentials_file_no_file_version(self):
267        contents = StringIO('{}')
268        self.assertEqual(
269            multiprocess_file_storage._load_credentials_file(contents), {})
270
271    def test__load_credentials_file_bad_file_version(self):
272        contents = StringIO(json.dumps({'file_version': 1}))
273        self.assertEqual(
274            multiprocess_file_storage._load_credentials_file(contents), {})
275
276    def test__load_credentials_no_open_file(self):
277        backend = multiprocess_file_storage._get_backend(self.filename)
278        backend._credentials = mock.Mock()
279        backend._credentials.update.side_effect = AssertionError()
280        backend._load_credentials()
281
282    def test_acquire_lock_nonexistent_file(self):
283        backend = multiprocess_file_storage._get_backend(self.filename)
284        os.unlink(self.filename)
285        backend._process_lock = mock.Mock()
286        backend._process_lock.acquire.return_value = False
287        backend.acquire_lock()
288        self.assertIsNone(backend._file)
289
290    def test_release_lock_with_no_file(self):
291        backend = multiprocess_file_storage._get_backend(self.filename)
292        backend._file = None
293        backend._read_only = True
294        backend._thread_lock.acquire()
295        backend.release_lock()
296
297    def test__refresh_predicate(self):
298        backend = multiprocess_file_storage._get_backend(self.filename)
299
300        credentials = _create_test_credentials()
301        self.assertFalse(backend._refresh_predicate(credentials))
302
303        credentials.invalid = True
304        self.assertTrue(backend._refresh_predicate(credentials))
305
306        credentials = _create_test_credentials(
307            expiration=(
308                datetime.datetime.utcnow() - datetime.timedelta(seconds=3600)))
309        self.assertTrue(backend._refresh_predicate(credentials))
310
311
312if __name__ == '__main__':  # pragma: NO COVER
313    unittest2.main()
314