1# Copyright 2016 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16
17import sqlalchemy
18import sqlalchemy.ext.declarative
19import sqlalchemy.orm
20import unittest2
21
22import oauth2client
23import oauth2client.client
24import oauth2client.contrib.sqlalchemy
25
26Base = sqlalchemy.ext.declarative.declarative_base()
27
28
29class DummyModel(Base):
30    __tablename__ = 'dummy'
31
32    id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
33    # we will query against this, because of ROWID
34    key = sqlalchemy.Column(sqlalchemy.Integer)
35    credentials = sqlalchemy.Column(
36        oauth2client.contrib.sqlalchemy.CredentialsType)
37
38
39class TestSQLAlchemyStorage(unittest2.TestCase):
40    def setUp(self):
41        engine = sqlalchemy.create_engine('sqlite://')
42        Base.metadata.create_all(engine)
43
44        self.session = sqlalchemy.orm.sessionmaker(bind=engine)
45        self.credentials = oauth2client.client.OAuth2Credentials(
46            access_token='token',
47            client_id='client_id',
48            client_secret='client_secret',
49            refresh_token='refresh_token',
50            token_expiry=datetime.datetime.utcnow(),
51            token_uri=oauth2client.GOOGLE_TOKEN_URI,
52            user_agent='DummyAgent',
53        )
54
55    def tearDown(self):
56        session = self.session()
57        session.query(DummyModel).filter_by(key=1).delete()
58        session.commit()
59
60    def compare_credentials(self, result):
61        self.assertEqual(result.access_token, self.credentials.access_token)
62        self.assertEqual(result.client_id, self.credentials.client_id)
63        self.assertEqual(result.client_secret, self.credentials.client_secret)
64        self.assertEqual(result.refresh_token, self.credentials.refresh_token)
65        self.assertEqual(result.token_expiry, self.credentials.token_expiry)
66        self.assertEqual(result.token_uri, self.credentials.token_uri)
67        self.assertEqual(result.user_agent, self.credentials.user_agent)
68
69    def test_get(self):
70        session = self.session()
71        credentials_storage = oauth2client.contrib.sqlalchemy.Storage(
72            session=session,
73            model_class=DummyModel,
74            key_name='key',
75            key_value=1,
76            property_name='credentials',
77        )
78        self.assertIsNone(credentials_storage.get())
79        session.add(DummyModel(
80            key=1,
81            credentials=self.credentials,
82        ))
83        session.commit()
84
85        self.compare_credentials(credentials_storage.get())
86
87    def test_put(self):
88        session = self.session()
89        oauth2client.contrib.sqlalchemy.Storage(
90            session=session,
91            model_class=DummyModel,
92            key_name='key',
93            key_value=1,
94            property_name='credentials',
95        ).put(self.credentials)
96        session.commit()
97
98        entity = session.query(DummyModel).filter_by(key=1).first()
99        self.compare_credentials(entity.credentials)
100
101    def test_delete(self):
102        session = self.session()
103        session.add(DummyModel(
104            key=1,
105            credentials=self.credentials,
106        ))
107        session.commit()
108
109        query = session.query(DummyModel).filter_by(key=1)
110        self.assertIsNotNone(query.first())
111        oauth2client.contrib.sqlalchemy.Storage(
112            session=session,
113            model_class=DummyModel,
114            key_name='key',
115            key_value=1,
116            property_name='credentials',
117        ).delete()
118        session.commit()
119        self.assertIsNone(query.first())
120