1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Oauth2client.file tests
16
17Unit tests for oauth2client.file
18"""
19
20import copy
21import datetime
22import json
23import os
24import pickle
25import stat
26import tempfile
27
28import six
29from six.moves import http_client
30import unittest2
31
32from oauth2client import client
33from oauth2client import file
34from .http_mock import HttpMockSequence
35
36try:
37    # Python2
38    from future_builtins import oct
39except:  # pragma: NO COVER
40    pass
41
42__author__ = 'jcgregorio@google.com (Joe Gregorio)'
43
44_filehandle, FILENAME = tempfile.mkstemp('oauth2client_test.data')
45os.close(_filehandle)
46
47
48class OAuth2ClientFileTests(unittest2.TestCase):
49
50    def tearDown(self):
51        try:
52            os.unlink(FILENAME)
53        except OSError:
54            pass
55
56    def setUp(self):
57        try:
58            os.unlink(FILENAME)
59        except OSError:
60            pass
61
62    def _create_test_credentials(self, client_id='some_client_id',
63                                 expiration=None):
64        access_token = 'foo'
65        client_secret = 'cOuDdkfjxxnv+'
66        refresh_token = '1/0/a.df219fjls0'
67        token_expiry = expiration or datetime.datetime.utcnow()
68        token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
69        user_agent = 'refresh_checker/1.0'
70
71        credentials = client.OAuth2Credentials(
72            access_token, client_id, client_secret,
73            refresh_token, token_expiry, token_uri,
74            user_agent)
75        return credentials
76
77    def test_non_existent_file_storage(self):
78        s = file.Storage(FILENAME)
79        credentials = s.get()
80        self.assertEquals(None, credentials)
81
82    @unittest2.skipIf(not hasattr(os, 'symlink'), 'No symlink available')
83    def test_no_sym_link_credentials(self):
84        SYMFILENAME = FILENAME + '.sym'
85        os.symlink(FILENAME, SYMFILENAME)
86        s = file.Storage(SYMFILENAME)
87        try:
88            with self.assertRaises(file.CredentialsFileSymbolicLinkError):
89                s.get()
90        finally:
91            os.unlink(SYMFILENAME)
92
93    def test_pickle_and_json_interop(self):
94        # Write a file with a pickled OAuth2Credentials.
95        credentials = self._create_test_credentials()
96
97        f = open(FILENAME, 'wb')
98        pickle.dump(credentials, f)
99        f.close()
100
101        # Storage should be not be able to read that object, as the capability
102        # to read and write credentials as pickled objects has been removed.
103        s = file.Storage(FILENAME)
104        read_credentials = s.get()
105        self.assertEquals(None, read_credentials)
106
107        # Now write it back out and confirm it has been rewritten as JSON
108        s.put(credentials)
109        with open(FILENAME) as f:
110            data = json.load(f)
111
112        self.assertEquals(data['access_token'], 'foo')
113        self.assertEquals(data['_class'], 'OAuth2Credentials')
114        self.assertEquals(data['_module'], client.OAuth2Credentials.__module__)
115
116    def test_token_refresh_store_expired(self):
117        expiration = (datetime.datetime.utcnow() -
118                      datetime.timedelta(minutes=15))
119        credentials = self._create_test_credentials(expiration=expiration)
120
121        s = file.Storage(FILENAME)
122        s.put(credentials)
123        credentials = s.get()
124        new_cred = copy.copy(credentials)
125        new_cred.access_token = 'bar'
126        s.put(new_cred)
127
128        access_token = '1/3w'
129        token_response = {'access_token': access_token, 'expires_in': 3600}
130        http = HttpMockSequence([
131            ({'status': '200'}, json.dumps(token_response).encode('utf-8')),
132        ])
133
134        credentials._refresh(http.request)
135        self.assertEquals(credentials.access_token, access_token)
136
137    def test_token_refresh_store_expires_soon(self):
138        # Tests the case where an access token that is valid when it is read
139        # from the store expires before the original request succeeds.
140        expiration = (datetime.datetime.utcnow() +
141                      datetime.timedelta(minutes=15))
142        credentials = self._create_test_credentials(expiration=expiration)
143
144        s = file.Storage(FILENAME)
145        s.put(credentials)
146        credentials = s.get()
147        new_cred = copy.copy(credentials)
148        new_cred.access_token = 'bar'
149        s.put(new_cred)
150
151        access_token = '1/3w'
152        token_response = {'access_token': access_token, 'expires_in': 3600}
153        http = HttpMockSequence([
154            ({'status': str(int(http_client.UNAUTHORIZED))},
155             b'Initial token expired'),
156            ({'status': str(int(http_client.UNAUTHORIZED))},
157             b'Store token expired'),
158            ({'status': str(int(http_client.OK))},
159             json.dumps(token_response).encode('utf-8')),
160            ({'status': str(int(http_client.OK))},
161             b'Valid response to original request')
162        ])
163
164        credentials.authorize(http)
165        http.request('https://example.com')
166        self.assertEqual(credentials.access_token, access_token)
167
168    def test_token_refresh_good_store(self):
169        expiration = (datetime.datetime.utcnow() +
170                      datetime.timedelta(minutes=15))
171        credentials = self._create_test_credentials(expiration=expiration)
172
173        s = file.Storage(FILENAME)
174        s.put(credentials)
175        credentials = s.get()
176        new_cred = copy.copy(credentials)
177        new_cred.access_token = 'bar'
178        s.put(new_cred)
179
180        credentials._refresh(None)
181        self.assertEquals(credentials.access_token, 'bar')
182
183    def test_token_refresh_stream_body(self):
184        expiration = (datetime.datetime.utcnow() +
185                      datetime.timedelta(minutes=15))
186        credentials = self._create_test_credentials(expiration=expiration)
187
188        s = file.Storage(FILENAME)
189        s.put(credentials)
190        credentials = s.get()
191        new_cred = copy.copy(credentials)
192        new_cred.access_token = 'bar'
193        s.put(new_cred)
194
195        valid_access_token = '1/3w'
196        token_response = {'access_token': valid_access_token,
197                          'expires_in': 3600}
198        http = HttpMockSequence([
199            ({'status': str(int(http_client.UNAUTHORIZED))},
200             b'Initial token expired'),
201            ({'status': str(int(http_client.UNAUTHORIZED))},
202             b'Store token expired'),
203            ({'status': str(int(http_client.OK))},
204             json.dumps(token_response).encode('utf-8')),
205            ({'status': str(int(http_client.OK))}, 'echo_request_body')
206        ])
207
208        body = six.StringIO('streaming body')
209
210        credentials.authorize(http)
211        _, content = http.request('https://example.com', body=body)
212        self.assertEqual(content, 'streaming body')
213        self.assertEqual(credentials.access_token, valid_access_token)
214
215    def test_credentials_delete(self):
216        credentials = self._create_test_credentials()
217
218        s = file.Storage(FILENAME)
219        s.put(credentials)
220        credentials = s.get()
221        self.assertNotEquals(None, credentials)
222        s.delete()
223        credentials = s.get()
224        self.assertEquals(None, credentials)
225
226    def test_access_token_credentials(self):
227        access_token = 'foo'
228        user_agent = 'refresh_checker/1.0'
229
230        credentials = client.AccessTokenCredentials(access_token, user_agent)
231
232        s = file.Storage(FILENAME)
233        credentials = s.put(credentials)
234        credentials = s.get()
235
236        self.assertNotEquals(None, credentials)
237        self.assertEquals('foo', credentials.access_token)
238
239        self.assertTrue(os.path.exists(FILENAME))
240
241        if os.name == 'posix':  # pragma: NO COVER
242            mode = os.stat(FILENAME).st_mode
243            self.assertEquals('0o600', oct(stat.S_IMODE(mode)))
244