1#!/usr/bin/python
2
3import unittest, time
4import common
5from autotest_lib.client.common_lib import global_config
6from autotest_lib.client.common_lib.test_utils import mock
7from autotest_lib.database import database_connection
8
9_CONFIG_SECTION = 'AUTOTEST_WEB'
10_HOST = 'myhost'
11_USER = 'myuser'
12_PASS = 'mypass'
13_DB_NAME = 'mydb'
14_DB_TYPE = 'mydbtype'
15
16_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
17                       db_name=_DB_NAME)
18_RECONNECT_DELAY = 10
19
20class FakeDatabaseError(Exception):
21    pass
22
23
24class DatabaseConnectionTest(unittest.TestCase):
25    def setUp(self):
26        self.god = mock.mock_god()
27        self.god.stub_function(time, 'sleep')
28
29
30    def tearDown(self):
31        global_config.global_config.reset_config_values()
32        self.god.unstub_all()
33
34
35    def _get_database_connection(self, config_section=_CONFIG_SECTION):
36        if config_section == _CONFIG_SECTION:
37            self._override_config()
38        db = database_connection.DatabaseConnection(config_section)
39
40        self._fake_backend = self.god.create_mock_class(
41            database_connection._GenericBackend, 'fake_backend')
42        for exception in database_connection._DB_EXCEPTIONS:
43            setattr(self._fake_backend, exception, FakeDatabaseError)
44        self._fake_backend.rowcount = 0
45
46        def get_fake_backend(db_type):
47            self._db_type = db_type
48            return self._fake_backend
49        self.god.stub_with(db, '_get_backend', get_fake_backend)
50
51        db.reconnect_delay_sec = _RECONNECT_DELAY
52        return db
53
54
55    def _override_config(self):
56        c = global_config.global_config
57        c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
58        c.override_config_value(_CONFIG_SECTION, 'user', _USER)
59        c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
60        c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
61        c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)
62
63
64    def test_connect(self):
65        db = self._get_database_connection(config_section=None)
66        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
67
68        db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
69                   password=_PASS, db_name=_DB_NAME)
70
71        self.assertEquals(self._db_type, _DB_TYPE)
72        self.god.check_playback()
73
74
75    def test_global_config(self):
76        db = self._get_database_connection()
77        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
78
79        db.connect()
80
81        self.assertEquals(self._db_type, _DB_TYPE)
82        self.god.check_playback()
83
84
85    def _expect_reconnect(self, fail=False):
86        self._fake_backend.disconnect.expect_call()
87        call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
88        if fail:
89            call.and_raises(FakeDatabaseError())
90
91
92    def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
93        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
94            FakeDatabaseError())
95        for i in xrange(num_reconnects):
96            time.sleep.expect_call(_RECONNECT_DELAY)
97            if i < num_reconnects - 1:
98                self._expect_reconnect(fail=True)
99            else:
100                self._expect_reconnect(fail=fail_last)
101
102
103    def test_connect_retry(self):
104        db = self._get_database_connection()
105        self._expect_fail_and_reconnect(1)
106
107        db.connect()
108        self.god.check_playback()
109
110        self._fake_backend.disconnect.expect_call()
111        self._expect_fail_and_reconnect(0)
112        self.assertRaises(FakeDatabaseError, db.connect,
113                          try_reconnecting=False)
114        self.god.check_playback()
115
116        db.reconnect_enabled = False
117        self._fake_backend.disconnect.expect_call()
118        self._expect_fail_and_reconnect(0)
119        self.assertRaises(FakeDatabaseError, db.connect)
120        self.god.check_playback()
121
122
123    def test_max_reconnect(self):
124        db = self._get_database_connection()
125        db.max_reconnect_attempts = 5
126        self._expect_fail_and_reconnect(5, fail_last=True)
127
128        self.assertRaises(FakeDatabaseError, db.connect)
129        self.god.check_playback()
130
131
132    def test_reconnect_forever(self):
133        db = self._get_database_connection()
134        db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
135        self._expect_fail_and_reconnect(30)
136
137        db.connect()
138        self.god.check_playback()
139
140
141    def _simple_connect(self, db):
142        self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
143        db.connect()
144        self.god.check_playback()
145
146
147    def test_disconnect(self):
148        db = self._get_database_connection()
149        self._simple_connect(db)
150        self._fake_backend.disconnect.expect_call()
151
152        db.disconnect()
153        self.god.check_playback()
154
155
156    def test_execute(self):
157        db = self._get_database_connection()
158        self._simple_connect(db)
159        params = object()
160        self._fake_backend.execute.expect_call('query', params)
161
162        db.execute('query', params)
163        self.god.check_playback()
164
165
166    def test_execute_retry(self):
167        db = self._get_database_connection()
168        self._simple_connect(db)
169        self._fake_backend.execute.expect_call('query', None).and_raises(
170            FakeDatabaseError())
171        self._expect_reconnect()
172        self._fake_backend.execute.expect_call('query', None)
173
174        db.execute('query')
175        self.god.check_playback()
176
177        self._fake_backend.execute.expect_call('query', None).and_raises(
178            FakeDatabaseError())
179        self.assertRaises(FakeDatabaseError, db.execute, 'query',
180                          try_reconnecting=False)
181
182
183if __name__ == '__main__':
184    unittest.main()
185