1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17import os
18import tempfile
19import time
20
21import dev_appserver
22
23dev_appserver.fix_sys_path()
24
25from google.appengine.api import apiproxy_stub
26from google.appengine.api import apiproxy_stub_map
27from google.appengine.api import app_identity
28from google.appengine.api import memcache
29from google.appengine.api import users
30from google.appengine.api.memcache import memcache_stub
31from google.appengine.ext import db
32from google.appengine.ext import ndb
33from google.appengine.ext import testbed
34import httplib2
35import mock
36from six.moves import urllib
37import unittest2
38import webapp2
39from webtest import TestApp
40
41import oauth2client
42from oauth2client import client
43from oauth2client import clientsecrets
44from oauth2client.contrib import appengine
45from ..http_mock import CacheMock
46
47__author__ = 'jcgregorio@google.com (Joe Gregorio)'
48
49DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
50
51
52def datafile(filename):
53    return os.path.join(DATA_DIR, filename)
54
55
56def load_and_cache(existing_file, fakename, cache_mock):
57    client_type, client_info = clientsecrets._loadfile(datafile(existing_file))
58    cache_mock.cache[fakename] = {client_type: client_info}
59
60
61class UserMock(object):
62    """Mock the app engine user service"""
63
64    def __call__(self):
65        return self
66
67    def user_id(self):
68        return 'foo_user'
69
70
71class UserNotLoggedInMock(object):
72    """Mock the app engine user service"""
73
74    def __call__(self):
75        return None
76
77
78class Http2Mock(object):
79    """Mock httplib2.Http"""
80    status = 200
81    content = {
82        'access_token': 'foo_access_token',
83        'refresh_token': 'foo_refresh_token',
84        'expires_in': 3600,
85        'extra': 'value',
86    }
87
88    def request(self, token_uri, method, body, headers, *args, **kwargs):
89        self.body = body
90        self.headers = headers
91        return self, json.dumps(self.content)
92
93
94class TestAppAssertionCredentials(unittest2.TestCase):
95    account_name = "service_account_name@appspot.com"
96    signature = "signature"
97
98    class AppIdentityStubImpl(apiproxy_stub.APIProxyStub):
99
100        def __init__(self, key_name=None, sig_bytes=None,
101                     svc_acct=None):
102            super(TestAppAssertionCredentials.AppIdentityStubImpl,
103                  self).__init__('app_identity_service')
104            self._key_name = key_name
105            self._sig_bytes = sig_bytes
106            self._sign_calls = []
107            self._svc_acct = svc_acct
108            self._get_acct_name_calls = 0
109
110        def _Dynamic_GetAccessToken(self, request, response):
111            response.set_access_token('a_token_123')
112            response.set_expiration_time(time.time() + 1800)
113
114        def _Dynamic_SignForApp(self, request, response):
115            response.set_key_name(self._key_name)
116            response.set_signature_bytes(self._sig_bytes)
117            self._sign_calls.append(request.bytes_to_sign())
118
119        def _Dynamic_GetServiceAccountName(self, request, response):
120            response.set_service_account_name(self._svc_acct)
121            self._get_acct_name_calls += 1
122
123    class ErroringAppIdentityStubImpl(apiproxy_stub.APIProxyStub):
124
125        def __init__(self):
126            super(TestAppAssertionCredentials.ErroringAppIdentityStubImpl,
127                  self).__init__('app_identity_service')
128
129        def _Dynamic_GetAccessToken(self, request, response):
130            raise app_identity.BackendDeadlineExceeded()
131
132    def test_raise_correct_type_of_exception(self):
133        app_identity_stub = self.ErroringAppIdentityStubImpl()
134        apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
135        apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
136                                                app_identity_stub)
137        apiproxy_stub_map.apiproxy.RegisterStub(
138            'memcache', memcache_stub.MemcacheServiceStub())
139
140        scope = 'http://www.googleapis.com/scope'
141        credentials = appengine.AppAssertionCredentials(scope)
142        http = httplib2.Http()
143        with self.assertRaises(client.AccessTokenRefreshError):
144            credentials.refresh(http)
145
146    def test_get_access_token_on_refresh(self):
147        app_identity_stub = self.AppIdentityStubImpl()
148        apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
149        apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
150                                                app_identity_stub)
151        apiproxy_stub_map.apiproxy.RegisterStub(
152            'memcache', memcache_stub.MemcacheServiceStub())
153
154        scope = [
155            "http://www.googleapis.com/scope",
156            "http://www.googleapis.com/scope2"]
157        credentials = appengine.AppAssertionCredentials(scope)
158        http = httplib2.Http()
159        credentials.refresh(http)
160        self.assertEqual('a_token_123', credentials.access_token)
161
162        json = credentials.to_json()
163        credentials = client.Credentials.new_from_json(json)
164        self.assertEqual(
165            'http://www.googleapis.com/scope http://www.googleapis.com/scope2',
166            credentials.scope)
167
168        scope = ('http://www.googleapis.com/scope '
169                 'http://www.googleapis.com/scope2')
170        credentials = appengine.AppAssertionCredentials(scope)
171        http = httplib2.Http()
172        credentials.refresh(http)
173        self.assertEqual('a_token_123', credentials.access_token)
174        self.assertEqual(
175            'http://www.googleapis.com/scope http://www.googleapis.com/scope2',
176            credentials.scope)
177
178    def test_custom_service_account(self):
179        scope = "http://www.googleapis.com/scope"
180        account_id = "service_account_name_2@appspot.com"
181
182        with mock.patch.object(app_identity, 'get_access_token',
183                               return_value=('a_token_456', None),
184                               autospec=True) as get_access_token:
185            credentials = appengine.AppAssertionCredentials(
186                scope, service_account_id=account_id)
187            http = httplib2.Http()
188            credentials.refresh(http)
189
190            self.assertEqual('a_token_456', credentials.access_token)
191            self.assertEqual(scope, credentials.scope)
192            get_access_token.assert_called_once_with(
193                [scope], service_account_id=account_id)
194
195    def test_create_scoped_required_without_scopes(self):
196        credentials = appengine.AppAssertionCredentials([])
197        self.assertTrue(credentials.create_scoped_required())
198
199    def test_create_scoped_required_with_scopes(self):
200        credentials = appengine.AppAssertionCredentials(['dummy_scope'])
201        self.assertFalse(credentials.create_scoped_required())
202
203    def test_create_scoped(self):
204        credentials = appengine.AppAssertionCredentials([])
205        new_credentials = credentials.create_scoped(['dummy_scope'])
206        self.assertNotEqual(credentials, new_credentials)
207        self.assertIsInstance(
208            new_credentials, appengine.AppAssertionCredentials)
209        self.assertEqual('dummy_scope', new_credentials.scope)
210
211    def test_sign_blob(self):
212        key_name = b'1234567890'
213        sig_bytes = b'himom'
214        app_identity_stub = self.AppIdentityStubImpl(
215            key_name=key_name, sig_bytes=sig_bytes)
216        apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
217        apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
218                                                app_identity_stub)
219        credentials = appengine.AppAssertionCredentials([])
220        to_sign = b'blob'
221        self.assertEqual(app_identity_stub._sign_calls, [])
222        result = credentials.sign_blob(to_sign)
223        self.assertEqual(result, (key_name, sig_bytes))
224        self.assertEqual(app_identity_stub._sign_calls, [to_sign])
225
226    def test_service_account_email(self):
227        acct_name = 'new-value@appspot.gserviceaccount.com'
228        app_identity_stub = self.AppIdentityStubImpl(svc_acct=acct_name)
229        apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
230        apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
231                                                app_identity_stub)
232
233        credentials = appengine.AppAssertionCredentials([])
234        self.assertIsNone(credentials._service_account_email)
235        self.assertEqual(app_identity_stub._get_acct_name_calls, 0)
236        self.assertEqual(credentials.service_account_email, acct_name)
237        self.assertIsNotNone(credentials._service_account_email)
238        self.assertEqual(app_identity_stub._get_acct_name_calls, 1)
239
240    def test_service_account_email_already_set(self):
241        acct_name = 'existing@appspot.gserviceaccount.com'
242        credentials = appengine.AppAssertionCredentials([])
243        credentials._service_account_email = acct_name
244
245        app_identity_stub = self.AppIdentityStubImpl(svc_acct=acct_name)
246        apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
247        apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
248                                                app_identity_stub)
249
250        self.assertEqual(app_identity_stub._get_acct_name_calls, 0)
251        self.assertEqual(credentials.service_account_email, acct_name)
252        self.assertEqual(app_identity_stub._get_acct_name_calls, 0)
253
254    def test_get_access_token(self):
255        app_identity_stub = self.AppIdentityStubImpl()
256        apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
257        apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
258                                                app_identity_stub)
259        apiproxy_stub_map.apiproxy.RegisterStub(
260            'memcache', memcache_stub.MemcacheServiceStub())
261
262        credentials = appengine.AppAssertionCredentials(['dummy_scope'])
263        token = credentials.get_access_token()
264        self.assertEqual('a_token_123', token.access_token)
265        self.assertEqual(None, token.expires_in)
266
267    def test_save_to_well_known_file(self):
268        os.environ[client._CLOUDSDK_CONFIG_ENV_VAR] = tempfile.mkdtemp()
269        credentials = appengine.AppAssertionCredentials([])
270        with self.assertRaises(NotImplementedError):
271            client.save_to_well_known_file(credentials)
272        del os.environ[client._CLOUDSDK_CONFIG_ENV_VAR]
273
274
275class TestFlowModel(db.Model):
276    flow = appengine.FlowProperty()
277
278
279class FlowPropertyTest(unittest2.TestCase):
280
281    def setUp(self):
282        self.testbed = testbed.Testbed()
283        self.testbed.activate()
284        self.testbed.init_datastore_v3_stub()
285
286        self.flow = client.flow_from_clientsecrets(
287            datafile('client_secrets.json'),
288            'foo',
289            redirect_uri='oob')
290
291    def tearDown(self):
292        self.testbed.deactivate()
293
294    def test_flow_get_put(self):
295        instance = TestFlowModel(
296            flow=self.flow,
297            key_name='foo'
298        )
299        instance.put()
300        retrieved = TestFlowModel.get_by_key_name('foo')
301
302        self.assertEqual('foo_client_id', retrieved.flow.client_id)
303
304    def test_make_value_from_datastore_none(self):
305        self.assertIsNone(
306            appengine.FlowProperty().make_value_from_datastore(None))
307
308    def test_validate(self):
309        appengine.FlowProperty().validate(None)
310        with self.assertRaises(db.BadValueError):
311            appengine.FlowProperty().validate(42)
312
313
314class TestCredentialsModel(db.Model):
315    credentials = appengine.CredentialsProperty()
316
317
318class CredentialsPropertyTest(unittest2.TestCase):
319
320    def setUp(self):
321        self.testbed = testbed.Testbed()
322        self.testbed.activate()
323        self.testbed.init_datastore_v3_stub()
324
325        access_token = 'foo'
326        client_id = 'some_client_id'
327        client_secret = 'cOuDdkfjxxnv+'
328        refresh_token = '1/0/a.df219fjls0'
329        token_expiry = datetime.datetime.utcnow()
330        user_agent = 'refresh_checker/1.0'
331        self.credentials = client.OAuth2Credentials(
332            access_token, client_id, client_secret,
333            refresh_token, token_expiry, oauth2client.GOOGLE_TOKEN_URI,
334            user_agent)
335
336    def tearDown(self):
337        self.testbed.deactivate()
338
339    def test_credentials_get_put(self):
340        instance = TestCredentialsModel(
341            credentials=self.credentials,
342            key_name='foo'
343        )
344        instance.put()
345        retrieved = TestCredentialsModel.get_by_key_name('foo')
346
347        self.assertEqual(
348            self.credentials.to_json(),
349            retrieved.credentials.to_json())
350
351    def test_make_value_from_datastore(self):
352        self.assertIsNone(
353            appengine.CredentialsProperty().make_value_from_datastore(None))
354        self.assertIsNone(
355            appengine.CredentialsProperty().make_value_from_datastore(''))
356        self.assertIsNone(
357            appengine.CredentialsProperty().make_value_from_datastore('{'))
358
359        decoded = appengine.CredentialsProperty().make_value_from_datastore(
360            self.credentials.to_json())
361        self.assertEqual(
362            self.credentials.to_json(),
363            decoded.to_json())
364
365    def test_validate(self):
366        appengine.CredentialsProperty().validate(self.credentials)
367        appengine.CredentialsProperty().validate(None)
368        with self.assertRaises(db.BadValueError):
369            appengine.CredentialsProperty().validate(42)
370
371
372def _http_request(*args, **kwargs):
373    resp = httplib2.Response({'status': '200'})
374    content = json.dumps({'access_token': 'bar'})
375
376    return resp, content
377
378
379class StorageByKeyNameTest(unittest2.TestCase):
380
381    def setUp(self):
382        self.testbed = testbed.Testbed()
383        self.testbed.activate()
384        self.testbed.init_datastore_v3_stub()
385        self.testbed.init_memcache_stub()
386        self.testbed.init_user_stub()
387
388        access_token = 'foo'
389        client_id = 'some_client_id'
390        client_secret = 'cOuDdkfjxxnv+'
391        refresh_token = '1/0/a.df219fjls0'
392        token_expiry = datetime.datetime.utcnow()
393        user_agent = 'refresh_checker/1.0'
394        self.credentials = client.OAuth2Credentials(
395            access_token, client_id, client_secret,
396            refresh_token, token_expiry, oauth2client.GOOGLE_TOKEN_URI,
397            user_agent)
398
399    def tearDown(self):
400        self.testbed.deactivate()
401
402    def test_bad_ctor(self):
403        with self.assertRaises(ValueError):
404            appengine.StorageByKeyName(appengine.CredentialsModel, None, None)
405
406    def test__is_ndb(self):
407        storage = appengine.StorageByKeyName(
408            object(), 'foo', 'credentials')
409
410        with self.assertRaises(TypeError):
411            storage._is_ndb()
412
413        storage._model = type(object)
414        with self.assertRaises(TypeError):
415            storage._is_ndb()
416
417        storage._model = appengine.CredentialsModel
418        self.assertFalse(storage._is_ndb())
419
420        storage._model = appengine.CredentialsNDBModel
421        self.assertTrue(storage._is_ndb())
422
423    def test_get_and_put_simple(self):
424        storage = appengine.StorageByKeyName(
425            appengine.CredentialsModel, 'foo', 'credentials')
426
427        self.assertEqual(None, storage.get())
428        self.credentials.set_store(storage)
429
430        self.credentials._refresh(_http_request)
431        credmodel = appengine.CredentialsModel.get_by_key_name('foo')
432        self.assertEqual('bar', credmodel.credentials.access_token)
433
434    def test_get_and_put_cached(self):
435        storage = appengine.StorageByKeyName(
436            appengine.CredentialsModel, 'foo', 'credentials', cache=memcache)
437
438        self.assertEqual(None, storage.get())
439        self.credentials.set_store(storage)
440
441        self.credentials._refresh(_http_request)
442        credmodel = appengine.CredentialsModel.get_by_key_name('foo')
443        self.assertEqual('bar', credmodel.credentials.access_token)
444
445        # Now remove the item from the cache.
446        memcache.delete('foo')
447
448        # Check that getting refreshes the cache.
449        credentials = storage.get()
450        self.assertEqual('bar', credentials.access_token)
451        self.assertNotEqual(None, memcache.get('foo'))
452
453        # Deleting should clear the cache.
454        storage.delete()
455        credentials = storage.get()
456        self.assertEqual(None, credentials)
457        self.assertEqual(None, memcache.get('foo'))
458
459    def test_get_and_put_set_store_on_cache_retrieval(self):
460        storage = appengine.StorageByKeyName(
461            appengine.CredentialsModel, 'foo', 'credentials', cache=memcache)
462
463        self.assertEqual(None, storage.get())
464        self.credentials.set_store(storage)
465        storage.put(self.credentials)
466        # Pre-bug 292 old_creds wouldn't have storage, and the _refresh
467        # wouldn't be able to store the updated cred back into the storage.
468        old_creds = storage.get()
469        self.assertEqual(old_creds.access_token, 'foo')
470        old_creds.invalid = True
471        old_creds._refresh(_http_request)
472        new_creds = storage.get()
473        self.assertEqual(new_creds.access_token, 'bar')
474
475    def test_get_and_put_ndb(self):
476        # Start empty
477        storage = appengine.StorageByKeyName(
478            appengine.CredentialsNDBModel, 'foo', 'credentials')
479        self.assertEqual(None, storage.get())
480
481        # Refresh storage and retrieve without using storage
482        self.credentials.set_store(storage)
483        self.credentials._refresh(_http_request)
484        credmodel = appengine.CredentialsNDBModel.get_by_id('foo')
485        self.assertEqual('bar', credmodel.credentials.access_token)
486        self.assertEqual(credmodel.credentials.to_json(),
487                         self.credentials.to_json())
488
489    def test_delete_ndb(self):
490        # Start empty
491        storage = appengine.StorageByKeyName(
492            appengine.CredentialsNDBModel, 'foo', 'credentials')
493        self.assertEqual(None, storage.get())
494
495        # Add credentials to model with storage, and check equivalent
496        # w/o storage
497        storage.put(self.credentials)
498        credmodel = appengine.CredentialsNDBModel.get_by_id('foo')
499        self.assertEqual(credmodel.credentials.to_json(),
500                         self.credentials.to_json())
501
502        # Delete and make sure empty
503        storage.delete()
504        self.assertEqual(None, storage.get())
505
506    def test_get_and_put_mixed_ndb_storage_db_get(self):
507        # Start empty
508        storage = appengine.StorageByKeyName(
509            appengine.CredentialsNDBModel, 'foo', 'credentials')
510        self.assertEqual(None, storage.get())
511
512        # Set NDB store and refresh to add to storage
513        self.credentials.set_store(storage)
514        self.credentials._refresh(_http_request)
515
516        # Retrieve same key from DB model to confirm mixing works
517        credmodel = appengine.CredentialsModel.get_by_key_name('foo')
518        self.assertEqual('bar', credmodel.credentials.access_token)
519        self.assertEqual(self.credentials.to_json(),
520                         credmodel.credentials.to_json())
521
522    def test_get_and_put_mixed_db_storage_ndb_get(self):
523        # Start empty
524        storage = appengine.StorageByKeyName(
525            appengine.CredentialsModel, 'foo', 'credentials')
526        self.assertEqual(None, storage.get())
527
528        # Set DB store and refresh to add to storage
529        self.credentials.set_store(storage)
530        self.credentials._refresh(_http_request)
531
532        # Retrieve same key from NDB model to confirm mixing works
533        credmodel = appengine.CredentialsNDBModel.get_by_id('foo')
534        self.assertEqual('bar', credmodel.credentials.access_token)
535        self.assertEqual(self.credentials.to_json(),
536                         credmodel.credentials.to_json())
537
538    def test_delete_db_ndb_mixed(self):
539        # Start empty
540        storage_ndb = appengine.StorageByKeyName(
541            appengine.CredentialsNDBModel, 'foo', 'credentials')
542        storage = appengine.StorageByKeyName(
543            appengine.CredentialsModel, 'foo', 'credentials')
544
545        # First DB, then NDB
546        self.assertEqual(None, storage.get())
547        storage.put(self.credentials)
548        self.assertNotEqual(None, storage.get())
549
550        storage_ndb.delete()
551        self.assertEqual(None, storage.get())
552
553        # First NDB, then DB
554        self.assertEqual(None, storage_ndb.get())
555        storage_ndb.put(self.credentials)
556
557        storage.delete()
558        self.assertNotEqual(None, storage_ndb.get())
559        # NDB uses memcache and an instance cache (Context)
560        ndb.get_context().clear_cache()
561        memcache.flush_all()
562        self.assertEqual(None, storage_ndb.get())
563
564
565class MockRequest(object):
566    url = 'https://example.org'
567
568    def relative_url(self, rel):
569        return self.url + rel
570
571
572class MockRequestHandler(object):
573    request = MockRequest()
574
575
576class DecoratorTests(unittest2.TestCase):
577
578    def setUp(self):
579        self.testbed = testbed.Testbed()
580        self.testbed.activate()
581        self.testbed.init_datastore_v3_stub()
582        self.testbed.init_memcache_stub()
583        self.testbed.init_user_stub()
584
585        decorator = appengine.OAuth2Decorator(
586            client_id='foo_client_id', client_secret='foo_client_secret',
587            scope=['foo_scope', 'bar_scope'], user_agent='foo')
588
589        self._finish_setup(decorator, user_mock=UserMock)
590
591    def _finish_setup(self, decorator, user_mock):
592        self.decorator = decorator
593        self.had_credentials = False
594        self.found_credentials = None
595        self.should_raise = False
596        parent = self
597
598        class TestRequiredHandler(webapp2.RequestHandler):
599            @decorator.oauth_required
600            def get(self):
601                parent.assertTrue(decorator.has_credentials())
602                parent.had_credentials = True
603                parent.found_credentials = decorator.credentials
604                if parent.should_raise:
605                    raise parent.should_raise
606
607        class TestAwareHandler(webapp2.RequestHandler):
608            @decorator.oauth_aware
609            def get(self, *args, **kwargs):
610                self.response.out.write('Hello World!')
611                assert(kwargs['year'] == '2012')
612                assert(kwargs['month'] == '01')
613                if decorator.has_credentials():
614                    parent.had_credentials = True
615                    parent.found_credentials = decorator.credentials
616                if parent.should_raise:
617                    raise parent.should_raise
618
619        routes = [
620            ('/oauth2callback', self.decorator.callback_handler()),
621            ('/foo_path', TestRequiredHandler),
622            webapp2.Route(r'/bar_path/<year:\d{4}>/<month:\d{2}>',
623                          handler=TestAwareHandler, name='bar'),
624        ]
625        application = webapp2.WSGIApplication(routes, debug=True)
626
627        self.app = TestApp(application, extra_environ={
628            'wsgi.url_scheme': 'http',
629            'HTTP_HOST': 'localhost',
630        })
631        self.current_user = user_mock()
632        users.get_current_user = self.current_user
633        self.httplib2_orig = httplib2.Http
634        httplib2.Http = Http2Mock
635
636    def tearDown(self):
637        self.testbed.deactivate()
638        httplib2.Http = self.httplib2_orig
639
640    def test_in_error(self):
641        # NOTE: This branch is never reached. _in_error is not set by any code
642        # path. It appears to be intended to be set during construction.
643        self.decorator._in_error = True
644        self.decorator._message = 'foobar'
645
646        response = self.app.get('http://localhost/foo_path')
647        self.assertIn('foobar', response.body)
648
649        response = self.app.get('http://localhost/bar_path/1234/56')
650        self.assertIn('foobar', response.body)
651
652    def test_callback_application(self):
653        app = self.decorator.callback_application()
654        self.assertEqual(
655            app.router.match_routes[0].handler.__name__,
656            'OAuth2Handler')
657
658    def test_required(self):
659        # An initial request to an oauth_required decorated path should be a
660        # redirect to start the OAuth dance.
661        self.assertEqual(self.decorator.flow, None)
662        self.assertEqual(self.decorator.credentials, None)
663        response = self.app.get('http://localhost/foo_path')
664        self.assertTrue(response.status.startswith('302'))
665        q = urllib.parse.parse_qs(
666            response.headers['Location'].split('?', 1)[1])
667        self.assertEqual('http://localhost/oauth2callback',
668                         q['redirect_uri'][0])
669        self.assertEqual('foo_client_id', q['client_id'][0])
670        self.assertEqual('foo_scope bar_scope', q['scope'][0])
671        self.assertEqual('http://localhost/foo_path',
672                         q['state'][0].rsplit(':', 1)[0])
673        self.assertEqual('code', q['response_type'][0])
674        self.assertEqual(False, self.decorator.has_credentials())
675
676        with mock.patch.object(appengine, '_parse_state_value',
677                               return_value='foo_path',
678                               autospec=True) as parse_state_value:
679            # Now simulate the callback to /oauth2callback.
680            response = self.app.get('/oauth2callback', {
681                'code': 'foo_access_code',
682                'state': 'foo_path:xsrfkey123',
683            })
684            parts = response.headers['Location'].split('?', 1)
685            self.assertEqual('http://localhost/foo_path', parts[0])
686            self.assertEqual(None, self.decorator.credentials)
687            if self.decorator._token_response_param:
688                response_query = urllib.parse.parse_qs(parts[1])
689                response = response_query[
690                    self.decorator._token_response_param][0]
691                self.assertEqual(Http2Mock.content,
692                                 json.loads(urllib.parse.unquote(response)))
693            self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
694            self.assertEqual(self.decorator.credentials,
695                             self.decorator._tls.credentials)
696
697            parse_state_value.assert_called_once_with(
698                'foo_path:xsrfkey123', self.current_user)
699
700        # Now requesting the decorated path should work.
701        response = self.app.get('/foo_path')
702        self.assertEqual('200 OK', response.status)
703        self.assertEqual(True, self.had_credentials)
704        self.assertEqual('foo_refresh_token',
705                         self.found_credentials.refresh_token)
706        self.assertEqual('foo_access_token',
707                         self.found_credentials.access_token)
708        self.assertEqual(None, self.decorator.credentials)
709
710        # Raising an exception still clears the Credentials.
711        self.should_raise = Exception('')
712        with self.assertRaises(Exception):
713            self.app.get('/foo_path')
714        self.should_raise = False
715        self.assertEqual(None, self.decorator.credentials)
716
717        # Access token refresh error should start the dance again
718        self.should_raise = client.AccessTokenRefreshError()
719        response = self.app.get('/foo_path')
720        self.should_raise = False
721        self.assertTrue(response.status.startswith('302'))
722        query_params = urllib.parse.parse_qs(
723            response.headers['Location'].split('?', 1)[1])
724        self.assertEqual('http://localhost/oauth2callback',
725                         query_params['redirect_uri'][0])
726
727        # Invalidate the stored Credentials.
728        self.found_credentials.invalid = True
729        self.found_credentials.store.put(self.found_credentials)
730
731        # Invalid Credentials should start the OAuth dance again.
732        response = self.app.get('/foo_path')
733        self.assertTrue(response.status.startswith('302'))
734        query_params = urllib.parse.parse_qs(
735            response.headers['Location'].split('?', 1)[1])
736        self.assertEqual('http://localhost/oauth2callback',
737                         query_params['redirect_uri'][0])
738
739    def test_storage_delete(self):
740        # An initial request to an oauth_required decorated path should be a
741        # redirect to start the OAuth dance.
742        response = self.app.get('/foo_path')
743        self.assertTrue(response.status.startswith('302'))
744
745        with mock.patch.object(appengine, '_parse_state_value',
746                               return_value='foo_path',
747                               autospec=True) as parse_state_value:
748            # Now simulate the callback to /oauth2callback.
749            response = self.app.get('/oauth2callback', {
750                'code': 'foo_access_code',
751                'state': 'foo_path:xsrfkey123',
752            })
753            self.assertEqual('http://localhost/foo_path',
754                             response.headers['Location'])
755            self.assertEqual(None, self.decorator.credentials)
756
757            # Now requesting the decorated path should work.
758            response = self.app.get('/foo_path')
759
760            self.assertTrue(self.had_credentials)
761
762            # Credentials should be cleared after each call.
763            self.assertEqual(None, self.decorator.credentials)
764
765            # Invalidate the stored Credentials.
766            self.found_credentials.store.delete()
767
768            # Invalid Credentials should start the OAuth dance again.
769            response = self.app.get('/foo_path')
770            self.assertTrue(response.status.startswith('302'))
771
772            parse_state_value.assert_called_once_with(
773                'foo_path:xsrfkey123', self.current_user)
774
775    def test_aware(self):
776        # An initial request to an oauth_aware decorated path should
777        # not redirect.
778        response = self.app.get('http://localhost/bar_path/2012/01')
779        self.assertEqual('Hello World!', response.body)
780        self.assertEqual('200 OK', response.status)
781        self.assertEqual(False, self.decorator.has_credentials())
782        url = self.decorator.authorize_url()
783        q = urllib.parse.parse_qs(url.split('?', 1)[1])
784        self.assertEqual('http://localhost/oauth2callback',
785                         q['redirect_uri'][0])
786        self.assertEqual('foo_client_id', q['client_id'][0])
787        self.assertEqual('foo_scope bar_scope', q['scope'][0])
788        self.assertEqual('http://localhost/bar_path/2012/01',
789                         q['state'][0].rsplit(':', 1)[0])
790        self.assertEqual('code', q['response_type'][0])
791
792        with mock.patch.object(appengine, '_parse_state_value',
793                               return_value='bar_path',
794                               autospec=True) as parse_state_value:
795            # Now simulate the callback to /oauth2callback.
796            url = self.decorator.authorize_url()
797            response = self.app.get('/oauth2callback', {
798                'code': 'foo_access_code',
799                'state': 'bar_path:xsrfkey456',
800            })
801
802            self.assertEqual('http://localhost/bar_path',
803                             response.headers['Location'])
804            self.assertEqual(False, self.decorator.has_credentials())
805            parse_state_value.assert_called_once_with(
806                'bar_path:xsrfkey456', self.current_user)
807
808        # Now requesting the decorated path will have credentials.
809        response = self.app.get('/bar_path/2012/01')
810        self.assertEqual('200 OK', response.status)
811        self.assertEqual('Hello World!', response.body)
812        self.assertEqual(True, self.had_credentials)
813        self.assertEqual('foo_refresh_token',
814                         self.found_credentials.refresh_token)
815        self.assertEqual('foo_access_token',
816                         self.found_credentials.access_token)
817
818        # Credentials should be cleared after each call.
819        self.assertEqual(None, self.decorator.credentials)
820
821        # Raising an exception still clears the Credentials.
822        self.should_raise = Exception('')
823        with self.assertRaises(Exception):
824            self.app.get('/bar_path/2012/01')
825        self.should_raise = False
826        self.assertEqual(None, self.decorator.credentials)
827
828    def test_error_in_step2(self):
829        # An initial request to an oauth_aware decorated path should
830        # not redirect.
831        response = self.app.get('/bar_path/2012/01')
832        self.decorator.authorize_url()
833        response = self.app.get('/oauth2callback', {
834            'error': 'Bad<Stuff>Happened\''
835        })
836        self.assertEqual('200 OK', response.status)
837        self.assertTrue('Bad&lt;Stuff&gt;Happened&#39;' in response.body)
838
839    def test_kwargs_are_passed_to_underlying_flow(self):
840        decorator = appengine.OAuth2Decorator(
841            client_id='foo_client_id', client_secret='foo_client_secret',
842            user_agent='foo_user_agent', scope=['foo_scope', 'bar_scope'],
843            access_type='offline', prompt='consent',
844            revoke_uri='dummy_revoke_uri')
845        request_handler = MockRequestHandler()
846        decorator._create_flow(request_handler)
847
848        self.assertEqual('https://example.org/oauth2callback',
849                         decorator.flow.redirect_uri)
850        self.assertEqual('offline', decorator.flow.params['access_type'])
851        self.assertEqual('consent', decorator.flow.params['prompt'])
852        self.assertEqual('foo_user_agent', decorator.flow.user_agent)
853        self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
854        self.assertEqual(None, decorator.flow.params.get('user_agent', None))
855        self.assertEqual(decorator.flow, decorator._tls.flow)
856
857    def test_token_response_param(self):
858        self.decorator._token_response_param = 'foobar'
859        self.test_required()
860
861    def test_decorator_from_client_secrets(self):
862        decorator = appengine.OAuth2DecoratorFromClientSecrets(
863            datafile('client_secrets.json'),
864            scope=['foo_scope', 'bar_scope'])
865        self._finish_setup(decorator, user_mock=UserMock)
866
867        self.assertFalse(decorator._in_error)
868        self.decorator = decorator
869        self.test_required()
870        http = self.decorator.http()
871        self.assertEquals('foo_access_token',
872                          http.request.credentials.access_token)
873
874        # revoke_uri is not required
875        self.assertEqual(self.decorator._revoke_uri,
876                         'https://accounts.google.com/o/oauth2/revoke')
877        self.assertEqual(self.decorator._revoke_uri,
878                         self.decorator.credentials.revoke_uri)
879
880    def test_decorator_from_client_secrets_toplevel(self):
881        decorator_patch = mock.patch(
882            'oauth2client.contrib.appengine.OAuth2DecoratorFromClientSecrets')
883
884        with decorator_patch as decorator_mock:
885            filename = datafile('client_secrets.json')
886            appengine.oauth2decorator_from_clientsecrets(
887                filename, scope='foo_scope')
888            decorator_mock.assert_called_once_with(
889                filename,
890                'foo_scope',
891                cache=None,
892                message=None)
893
894    def test_decorator_from_client_secrets_bad_type(self):
895        # NOTE: this code path is not currently reachable, as the only types
896        # that oauth2client.clientsecrets can load is web and installed, so
897        # this test forces execution of this code path. Despite not being
898        # normally reachable, this should remain in case future types of
899        # credentials are added.
900
901        loadfile_patch = mock.patch(
902            'oauth2client.contrib.appengine.clientsecrets.loadfile')
903        with loadfile_patch as loadfile_mock:
904            loadfile_mock.return_value = ('badtype', None)
905            with self.assertRaises(clientsecrets.InvalidClientSecretsError):
906                appengine.OAuth2DecoratorFromClientSecrets(
907                    'doesntmatter.json',
908                    scope=['foo_scope', 'bar_scope'])
909
910    def test_decorator_from_client_secrets_kwargs(self):
911        decorator = appengine.OAuth2DecoratorFromClientSecrets(
912            datafile('client_secrets.json'),
913            scope=['foo_scope', 'bar_scope'],
914            prompt='consent')
915        self.assertIn('prompt', decorator._kwargs)
916
917    def test_decorator_from_cached_client_secrets(self):
918        cache_mock = CacheMock()
919        load_and_cache('client_secrets.json', 'secret', cache_mock)
920        decorator = appengine.OAuth2DecoratorFromClientSecrets(
921            # filename, scope, message=None, cache=None
922            'secret', '', cache=cache_mock)
923        self.assertFalse(decorator._in_error)
924
925    def test_decorator_from_client_secrets_not_logged_in_required(self):
926        decorator = appengine.OAuth2DecoratorFromClientSecrets(
927            datafile('client_secrets.json'),
928            scope=['foo_scope', 'bar_scope'], message='NotLoggedInMessage')
929        self.decorator = decorator
930        self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
931
932        self.assertFalse(decorator._in_error)
933
934        # An initial request to an oauth_required decorated path should be a
935        # redirect to login.
936        response = self.app.get('/foo_path')
937        self.assertTrue(response.status.startswith('302'))
938        self.assertTrue('Login' in str(response))
939
940    def test_decorator_from_client_secrets_not_logged_in_aware(self):
941        decorator = appengine.OAuth2DecoratorFromClientSecrets(
942            datafile('client_secrets.json'),
943            scope=['foo_scope', 'bar_scope'], message='NotLoggedInMessage')
944        self.decorator = decorator
945        self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
946
947        # An initial request to an oauth_aware decorated path should be a
948        # redirect to login.
949        response = self.app.get('/bar_path/2012/03')
950        self.assertTrue(response.status.startswith('302'))
951        self.assertTrue('Login' in str(response))
952
953    def test_decorator_from_unfilled_client_secrets_required(self):
954        MESSAGE = 'File is missing'
955        try:
956            appengine.OAuth2DecoratorFromClientSecrets(
957                datafile('unfilled_client_secrets.json'),
958                scope=['foo_scope', 'bar_scope'], message=MESSAGE)
959        except clientsecrets.InvalidClientSecretsError:
960            pass
961
962    def test_decorator_from_unfilled_client_secrets_aware(self):
963        MESSAGE = 'File is missing'
964        try:
965            appengine.OAuth2DecoratorFromClientSecrets(
966                datafile('unfilled_client_secrets.json'),
967                scope=['foo_scope', 'bar_scope'], message=MESSAGE)
968        except clientsecrets.InvalidClientSecretsError:
969            pass
970
971    def test_decorator_from_client_secrets_with_optional_settings(self):
972        # Test that the decorator works with the absense of a revoke_uri in
973        # the client secrets.
974        loadfile_patch = mock.patch(
975            'oauth2client.contrib.appengine.clientsecrets.loadfile')
976        with loadfile_patch as loadfile_mock:
977            loadfile_mock.return_value = (clientsecrets.TYPE_WEB, {
978                "client_id": "foo_client_id",
979                "client_secret": "foo_client_secret",
980                "redirect_uris": [],
981                "auth_uri": "https://accounts.google.com/o/oauth2/v2/auth",
982                "token_uri": "https://www.googleapis.com/oauth2/v4/token",
983                # No revoke URI
984            })
985
986            decorator = appengine.OAuth2DecoratorFromClientSecrets(
987                'doesntmatter.json',
988                scope=['foo_scope', 'bar_scope'])
989
990        self.assertEqual(decorator._revoke_uri, oauth2client.GOOGLE_REVOKE_URI)
991        # This is never set, but it's consistent with other tests.
992        self.assertFalse(decorator._in_error)
993
994    def test_invalid_state(self):
995        with mock.patch.object(appengine, '_parse_state_value',
996                               return_value=None, autospec=True):
997            # Now simulate the callback to /oauth2callback.
998            response = self.app.get('/oauth2callback', {
999                'code': 'foo_access_code',
1000                'state': 'foo_path:xsrfkey123',
1001            })
1002            self.assertEqual('200 OK', response.status)
1003            self.assertEqual('The authorization request failed', response.body)
1004
1005
1006class DecoratorXsrfSecretTests(unittest2.TestCase):
1007    """Test xsrf_secret_key."""
1008
1009    def setUp(self):
1010        self.testbed = testbed.Testbed()
1011        self.testbed.activate()
1012        self.testbed.init_datastore_v3_stub()
1013        self.testbed.init_memcache_stub()
1014
1015    def tearDown(self):
1016        self.testbed.deactivate()
1017
1018    def test_build_and_parse_state(self):
1019        secret = appengine.xsrf_secret_key()
1020
1021        # Secret shouldn't change from call to call.
1022        secret2 = appengine.xsrf_secret_key()
1023        self.assertEqual(secret, secret2)
1024
1025        # Secret shouldn't change if memcache goes away.
1026        memcache.delete(appengine.XSRF_MEMCACHE_ID,
1027                        namespace=appengine.OAUTH2CLIENT_NAMESPACE)
1028        secret3 = appengine.xsrf_secret_key()
1029        self.assertEqual(secret2, secret3)
1030
1031        # Secret should change if both memcache and the model goes away.
1032        memcache.delete(appengine.XSRF_MEMCACHE_ID,
1033                        namespace=appengine.OAUTH2CLIENT_NAMESPACE)
1034        model = appengine.SiteXsrfSecretKey.get_or_insert('site')
1035        model.delete()
1036
1037        secret4 = appengine.xsrf_secret_key()
1038        self.assertNotEqual(secret3, secret4)
1039
1040    def test_ndb_insert_db_get(self):
1041        secret = appengine._generate_new_xsrf_secret_key()
1042        appengine.SiteXsrfSecretKeyNDB(id='site', secret=secret).put()
1043
1044        site_key = appengine.SiteXsrfSecretKey.get_by_key_name('site')
1045        self.assertEqual(site_key.secret, secret)
1046
1047    def test_db_insert_ndb_get(self):
1048        secret = appengine._generate_new_xsrf_secret_key()
1049        appengine.SiteXsrfSecretKey(key_name='site', secret=secret).put()
1050
1051        site_key = appengine.SiteXsrfSecretKeyNDB.get_by_id('site')
1052        self.assertEqual(site_key.secret, secret)
1053
1054
1055class DecoratorXsrfProtectionTests(unittest2.TestCase):
1056    """Test _build_state_value and _parse_state_value."""
1057
1058    def setUp(self):
1059        self.testbed = testbed.Testbed()
1060        self.testbed.activate()
1061        self.testbed.init_datastore_v3_stub()
1062        self.testbed.init_memcache_stub()
1063
1064    def tearDown(self):
1065        self.testbed.deactivate()
1066
1067    def test_build_and_parse_state(self):
1068        state = appengine._build_state_value(MockRequestHandler(), UserMock())
1069        self.assertEqual(
1070            'https://example.org',
1071            appengine._parse_state_value(state, UserMock()))
1072        redirect_uri = appengine._parse_state_value(state[1:], UserMock())
1073        self.assertIsNone(redirect_uri)
1074