1# Copyright 2016 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import json
16import os
17
18from google.appengine.ext import ndb
19from google.appengine.ext import testbed
20import mock
21import unittest2
22
23from oauth2client import client
24from oauth2client.contrib import appengine
25
26
27DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
28
29
30def datafile(filename):
31    return os.path.join(DATA_DIR, filename)
32
33
34class TestNDBModel(ndb.Model):
35    flow = appengine.FlowNDBProperty()
36    creds = appengine.CredentialsNDBProperty()
37
38
39class TestFlowNDBProperty(unittest2.TestCase):
40
41    def setUp(self):
42        self.testbed = testbed.Testbed()
43        self.testbed.activate()
44        self.testbed.init_datastore_v3_stub()
45        self.testbed.init_memcache_stub()
46
47    def tearDown(self):
48        self.testbed.deactivate()
49
50    def test_flow_get_put(self):
51        instance = TestNDBModel(
52            flow=client.flow_from_clientsecrets(
53                datafile('client_secrets.json'), 'foo', redirect_uri='oob'),
54            id='foo'
55        )
56        instance.put()
57        retrieved = TestNDBModel.get_by_id('foo')
58
59        self.assertEqual('foo_client_id', retrieved.flow.client_id)
60
61    @mock.patch('oauth2client.contrib._appengine_ndb._LOGGER')
62    def test_validate_success(self, mock_logger):
63        flow_prop = TestNDBModel.flow
64        flow_val = client.flow_from_clientsecrets(
65            datafile('client_secrets.json'), 'foo', redirect_uri='oob')
66        flow_prop._validate(flow_val)
67        mock_logger.info.assert_called_once_with('validate: Got type %s',
68                                                 type(flow_val))
69
70    @mock.patch('oauth2client.contrib._appengine_ndb._LOGGER')
71    def test_validate_none(self, mock_logger):
72        flow_prop = TestNDBModel.flow
73        flow_val = None
74        flow_prop._validate(flow_val)
75        mock_logger.info.assert_called_once_with('validate: Got type %s',
76                                                 type(flow_val))
77
78    @mock.patch('oauth2client.contrib._appengine_ndb._LOGGER')
79    def test_validate_bad_type(self, mock_logger):
80        flow_prop = TestNDBModel.flow
81        flow_val = object()
82        with self.assertRaises(TypeError):
83            flow_prop._validate(flow_val)
84        mock_logger.info.assert_called_once_with('validate: Got type %s',
85                                                 type(flow_val))
86
87
88class TestCredentialsNDBProperty(unittest2.TestCase):
89
90    def setUp(self):
91        self.testbed = testbed.Testbed()
92        self.testbed.activate()
93        self.testbed.init_datastore_v3_stub()
94        self.testbed.init_memcache_stub()
95
96    def tearDown(self):
97        self.testbed.deactivate()
98
99    def test_valid_creds_get_put(self):
100        creds = client.Credentials()
101        instance = TestNDBModel(creds=creds, id='bar')
102        instance.put()
103        retrieved = TestNDBModel.get_by_id('bar')
104        self.assertIsInstance(retrieved.creds, client.Credentials)
105
106    @mock.patch('oauth2client.contrib._appengine_ndb._LOGGER')
107    def test_validate_success(self, mock_logger):
108        creds_prop = TestNDBModel.creds
109        creds_val = client.Credentials()
110        creds_prop._validate(creds_val)
111        mock_logger.info.assert_called_once_with('validate: Got type %s',
112                                                 type(creds_val))
113
114    @mock.patch('oauth2client.contrib._appengine_ndb._LOGGER')
115    def test_validate_none(self, mock_logger):
116        creds_prop = TestNDBModel.creds
117        creds_val = None
118        creds_prop._validate(creds_val)
119        mock_logger.info.assert_called_once_with('validate: Got type %s',
120                                                 type(creds_val))
121
122    @mock.patch('oauth2client.contrib._appengine_ndb._LOGGER')
123    def test_validate_bad_type(self, mock_logger):
124        creds_prop = TestNDBModel.creds
125        creds_val = object()
126        with self.assertRaises(TypeError):
127            creds_prop._validate(creds_val)
128        mock_logger.info.assert_called_once_with('validate: Got type %s',
129                                                 type(creds_val))
130
131    def test__to_base_type_valid_creds(self):
132        creds_prop = TestNDBModel.creds
133        creds = client.Credentials()
134        creds_json = json.loads(creds_prop._to_base_type(creds))
135        self.assertDictEqual(creds_json, {
136            '_class': 'Credentials',
137            '_module': 'oauth2client.client',
138            'token_expiry': None,
139        })
140
141    def test__to_base_type_null_creds(self):
142        creds_prop = TestNDBModel.creds
143        self.assertEqual(creds_prop._to_base_type(None), '')
144
145    def test__from_base_type_valid_creds(self):
146        creds_prop = TestNDBModel.creds
147        creds_json = json.dumps({
148            '_class': 'Credentials',
149            '_module': 'oauth2client.client',
150            'token_expiry': None,
151        })
152        creds = creds_prop._from_base_type(creds_json)
153        self.assertIsInstance(creds, client.Credentials)
154
155    def test__from_base_type_false_value(self):
156        creds_prop = TestNDBModel.creds
157        self.assertIsNone(creds_prop._from_base_type(''))
158        self.assertIsNone(creds_prop._from_base_type(False))
159        self.assertIsNone(creds_prop._from_base_type(None))
160        self.assertIsNone(creds_prop._from_base_type([]))
161        self.assertIsNone(creds_prop._from_base_type({}))
162
163    def test__from_base_type_bad_json(self):
164        creds_prop = TestNDBModel.creds
165        creds_json = '{JK-I-AM-NOT-JSON'
166        self.assertIsNone(creds_prop._from_base_type(creds_json))
167