1# pylint: disable-msg=C0111
2
3import re, time, traceback
4import common
5from autotest_lib.client.common_lib import global_config
6
7RECONNECT_FOREVER = object()
8
9_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
10_GLOBAL_CONFIG_NAMES = {
11    'username' : 'user',
12    'db_name' : 'database',
13}
14
15def _copy_exceptions(source, destination):
16    for exception_name in _DB_EXCEPTIONS:
17        try:
18            setattr(destination, exception_name,
19                    getattr(source, exception_name))
20        except AttributeError:
21            # Under the django backend:
22            # Django 1.3 does not have OperationalError and ProgrammingError.
23            # Let's just mock these classes with the base DatabaseError.
24            setattr(destination, exception_name,
25                    getattr(source, 'DatabaseError'))
26
27
28class _GenericBackend(object):
29    def __init__(self, database_module):
30        self._database_module = database_module
31        self._connection = None
32        self._cursor = None
33        self.rowcount = None
34        _copy_exceptions(database_module, self)
35
36
37    def connect(self, host=None, username=None, password=None, db_name=None):
38        """
39        This is assumed to enable autocommit.
40        """
41        raise NotImplementedError
42
43
44    def disconnect(self):
45        if self._connection:
46            self._connection.close()
47        self._connection = None
48        self._cursor = None
49
50
51    def execute(self, query, parameters=None):
52        if parameters is None:
53            parameters = ()
54        self._cursor.execute(query, parameters)
55        self.rowcount = self._cursor.rowcount
56        return self._cursor.fetchall()
57
58
59class _MySqlBackend(_GenericBackend):
60    def __init__(self):
61        import MySQLdb
62        super(_MySqlBackend, self).__init__(MySQLdb)
63
64
65    @staticmethod
66    def convert_boolean(boolean, conversion_dict):
67        'Convert booleans to integer strings'
68        return str(int(boolean))
69
70
71    def connect(self, host=None, username=None, password=None, db_name=None):
72        import MySQLdb.converters
73        convert_dict = MySQLdb.converters.conversions
74        convert_dict.setdefault(bool, self.convert_boolean)
75
76        self._connection = self._database_module.connect(
77            host=host, user=username, passwd=password, db=db_name,
78            conv=convert_dict)
79        self._connection.autocommit(True)
80        self._cursor = self._connection.cursor()
81
82
83class _SqliteBackend(_GenericBackend):
84    def __init__(self):
85        try:
86            from pysqlite2 import dbapi2
87        except ImportError:
88            from sqlite3 import dbapi2
89        super(_SqliteBackend, self).__init__(dbapi2)
90        self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
91                                             re.IGNORECASE)
92
93
94    def connect(self, host=None, username=None, password=None, db_name=None):
95        self._connection = self._database_module.connect(db_name)
96        self._connection.isolation_level = None # enable autocommit
97        self._cursor = self._connection.cursor()
98
99
100    def execute(self, query, parameters=None):
101        # pysqlite2 uses paramstyle=qmark
102        # TODO: make this more sophisticated if necessary
103        query = query.replace('%s', '?')
104        # pysqlite2 can't handle parameters=None (it throws a nonsense
105        # exception)
106        if parameters is None:
107            parameters = ()
108        # sqlite3 doesn't support MySQL's LAST_INSERT_ID().  Instead it has
109        # something similar called LAST_INSERT_ROWID() that will do enough of
110        # what we want (for our non-concurrent unittest use case).
111        query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
112        return super(_SqliteBackend, self).execute(query, parameters)
113
114
115class _DjangoBackend(_GenericBackend):
116    def __init__(self):
117        from django.db import backend, connection, transaction
118        import django.db as django_db
119        super(_DjangoBackend, self).__init__(django_db)
120        self._django_connection = connection
121        self._django_transaction = transaction
122
123
124    def connect(self, host=None, username=None, password=None, db_name=None):
125        self._connection = self._django_connection
126        self._cursor = self._connection.cursor()
127
128
129    def execute(self, query, parameters=None):
130        try:
131            return super(_DjangoBackend, self).execute(query,
132                                                       parameters=parameters)
133        finally:
134            self._django_transaction.commit_unless_managed()
135
136
137_BACKEND_MAP = {
138    'mysql': _MySqlBackend,
139    'sqlite': _SqliteBackend,
140    'django': _DjangoBackend,
141}
142
143
144class DatabaseConnection(object):
145    """
146    Generic wrapper for a database connection.  Supports both mysql and sqlite
147    backends.
148
149    Public attributes:
150    * reconnect_enabled: if True, when an OperationalError occurs the class will
151      try to reconnect to the database automatically.
152    * reconnect_delay_sec: seconds to wait before reconnecting
153    * max_reconnect_attempts: maximum number of time to try reconnecting before
154      giving up.  Setting to RECONNECT_FOREVER removes the limit.
155    * rowcount - will hold cursor.rowcount after each call to execute().
156    * global_config_section - the section in which to find DB information. this
157      should be passed to the constructor, not set later, and may be None, in
158      which case information must be passed to connect().
159    * debug - if set True, all queries will be printed before being executed
160    """
161    _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
162                            'db_name')
163
164    def __init__(self, global_config_section=None, debug=False):
165        self.global_config_section = global_config_section
166        self._backend = None
167        self.rowcount = None
168        self.debug = debug
169
170        # reconnect defaults
171        self.reconnect_enabled = True
172        self.reconnect_delay_sec = 20
173        self.max_reconnect_attempts = 10
174
175        self._read_options()
176
177
178    def _get_option(self, name, provided_value, use_afe_setting=False):
179        """Get value of given option from global config.
180
181        @param name: Name of the config.
182        @param provided_value: Value being provided to override the one from
183                               global config.
184        @param use_afe_setting: Force to use the settings in AFE, default is
185                                False.
186        """
187        # TODO(dshi): This function returns the option value depends on multiple
188        # conditions. The value of `provided_value` has highest priority, then
189        # the code checks if use_afe_setting is True, if that's the case, force
190        # to use settings in AUTOTEST_WEB. At last the value is retrieved from
191        # specified global config section.
192        # The logic is too complicated for a generic function named like
193        # _get_option. Ideally we want to make it clear from caller that it
194        # wants to get database credential from one of the 3 ways:
195        # 1. Use the credential from given config section
196        # 2. Use the credential from AUTOTEST_WEB section
197        # 3. Use the credential provided by caller.
198        if provided_value is not None:
199            return provided_value
200        section = ('AUTOTEST_WEB' if use_afe_setting else
201                   self.global_config_section)
202        if section:
203            global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
204            return global_config.global_config.get_config_value(
205                    section, global_config_name)
206
207        return getattr(self, name, None)
208
209
210    def _read_options(self, db_type=None, host=None, username=None,
211                      password=None, db_name=None):
212        """Read database information from global config.
213
214        Unless any parameter is specified a value, the connection will use
215        database name from given configure section (self.global_config_section),
216        and database credential from AFE database settings (AUTOTEST_WEB).
217
218        @param db_type: database type, default to None.
219        @param host: database hostname, default to None.
220        @param username: user name for database connection, default to None.
221        @param password: database password, default to None.
222        @param db_name: database name, default to None.
223        """
224        self.db_name = self._get_option('db_name', db_name)
225        use_afe_setting = not bool(db_type or host or username or password)
226
227        # Database credential can be provided by the caller, as passed in from
228        # function connect.
229        self.db_type = self._get_option('db_type', db_type, use_afe_setting)
230        self.host = self._get_option('host', host, use_afe_setting)
231        self.username = self._get_option('username', username, use_afe_setting)
232        self.password = self._get_option('password', password, use_afe_setting)
233
234
235    def _get_backend(self, db_type):
236        if db_type not in _BACKEND_MAP:
237            raise ValueError('Invalid database type: %s, should be one of %s' %
238                             (db_type, ', '.join(_BACKEND_MAP.keys())))
239        backend_class = _BACKEND_MAP[db_type]
240        return backend_class()
241
242
243    def _reached_max_attempts(self, num_attempts):
244        return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
245                num_attempts > self.max_reconnect_attempts)
246
247
248    def _is_reconnect_enabled(self, supplied_param):
249        if supplied_param is not None:
250            return supplied_param
251        return self.reconnect_enabled
252
253
254    def _connect_backend(self, try_reconnecting=None):
255        num_attempts = 0
256        while True:
257            try:
258                self._backend.connect(host=self.host, username=self.username,
259                                      password=self.password,
260                                      db_name=self.db_name)
261                return
262            except self._backend.OperationalError:
263                num_attempts += 1
264                if not self._is_reconnect_enabled(try_reconnecting):
265                    raise
266                if self._reached_max_attempts(num_attempts):
267                    raise
268                traceback.print_exc()
269                print ("Can't connect to database; reconnecting in %s sec" %
270                       self.reconnect_delay_sec)
271                time.sleep(self.reconnect_delay_sec)
272                self.disconnect()
273
274
275    def connect(self, db_type=None, host=None, username=None, password=None,
276                db_name=None, try_reconnecting=None):
277        """
278        Parameters passed to this function will override defaults from global
279        config.  try_reconnecting, if passed, will override
280        self.reconnect_enabled.
281        """
282        self.disconnect()
283        self._read_options(db_type, host, username, password, db_name)
284
285        self._backend = self._get_backend(self.db_type)
286        _copy_exceptions(self._backend, self)
287        self._connect_backend(try_reconnecting)
288
289
290    def disconnect(self):
291        if self._backend:
292            self._backend.disconnect()
293
294
295    def execute(self, query, parameters=None, try_reconnecting=None):
296        """
297        Execute a query and return cursor.fetchall(). try_reconnecting, if
298        passed, will override self.reconnect_enabled.
299        """
300        if self.debug:
301            print 'Executing %s, %s' % (query, parameters)
302        # _connect_backend() contains a retry loop, so don't loop here
303        try:
304            results = self._backend.execute(query, parameters)
305        except self._backend.OperationalError:
306            if not self._is_reconnect_enabled(try_reconnecting):
307                raise
308            traceback.print_exc()
309            print ("MYSQL connection died; reconnecting")
310            self.disconnect()
311            self._connect_backend(try_reconnecting)
312            results = self._backend.execute(query, parameters)
313
314        self.rowcount = self._backend.rowcount
315        return results
316
317
318    def get_database_info(self):
319        return dict((attribute, getattr(self, attribute))
320                    for attribute in self._DATABASE_ATTRIBUTES)
321
322
323    @classmethod
324    def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
325        """
326        Factory method returning a DatabaseConnection for a temporary in-memory
327        database.
328        """
329        database = cls(**constructor_kwargs)
330        database.reconnect_enabled = False
331        database.connect(db_type='sqlite', db_name=file_path)
332        return database
333
334
335class TranslatingDatabase(DatabaseConnection):
336    """
337    Database wrapper than applies arbitrary substitution regexps to each query
338    string.  Useful for SQLite testing.
339    """
340    def __init__(self, translators):
341        """
342        @param translation_regexps: list of callables to apply to each query
343                string (in order).  Each accepts a query string and returns a
344                (possibly) modified query string.
345        """
346        super(TranslatingDatabase, self).__init__()
347        self._translators = translators
348
349
350    def execute(self, query, parameters=None, try_reconnecting=None):
351        for translator in self._translators:
352            query = translator(query)
353        return super(TranslatingDatabase, self).execute(
354                query, parameters=parameters, try_reconnecting=try_reconnecting)
355
356
357    @classmethod
358    def make_regexp_translator(cls, search_re, replace_str):
359        """
360        Returns a translator that calls re.sub() on the query with the given
361        search and replace arguments.
362        """
363        def translator(query):
364            return re.sub(search_re, replace_str, query)
365        return translator
366