1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import re, os, sys, time, random
6
7import common
8from autotest_lib.client.common_lib import global_config
9from autotest_lib.client.common_lib.cros.graphite import autotest_stats
10from autotest_lib.frontend import database_settings_helper
11from autotest_lib.server import site_utils
12from autotest_lib.tko import utils
13
14
15class MySQLTooManyRows(Exception):
16    pass
17
18
19class db_sql(object):
20    def __init__(self, debug=False, autocommit=True, host=None,
21                 database=None, user=None, password=None):
22        self.debug = debug
23        self.autocommit = autocommit
24        self._load_config(host, database, user, password)
25
26        self.con = None
27        self._init_db()
28
29        # if not present, insert statuses
30        self.status_idx = {}
31        self.status_word = {}
32        status_rows = self.select('status_idx, word', 'tko_status', None)
33        for s in status_rows:
34            self.status_idx[s[1]] = s[0]
35            self.status_word[s[0]] = s[1]
36
37        machine_map = os.path.join(os.path.dirname(__file__),
38                                   'machines')
39        if os.path.exists(machine_map):
40            self.machine_map = machine_map
41        else:
42            self.machine_map = None
43        self.machine_group = {}
44
45
46    def _load_config(self, host, database, user, password):
47        """Loads configuration settings required to connect to the database.
48
49        This will try to connect to use the settings prefixed with global_db_.
50        If they do not exist, they un-prefixed settings will be used.
51
52        If parameters are supplied, these will be taken instead of the values
53        in global_config.
54
55        @param host: If set, this host will be used, if not, the host will be
56                     retrieved from global_config.
57        @param database: If set, this database will be used, if not, the
58                         database will be retrieved from global_config.
59        @param user: If set, this user will be used, if not, the
60                         user will be retrieved from global_config.
61        @param password: If set, this password will be used, if not, the
62                         password will be retrieved from global_config.
63        """
64        database_settings = database_settings_helper.get_global_db_config()
65
66        # grab the host, database
67        self.host = host or database_settings['HOST']
68        self.database = database or database_settings['NAME']
69
70        # grab the user and password
71        self.user = user or database_settings['USER']
72        self.password = password or database_settings['PASSWORD']
73
74        # grab the timeout configuration
75        self.query_timeout =(
76                database_settings.get('OPTIONS', {}).get('timeout', 3600))
77
78        # Using fallback to non-global in order to work without configuration
79        # overhead on non-shard instances.
80        get_value = global_config.global_config.get_config_value_with_fallback
81        self.min_delay = get_value("AUTOTEST_WEB", "global_db_min_retry_delay",
82                                   "min_retry_delay", type=int, default=20)
83        self.max_delay = get_value("AUTOTEST_WEB", "global_db_max_retry_delay",
84                                   "max_retry_delay", type=int, default=60)
85
86        # TODO(beeps): Move this to django settings once we have routers.
87        # On test instances mysql connects through a different port. No point
88        # piping this through our entire infrastructure when it is only really
89        # used for testing; Ideally we would specify this through django
90        # settings and default it to the empty string so django will figure out
91        # the default based on the database backend (eg: mysql, 3306), but until
92        # we have database routers in place any django settings will apply to
93        # both tko and afe.
94        # The intended use of this port is to allow a testing shard vm to
95        # update the master vm's database with test results. Specifying
96        # and empty string will fallback to not even specifying the port
97        # to the backend in tko/db.py. Unfortunately this means retries
98        # won't work on the test cluster till we've migrated to routers.
99        self.port = global_config.global_config.get_config_value(
100                "AUTOTEST_WEB", "global_db_port", type=str, default='')
101
102
103    def _init_db(self):
104        # make sure we clean up any existing connection
105        if self.con:
106            self.con.close()
107            self.con = None
108
109        try:
110            # create the db connection and cursor
111            self.con = self.connect(self.host, self.database,
112                                    self.user, self.password, self.port)
113        except:
114            autotest_stats.Counter('tko_db_con_error').increment()
115            raise
116        self.cur = self.con.cursor()
117
118
119    def _random_delay(self):
120        delay = random.randint(self.min_delay, self.max_delay)
121        time.sleep(delay)
122
123
124    def run_with_retry(self, function, *args, **dargs):
125        """Call function(*args, **dargs) until either it passes
126        without an operational error, or a timeout is reached.
127        This will re-connect to the database, so it is NOT safe
128        to use this inside of a database transaction.
129
130        It can be safely used with transactions, but the
131        transaction start & end must be completely contained
132        within the call to 'function'."""
133        OperationalError = _get_error_class("OperationalError")
134
135        success = False
136        start_time = time.time()
137        while not success:
138            try:
139                result = function(*args, **dargs)
140            except OperationalError, e:
141                self._log_operational_error(e)
142                stop_time = time.time()
143                elapsed_time = stop_time - start_time
144                if elapsed_time > self.query_timeout:
145                    raise
146                else:
147                    try:
148                        self._random_delay()
149                        self._init_db()
150                        autotest_stats.Counter('tko_db_error').increment()
151                    except OperationalError, e:
152                        self._log_operational_error(e)
153            else:
154                success = True
155        return result
156
157
158    def _log_operational_error(self, e):
159        msg = ("%s: An operational error occured during a database "
160               "operation: %s" % (time.strftime("%X %x"), str(e)))
161        print >> sys.stderr, msg
162        sys.stderr.flush() # we want these msgs to show up immediately
163
164
165    def dprint(self, value):
166        if self.debug:
167            sys.stdout.write('SQL: ' + str(value) + '\n')
168
169
170    def _commit(self):
171        """Private method for function commit to call for retry.
172        """
173        return self.con.commit()
174
175
176    def commit(self):
177        if self.autocommit:
178            return self.run_with_retry(self._commit)
179        else:
180            return self._commit()
181
182
183    def rollback(self):
184        self.con.rollback()
185
186
187    def get_last_autonumber_value(self):
188        self.cur.execute('SELECT LAST_INSERT_ID()', [])
189        return self.cur.fetchall()[0][0]
190
191
192    def _quote(self, field):
193        return '`%s`' % field
194
195
196    def _where_clause(self, where):
197        if not where:
198            return '', []
199
200        if isinstance(where, dict):
201            # key/value pairs (which should be equal, or None for null)
202            keys, values = [], []
203            for field, value in where.iteritems():
204                quoted_field = self._quote(field)
205                if value is None:
206                    keys.append(quoted_field + ' is null')
207                else:
208                    keys.append(quoted_field + '=%s')
209                    values.append(value)
210            where_clause = ' and '.join(keys)
211        elif isinstance(where, basestring):
212            # the exact string
213            where_clause = where
214            values = []
215        elif isinstance(where, tuple):
216            # preformatted where clause + values
217            where_clause, values = where
218            assert where_clause
219        else:
220            raise ValueError('Invalid "where" value: %r' % where)
221
222        return ' WHERE ' + where_clause, values
223
224
225
226    def select(self, fields, table, where, distinct=False, group_by=None,
227               max_rows=None):
228        """\
229                This selects all the fields requested from a
230                specific table with a particular where clause.
231                The where clause can either be a dictionary of
232                field=value pairs, a string, or a tuple of (string,
233                a list of values).  The last option is what you
234                should use when accepting user input as it'll
235                protect you against sql injection attacks (if
236                all user data is placed in the array rather than
237                the raw SQL).
238
239                For example:
240                  where = ("a = %s AND b = %s", ['val', 'val'])
241                is better than
242                  where = "a = 'val' AND b = 'val'"
243        """
244        cmd = ['select']
245        if distinct:
246            cmd.append('distinct')
247        cmd += [fields, 'from', table]
248
249        where_clause, values = self._where_clause(where)
250        cmd.append(where_clause)
251
252        if group_by:
253            cmd.append(' GROUP BY ' + group_by)
254
255        self.dprint('%s %s' % (' '.join(cmd), values))
256
257        # create a re-runable function for executing the query
258        def exec_sql():
259            sql = ' '.join(cmd)
260            numRec = self.cur.execute(sql, values)
261            if max_rows is not None and numRec > max_rows:
262                msg = 'Exceeded allowed number of records'
263                raise MySQLTooManyRows(msg)
264            return self.cur.fetchall()
265
266        # run the query, re-trying after operational errors
267        if self.autocommit:
268            return self.run_with_retry(exec_sql)
269        else:
270            return exec_sql()
271
272
273    def select_sql(self, fields, table, sql, values):
274        """\
275                select fields from table "sql"
276        """
277        cmd = 'select %s from %s %s' % (fields, table, sql)
278        self.dprint(cmd)
279
280        # create a -re-runable function for executing the query
281        def exec_sql():
282            self.cur.execute(cmd, values)
283            return self.cur.fetchall()
284
285        # run the query, re-trying after operational errors
286        if self.autocommit:
287            return self.run_with_retry(exec_sql)
288        else:
289            return exec_sql()
290
291
292    def _exec_sql_with_commit(self, sql, values, commit):
293        if self.autocommit:
294            # re-run the query until it succeeds
295            def exec_sql():
296                self.cur.execute(sql, values)
297                self.con.commit()
298            self.run_with_retry(exec_sql)
299        else:
300            # take one shot at running the query
301            self.cur.execute(sql, values)
302            if commit:
303                self.con.commit()
304
305
306    def insert(self, table, data, commit=None):
307        """\
308                'insert into table (keys) values (%s ... %s)', values
309
310                data:
311                        dictionary of fields and data
312        """
313        fields = data.keys()
314        refs = ['%s' for field in fields]
315        values = [data[field] for field in fields]
316        cmd = ('insert into %s (%s) values (%s)' %
317               (table, ','.join(self._quote(field) for field in fields),
318                ','.join(refs)))
319        self.dprint('%s %s' % (cmd, values))
320
321        self._exec_sql_with_commit(cmd, values, commit)
322
323
324    def delete(self, table, where, commit = None):
325        cmd = ['delete from', table]
326        if commit is None:
327            commit = self.autocommit
328        where_clause, values = self._where_clause(where)
329        cmd.append(where_clause)
330        sql = ' '.join(cmd)
331        self.dprint('%s %s' % (sql, values))
332
333        self._exec_sql_with_commit(sql, values, commit)
334
335
336    def update(self, table, data, where, commit = None):
337        """\
338                'update table set data values (%s ... %s) where ...'
339
340                data:
341                        dictionary of fields and data
342        """
343        if commit is None:
344            commit = self.autocommit
345        cmd = 'update %s ' % table
346        fields = data.keys()
347        data_refs = [self._quote(field) + '=%s' for field in fields]
348        data_values = [data[field] for field in fields]
349        cmd += ' set ' + ', '.join(data_refs)
350
351        where_clause, where_values = self._where_clause(where)
352        cmd += where_clause
353
354        values = data_values + where_values
355        self.dprint('%s %s' % (cmd, values))
356
357        self._exec_sql_with_commit(cmd, values, commit)
358
359
360    def delete_job(self, tag, commit = None):
361        job_idx = self.find_job(tag)
362        for test_idx in self.find_tests(job_idx):
363            where = {'test_idx' : test_idx}
364            self.delete('tko_iteration_result', where)
365            self.delete('tko_iteration_perf_value', where)
366            self.delete('tko_iteration_attributes', where)
367            self.delete('tko_test_attributes', where)
368            self.delete('tko_test_labels_tests', {'test_id': test_idx})
369        where = {'job_idx' : job_idx}
370        self.delete('tko_tests', where)
371        self.delete('tko_jobs', where)
372
373
374    def insert_job(self, tag, job, parent_job_id=None, commit=None):
375        job.machine_idx = self.lookup_machine(job.machine)
376        if not job.machine_idx:
377            job.machine_idx = self.insert_machine(job, commit=commit)
378        elif job.machine:
379            # Only try to update tko_machines record if machine is set. This
380            # prevents unnecessary db writes for suite jobs.
381            self.update_machine_information(job, commit=commit)
382
383        afe_job_id = utils.get_afe_job_id(tag)
384
385        data = {'tag':tag,
386                'label': job.label,
387                'username': job.user,
388                'machine_idx': job.machine_idx,
389                'queued_time': job.queued_time,
390                'started_time': job.started_time,
391                'finished_time': job.finished_time,
392                'afe_job_id': afe_job_id,
393                'afe_parent_job_id': parent_job_id}
394        if job.label:
395            label_info = site_utils.parse_job_name(job.label)
396            if label_info:
397                data['build'] = label_info.get('build', None)
398                data['build_version'] = label_info.get('build_version', None)
399                data['board'] = label_info.get('board', None)
400                data['suite'] = label_info.get('suite', None)
401        is_update = hasattr(job, 'index')
402        if is_update:
403            self.update('tko_jobs', data, {'job_idx': job.index}, commit=commit)
404        else:
405            self.insert('tko_jobs', data, commit=commit)
406            job.index = self.get_last_autonumber_value()
407        self.update_job_keyvals(job, commit=commit)
408        for test in job.tests:
409            self.insert_test(job, test, commit=commit)
410
411
412    def update_job_keyvals(self, job, commit=None):
413        for key, value in job.keyval_dict.iteritems():
414            where = {'job_id': job.index, 'key': key}
415            data = dict(where, value=value)
416            exists = self.select('id', 'tko_job_keyvals', where=where)
417
418            if exists:
419                self.update('tko_job_keyvals', data, where=where, commit=commit)
420            else:
421                self.insert('tko_job_keyvals', data, commit=commit)
422
423
424    def insert_test(self, job, test, commit = None):
425        kver = self.insert_kernel(test.kernel, commit=commit)
426        data = {'job_idx':job.index, 'test':test.testname,
427                'subdir':test.subdir, 'kernel_idx':kver,
428                'status':self.status_idx[test.status],
429                'reason':test.reason, 'machine_idx':job.machine_idx,
430                'started_time': test.started_time,
431                'finished_time':test.finished_time}
432        is_update = hasattr(test, "test_idx")
433        if is_update:
434            test_idx = test.test_idx
435            self.update('tko_tests', data,
436                        {'test_idx': test_idx}, commit=commit)
437            where = {'test_idx': test_idx}
438            self.delete('tko_iteration_result', where)
439            self.delete('tko_iteration_perf_value', where)
440            self.delete('tko_iteration_attributes', where)
441            where['user_created'] = 0
442            self.delete('tko_test_attributes', where)
443        else:
444            self.insert('tko_tests', data, commit=commit)
445            test_idx = test.test_idx = self.get_last_autonumber_value()
446        data = {'test_idx': test_idx}
447
448        for i in test.iterations:
449            data['iteration'] = i.index
450            for key, value in i.attr_keyval.iteritems():
451                data['attribute'] = key
452                data['value'] = value
453                self.insert('tko_iteration_attributes', data,
454                            commit=commit)
455            for key, value in i.perf_keyval.iteritems():
456                data['attribute'] = key
457                data['value'] = value
458                self.insert('tko_iteration_result', data,
459                            commit=commit)
460
461        data = {'test_idx': test_idx}
462        for i in test.perf_values:
463            data['iteration'] = i.index
464            for perf_dict in i.perf_measurements:
465                data['description'] = perf_dict['description']
466                data['value'] = perf_dict['value']
467                data['stddev'] = perf_dict['stddev']
468                data['units'] = perf_dict['units']
469                # TODO(fdeng): In db, higher_is_better doesn't allow null,
470                # This is a workaround to avoid altering the
471                # table (very expensive) while still allows test to send
472                # higher_is_better=None. Ideally, the table should be
473                # altered to allow this.
474                if perf_dict['higher_is_better'] is not None:
475                    data['higher_is_better'] = perf_dict['higher_is_better']
476                data['graph'] = perf_dict['graph']
477                self.insert('tko_iteration_perf_value', data, commit=commit)
478
479        for key, value in test.attributes.iteritems():
480            data = {'test_idx': test_idx, 'attribute': key,
481                    'value': value}
482            self.insert('tko_test_attributes', data, commit=commit)
483
484        if not is_update:
485            for label_index in test.labels:
486                data = {'test_id': test_idx, 'testlabel_id': label_index}
487                self.insert('tko_test_labels_tests', data, commit=commit)
488
489
490    def read_machine_map(self):
491        if self.machine_group or not self.machine_map:
492            return
493        for line in open(self.machine_map, 'r').readlines():
494            (machine, group) = line.split()
495            self.machine_group[machine] = group
496
497
498    def machine_info_dict(self, job):
499        hostname = job.machine
500        group = job.machine_group
501        owner = job.machine_owner
502
503        if not group:
504            self.read_machine_map()
505            group = self.machine_group.get(hostname, hostname)
506            if group == hostname and owner:
507                group = owner + '/' + hostname
508
509        return {'hostname': hostname, 'machine_group': group, 'owner': owner}
510
511
512    def insert_machine(self, job, commit = None):
513        machine_info = self.machine_info_dict(job)
514        self.insert('tko_machines', machine_info, commit=commit)
515        return self.get_last_autonumber_value()
516
517
518    def update_machine_information(self, job, commit = None):
519        machine_info = self.machine_info_dict(job)
520        self.update('tko_machines', machine_info,
521                    where={'hostname': machine_info['hostname']},
522                    commit=commit)
523
524
525    def lookup_machine(self, hostname):
526        where = { 'hostname' : hostname }
527        rows = self.select('machine_idx', 'tko_machines', where)
528        if rows:
529            return rows[0][0]
530        else:
531            return None
532
533
534    def lookup_kernel(self, kernel):
535        rows = self.select('kernel_idx', 'tko_kernels',
536                                {'kernel_hash':kernel.kernel_hash})
537        if rows:
538            return rows[0][0]
539        else:
540            return None
541
542
543    def insert_kernel(self, kernel, commit = None):
544        kver = self.lookup_kernel(kernel)
545        if kver:
546            return kver
547
548        # If this kernel has any significant patches, append their hash
549        # as diferentiator.
550        printable = kernel.base
551        patch_count = 0
552        for patch in kernel.patches:
553            match = re.match(r'.*(-mm[0-9]+|-git[0-9]+)\.(bz2|gz)$',
554                                                    patch.reference)
555            if not match:
556                patch_count += 1
557
558        self.insert('tko_kernels',
559                    {'base':kernel.base,
560                     'kernel_hash':kernel.kernel_hash,
561                     'printable':printable},
562                    commit=commit)
563        kver = self.get_last_autonumber_value()
564
565        if patch_count > 0:
566            printable += ' p%d' % (kver)
567            self.update('tko_kernels',
568                    {'printable':printable},
569                    {'kernel_idx':kver})
570
571        for patch in kernel.patches:
572            self.insert_patch(kver, patch, commit=commit)
573        return kver
574
575
576    def insert_patch(self, kver, patch, commit = None):
577        print patch.reference
578        name = os.path.basename(patch.reference)[:80]
579        self.insert('tko_patches',
580                    {'kernel_idx': kver,
581                     'name':name,
582                     'url':patch.reference,
583                     'hash':patch.hash},
584                    commit=commit)
585
586
587    def find_test(self, job_idx, testname, subdir):
588        where = {'job_idx': job_idx , 'test': testname, 'subdir': subdir}
589        rows = self.select('test_idx', 'tko_tests', where)
590        if rows:
591            return rows[0][0]
592        else:
593            return None
594
595
596    def find_tests(self, job_idx):
597        where = { 'job_idx':job_idx }
598        rows = self.select('test_idx', 'tko_tests', where)
599        if rows:
600            return [row[0] for row in rows]
601        else:
602            return []
603
604
605    def find_job(self, tag):
606        rows = self.select('job_idx', 'tko_jobs', {'tag': tag})
607        if rows:
608            return rows[0][0]
609        else:
610            return None
611
612
613def _get_db_type():
614    """Get the database type name to use from the global config."""
615    get_value = global_config.global_config.get_config_value_with_fallback
616    return "db_" + get_value("AUTOTEST_WEB", "global_db_type", "db_type",
617                             default="mysql")
618
619
620def _get_error_class(class_name):
621    """Retrieves the appropriate error class by name from the database
622    module."""
623    db_module = __import__("autotest_lib.tko." + _get_db_type(),
624                           globals(), locals(), ["driver"])
625    return getattr(db_module.driver, class_name)
626
627
628def db(*args, **dargs):
629    """Creates an instance of the database class with the arguments
630    provided in args and dargs, using the database type specified by
631    the global configuration (defaulting to mysql)."""
632    db_type = _get_db_type()
633    db_module = __import__("autotest_lib.tko." + db_type, globals(),
634                           locals(), [db_type])
635    db = getattr(db_module, db_type)(*args, **dargs)
636    return db
637