1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import math
6import os
7import random
8import re
9import sys
10import time
11
12import common
13from autotest_lib.client.common_lib import global_config
14from autotest_lib.client.common_lib.cros.graphite import autotest_stats
15from autotest_lib.frontend import database_settings_helper
16from autotest_lib.tko import utils
17
18
19def _log_error(msg):
20    """Log an error message.
21
22    @param msg: Message string
23    """
24    print >> sys.stderr, msg
25    sys.stderr.flush()  # we want these msgs to show up immediately
26
27
28def _format_operational_error(e):
29    """Format OperationalError.
30
31    @param e: OperationalError instance.
32    """
33    return ("%s: An operational error occurred during a database "
34            "operation: %s" % (time.strftime("%X %x"), str(e)))
35
36
37class MySQLTooManyRows(Exception):
38    """Too many records."""
39    pass
40
41
42class db_sql(object):
43    """Data access."""
44
45    def __init__(self, debug=False, autocommit=True, host=None,
46                 database=None, user=None, password=None):
47        self.debug = debug
48        self.autocommit = autocommit
49        self._load_config(host, database, user, password)
50
51        self.con = None
52        self._init_db()
53
54        # if not present, insert statuses
55        self.status_idx = {}
56        self.status_word = {}
57        status_rows = self.select('status_idx, word', 'tko_status', None)
58        for s in status_rows:
59            self.status_idx[s[1]] = s[0]
60            self.status_word[s[0]] = s[1]
61
62        machine_map = os.path.join(os.path.dirname(__file__),
63                                   'machines')
64        if os.path.exists(machine_map):
65            self.machine_map = machine_map
66        else:
67            self.machine_map = None
68        self.machine_group = {}
69
70
71    def _load_config(self, host, database, user, password):
72        """Loads configuration settings required to connect to the database.
73
74        This will try to connect to use the settings prefixed with global_db_.
75        If they do not exist, they un-prefixed settings will be used.
76
77        If parameters are supplied, these will be taken instead of the values
78        in global_config.
79
80        @param host: If set, this host will be used, if not, the host will be
81                     retrieved from global_config.
82        @param database: If set, this database will be used, if not, the
83                         database will be retrieved from global_config.
84        @param user: If set, this user will be used, if not, the
85                         user will be retrieved from global_config.
86        @param password: If set, this password will be used, if not, the
87                         password will be retrieved from global_config.
88        """
89        database_settings = database_settings_helper.get_global_db_config()
90
91        # grab the host, database
92        self.host = host or database_settings['HOST']
93        self.database = database or database_settings['NAME']
94
95        # grab the user and password
96        self.user = user or database_settings['USER']
97        self.password = password or database_settings['PASSWORD']
98
99        # grab the timeout configuration
100        self.query_timeout =(
101                database_settings.get('OPTIONS', {}).get('timeout', 3600))
102
103        # Using fallback to non-global in order to work without configuration
104        # overhead on non-shard instances.
105        get_value = global_config.global_config.get_config_value_with_fallback
106        self.min_delay = get_value("AUTOTEST_WEB", "global_db_min_retry_delay",
107                                   "min_retry_delay", type=int, default=20)
108        self.max_delay = get_value("AUTOTEST_WEB", "global_db_max_retry_delay",
109                                   "max_retry_delay", type=int, default=60)
110
111        # TODO(beeps): Move this to django settings once we have routers.
112        # On test instances mysql connects through a different port. No point
113        # piping this through our entire infrastructure when it is only really
114        # used for testing; Ideally we would specify this through django
115        # settings and default it to the empty string so django will figure out
116        # the default based on the database backend (eg: mysql, 3306), but until
117        # we have database routers in place any django settings will apply to
118        # both tko and afe.
119        # The intended use of this port is to allow a testing shard vm to
120        # update the master vm's database with test results. Specifying
121        # and empty string will fallback to not even specifying the port
122        # to the backend in tko/db.py. Unfortunately this means retries
123        # won't work on the test cluster till we've migrated to routers.
124        self.port = global_config.global_config.get_config_value(
125                "AUTOTEST_WEB", "global_db_port", type=str, default='')
126
127
128    def _init_db(self):
129        # make sure we clean up any existing connection
130        if self.con:
131            self.con.close()
132            self.con = None
133
134        try:
135            # create the db connection and cursor
136            self.con = self.connect(self.host, self.database,
137                                    self.user, self.password, self.port)
138        except:
139            autotest_stats.Counter('tko_db_con_error').increment()
140            raise
141        self.cur = self.con.cursor()
142
143
144    def _random_delay(self):
145        delay = random.randint(self.min_delay, self.max_delay)
146        time.sleep(delay)
147
148
149    def run_with_retry(self, function, *args, **dargs):
150        """Call function(*args, **dargs) until either it passes
151        without an operational error, or a timeout is reached.
152        This will re-connect to the database, so it is NOT safe
153        to use this inside of a database transaction.
154
155        It can be safely used with transactions, but the
156        transaction start & end must be completely contained
157        within the call to 'function'.
158
159        @param function: The function to run with retry.
160        @param args: The arguments
161        @param dargs: The named arguments.
162        """
163        OperationalError = _get_error_class("OperationalError")
164
165        success = False
166        start_time = time.time()
167        while not success:
168            try:
169                result = function(*args, **dargs)
170            except OperationalError, e:
171                _log_error("%s; retrying, don't panic yet"
172                           % _format_operational_error(e))
173                stop_time = time.time()
174                elapsed_time = stop_time - start_time
175                if elapsed_time > self.query_timeout:
176                    raise
177                else:
178                    try:
179                        self._random_delay()
180                        self._init_db()
181                        autotest_stats.Counter('tko_db_error').increment()
182                    except OperationalError, e:
183                        _log_error('%s; panic now'
184                                   % _format_operational_error(e))
185            else:
186                success = True
187        return result
188
189
190    def dprint(self, value):
191        """Print out debug value.
192
193        @param value: The value to print out.
194        """
195        if self.debug:
196            sys.stdout.write('SQL: ' + str(value) + '\n')
197
198
199    def _commit(self):
200        """Private method for function commit to call for retry.
201        """
202        return self.con.commit()
203
204
205    def commit(self):
206        """Commit the sql transaction."""
207        if self.autocommit:
208            return self.run_with_retry(self._commit)
209        else:
210            return self._commit()
211
212
213    def rollback(self):
214        """Rollback the sql transaction."""
215        self.con.rollback()
216
217
218    def get_last_autonumber_value(self):
219        """Gets the last auto number.
220
221        @return: The last auto number.
222        """
223        self.cur.execute('SELECT LAST_INSERT_ID()', [])
224        return self.cur.fetchall()[0][0]
225
226
227    def _quote(self, field):
228        return '`%s`' % field
229
230
231    def _where_clause(self, where):
232        if not where:
233            return '', []
234
235        if isinstance(where, dict):
236            # key/value pairs (which should be equal, or None for null)
237            keys, values = [], []
238            for field, value in where.iteritems():
239                quoted_field = self._quote(field)
240                if value is None:
241                    keys.append(quoted_field + ' is null')
242                else:
243                    keys.append(quoted_field + '=%s')
244                    values.append(value)
245            where_clause = ' and '.join(keys)
246        elif isinstance(where, basestring):
247            # the exact string
248            where_clause = where
249            values = []
250        elif isinstance(where, tuple):
251            # preformatted where clause + values
252            where_clause, values = where
253            assert where_clause
254        else:
255            raise ValueError('Invalid "where" value: %r' % where)
256
257        return ' WHERE ' + where_clause, values
258
259
260
261    def select(self, fields, table, where, distinct=False, group_by=None,
262               max_rows=None):
263        """\
264                This selects all the fields requested from a
265                specific table with a particular where clause.
266                The where clause can either be a dictionary of
267                field=value pairs, a string, or a tuple of (string,
268                a list of values).  The last option is what you
269                should use when accepting user input as it'll
270                protect you against sql injection attacks (if
271                all user data is placed in the array rather than
272                the raw SQL).
273
274                For example:
275                  where = ("a = %s AND b = %s", ['val', 'val'])
276                is better than
277                  where = "a = 'val' AND b = 'val'"
278
279        @param fields: The list of selected fields string.
280        @param table: The name of the database table.
281        @param where: The where clause string.
282        @param distinct: If select distinct values.
283        @param group_by: Group by clause.
284        @param max_rows: unused.
285        """
286        cmd = ['select']
287        if distinct:
288            cmd.append('distinct')
289        cmd += [fields, 'from', table]
290
291        where_clause, values = self._where_clause(where)
292        cmd.append(where_clause)
293
294        if group_by:
295            cmd.append(' GROUP BY ' + group_by)
296
297        self.dprint('%s %s' % (' '.join(cmd), values))
298
299        # create a re-runable function for executing the query
300        def exec_sql():
301            """Exeuctes an the sql command."""
302            sql = ' '.join(cmd)
303            numRec = self.cur.execute(sql, values)
304            if max_rows is not None and numRec > max_rows:
305                msg = 'Exceeded allowed number of records'
306                raise MySQLTooManyRows(msg)
307            return self.cur.fetchall()
308
309        # run the query, re-trying after operational errors
310        if self.autocommit:
311            return self.run_with_retry(exec_sql)
312        else:
313            return exec_sql()
314
315
316    def select_sql(self, fields, table, sql, values):
317        """\
318                select fields from table "sql"
319
320        @param fields: The list of selected fields string.
321        @param table: The name of the database table.
322        @param sql: The sql string.
323        @param values: The sql string parameter values.
324        """
325        cmd = 'select %s from %s %s' % (fields, table, sql)
326        self.dprint(cmd)
327
328        # create a -re-runable function for executing the query
329        def _exec_sql():
330            self.cur.execute(cmd, values)
331            return self.cur.fetchall()
332
333        # run the query, re-trying after operational errors
334        if self.autocommit:
335            return self.run_with_retry(_exec_sql)
336        else:
337            return _exec_sql()
338
339
340    def _exec_sql_with_commit(self, sql, values, commit):
341        if self.autocommit:
342            # re-run the query until it succeeds
343            def _exec_sql():
344                self.cur.execute(sql, values)
345                self.con.commit()
346            self.run_with_retry(_exec_sql)
347        else:
348            # take one shot at running the query
349            self.cur.execute(sql, values)
350            if commit:
351                self.con.commit()
352
353
354    def insert(self, table, data, commit=None):
355        """\
356                'insert into table (keys) values (%s ... %s)', values
357
358                data:
359                        dictionary of fields and data
360
361        @param table: The name of the table.
362        @param data: The insert data.
363        @param commit: If commit the transaction .
364        """
365        fields = data.keys()
366        refs = ['%s' for field in fields]
367        values = [data[field] for field in fields]
368        cmd = ('insert into %s (%s) values (%s)' %
369               (table, ','.join(self._quote(field) for field in fields),
370                ','.join(refs)))
371        self.dprint('%s %s' % (cmd, values))
372
373        self._exec_sql_with_commit(cmd, values, commit)
374
375
376    def delete(self, table, where, commit = None):
377        """Delete entries.
378
379        @param table: The name of the table.
380        @param where: The where clause.
381        @param commit: If commit the transaction .
382        """
383        cmd = ['delete from', table]
384        if commit is None:
385            commit = self.autocommit
386        where_clause, values = self._where_clause(where)
387        cmd.append(where_clause)
388        sql = ' '.join(cmd)
389        self.dprint('%s %s' % (sql, values))
390
391        self._exec_sql_with_commit(sql, values, commit)
392
393
394    def update(self, table, data, where, commit = None):
395        """\
396                'update table set data values (%s ... %s) where ...'
397
398                data:
399                        dictionary of fields and data
400
401        @param table: The name of the table.
402        @param data: The sql parameter values.
403        @param where: The where clause.
404        @param commit: If commit the transaction .
405        """
406        if commit is None:
407            commit = self.autocommit
408        cmd = 'update %s ' % table
409        fields = data.keys()
410        data_refs = [self._quote(field) + '=%s' for field in fields]
411        data_values = [data[field] for field in fields]
412        cmd += ' set ' + ', '.join(data_refs)
413
414        where_clause, where_values = self._where_clause(where)
415        cmd += where_clause
416
417        values = data_values + where_values
418        self.dprint('%s %s' % (cmd, values))
419
420        self._exec_sql_with_commit(cmd, values, commit)
421
422
423    def delete_job(self, tag, commit = None):
424        """Delete a tko job.
425
426        @param tag: The job tag.
427        @param commit: If commit the transaction .
428        """
429        job_idx = self.find_job(tag)
430        for test_idx in self.find_tests(job_idx):
431            where = {'test_idx' : test_idx}
432            self.delete('tko_iteration_result', where)
433            self.delete('tko_iteration_perf_value', where)
434            self.delete('tko_iteration_attributes', where)
435            self.delete('tko_test_attributes', where)
436            self.delete('tko_test_labels_tests', {'test_id': test_idx})
437        where = {'job_idx' : job_idx}
438        self.delete('tko_tests', where)
439        self.delete('tko_jobs', where)
440
441
442    def insert_job(self, tag, job, parent_job_id=None, commit=None):
443        """Insert a tko job.
444
445        @param tag: The job tag.
446        @param job: The job object.
447        @param parent_job_id: The parent job id.
448        @param commit: If commit the transaction .
449
450        @return The dict of data inserted into the tko_jobs table.
451        """
452        job.machine_idx = self.lookup_machine(job.machine)
453        if not job.machine_idx:
454            job.machine_idx = self.insert_machine(job, commit=commit)
455        elif job.machine:
456            # Only try to update tko_machines record if machine is set. This
457            # prevents unnecessary db writes for suite jobs.
458            self.update_machine_information(job, commit=commit)
459
460        afe_job_id = utils.get_afe_job_id(tag)
461
462        data = {'tag':tag,
463                'label': job.label,
464                'username': job.user,
465                'machine_idx': job.machine_idx,
466                'queued_time': job.queued_time,
467                'started_time': job.started_time,
468                'finished_time': job.finished_time,
469                'afe_job_id': afe_job_id,
470                'afe_parent_job_id': parent_job_id,
471                'build': job.build,
472                'build_version': job.build_version,
473                'board': job.board,
474                'suite': job.suite}
475        job.afe_job_id = afe_job_id
476        if parent_job_id:
477            job.afe_parent_job_id = str(parent_job_id)
478
479        # TODO(ntang): check job.index directly.
480        is_update = hasattr(job, 'index')
481        if is_update:
482            self.update('tko_jobs', data, {'job_idx': job.index}, commit=commit)
483        else:
484            self.insert('tko_jobs', data, commit=commit)
485            job.index = self.get_last_autonumber_value()
486        self.update_job_keyvals(job, commit=commit)
487        for test in job.tests:
488            self.insert_test(job, test, commit=commit)
489
490        return data
491
492
493    def update_job_keyvals(self, job, commit=None):
494        """Updates the job key values.
495
496        @param job: The job object.
497        @param commit: If commit the transaction .
498        """
499        for key, value in job.keyval_dict.iteritems():
500            where = {'job_id': job.index, 'key': key}
501            data = dict(where, value=value)
502            exists = self.select('id', 'tko_job_keyvals', where=where)
503
504            if exists:
505                self.update('tko_job_keyvals', data, where=where, commit=commit)
506            else:
507                self.insert('tko_job_keyvals', data, commit=commit)
508
509
510    def insert_test(self, job, test, commit = None):
511        """Inserts a job test.
512
513        @param job: The job object.
514        @param test: The test object.
515        @param commit: If commit the transaction .
516        """
517        kver = self.insert_kernel(test.kernel, commit=commit)
518        data = {'job_idx':job.index, 'test':test.testname,
519                'subdir':test.subdir, 'kernel_idx':kver,
520                'status':self.status_idx[test.status],
521                'reason':test.reason, 'machine_idx':job.machine_idx,
522                'started_time': test.started_time,
523                'finished_time':test.finished_time}
524        is_update = hasattr(test, "test_idx")
525        if is_update:
526            test_idx = test.test_idx
527            self.update('tko_tests', data,
528                        {'test_idx': test_idx}, commit=commit)
529            where = {'test_idx': test_idx}
530            self.delete('tko_iteration_result', where)
531            self.delete('tko_iteration_perf_value', where)
532            self.delete('tko_iteration_attributes', where)
533            where['user_created'] = 0
534            self.delete('tko_test_attributes', where)
535        else:
536            self.insert('tko_tests', data, commit=commit)
537            test_idx = test.test_idx = self.get_last_autonumber_value()
538        data = {'test_idx': test_idx}
539
540        for i in test.iterations:
541            data['iteration'] = i.index
542            for key, value in i.attr_keyval.iteritems():
543                data['attribute'] = key
544                data['value'] = value
545                self.insert('tko_iteration_attributes', data,
546                            commit=commit)
547            for key, value in i.perf_keyval.iteritems():
548                data['attribute'] = key
549                if math.isnan(value) or math.isinf(value):
550                    data['value'] = None
551                else:
552                    data['value'] = value
553                self.insert('tko_iteration_result', data,
554                            commit=commit)
555
556        data = {'test_idx': test_idx}
557
558        for key, value in test.attributes.iteritems():
559            data = {'test_idx': test_idx, 'attribute': key,
560                    'value': value}
561            self.insert('tko_test_attributes', data, commit=commit)
562
563        if not is_update:
564            for label_index in test.labels:
565                data = {'test_id': test_idx, 'testlabel_id': label_index}
566                self.insert('tko_test_labels_tests', data, commit=commit)
567
568
569    def read_machine_map(self):
570        """Reads the machine map."""
571        if self.machine_group or not self.machine_map:
572            return
573        for line in open(self.machine_map, 'r').readlines():
574            (machine, group) = line.split()
575            self.machine_group[machine] = group
576
577
578    def machine_info_dict(self, job):
579        """Reads the machine information of a job.
580
581        @param job: The job object.
582
583        @return: The machine info dictionary.
584        """
585        hostname = job.machine
586        group = job.machine_group
587        owner = job.machine_owner
588
589        if not group:
590            self.read_machine_map()
591            group = self.machine_group.get(hostname, hostname)
592            if group == hostname and owner:
593                group = owner + '/' + hostname
594
595        return {'hostname': hostname, 'machine_group': group, 'owner': owner}
596
597
598    def insert_machine(self, job, commit = None):
599        """Inserts the job machine.
600
601        @param job: The job object.
602        @param commit: If commit the transaction .
603        """
604        machine_info = self.machine_info_dict(job)
605        self.insert('tko_machines', machine_info, commit=commit)
606        return self.get_last_autonumber_value()
607
608
609    def update_machine_information(self, job, commit = None):
610        """Updates the job machine information.
611
612        @param job: The job object.
613        @param commit: If commit the transaction .
614        """
615        machine_info = self.machine_info_dict(job)
616        self.update('tko_machines', machine_info,
617                    where={'hostname': machine_info['hostname']},
618                    commit=commit)
619
620
621    def lookup_machine(self, hostname):
622        """Look up the machine information.
623
624        @param hostname: The hostname as string.
625        """
626        where = { 'hostname' : hostname }
627        rows = self.select('machine_idx', 'tko_machines', where)
628        if rows:
629            return rows[0][0]
630        else:
631            return None
632
633
634    def lookup_kernel(self, kernel):
635        """Look up the kernel.
636
637        @param kernel: The kernel object.
638        """
639        rows = self.select('kernel_idx', 'tko_kernels',
640                                {'kernel_hash':kernel.kernel_hash})
641        if rows:
642            return rows[0][0]
643        else:
644            return None
645
646
647    def insert_kernel(self, kernel, commit = None):
648        """Insert a kernel.
649
650        @param kernel: The kernel object.
651        @param commit: If commit the transaction .
652        """
653        kver = self.lookup_kernel(kernel)
654        if kver:
655            return kver
656
657        # If this kernel has any significant patches, append their hash
658        # as diferentiator.
659        printable = kernel.base
660        patch_count = 0
661        for patch in kernel.patches:
662            match = re.match(r'.*(-mm[0-9]+|-git[0-9]+)\.(bz2|gz)$',
663                                                    patch.reference)
664            if not match:
665                patch_count += 1
666
667        self.insert('tko_kernels',
668                    {'base':kernel.base,
669                     'kernel_hash':kernel.kernel_hash,
670                     'printable':printable},
671                    commit=commit)
672        kver = self.get_last_autonumber_value()
673
674        if patch_count > 0:
675            printable += ' p%d' % (kver)
676            self.update('tko_kernels',
677                    {'printable':printable},
678                    {'kernel_idx':kver})
679
680        for patch in kernel.patches:
681            self.insert_patch(kver, patch, commit=commit)
682        return kver
683
684
685    def insert_patch(self, kver, patch, commit = None):
686        """Insert a kernel patch.
687
688        @param kver: The kernel version.
689        @param patch: The kernel patch object.
690        @param commit: If commit the transaction .
691        """
692        print patch.reference
693        name = os.path.basename(patch.reference)[:80]
694        self.insert('tko_patches',
695                    {'kernel_idx': kver,
696                     'name':name,
697                     'url':patch.reference,
698                     'hash':patch.hash},
699                    commit=commit)
700
701
702    def find_test(self, job_idx, testname, subdir):
703        """Find a test by name.
704
705        @param job_idx: The job index.
706        @param testname: The test name.
707        @param subdir: The test sub directory under the job directory.
708        """
709        where = {'job_idx': job_idx , 'test': testname, 'subdir': subdir}
710        rows = self.select('test_idx', 'tko_tests', where)
711        if rows:
712            return rows[0][0]
713        else:
714            return None
715
716
717    def find_tests(self, job_idx):
718        """Find all tests by job index.
719
720        @param job_idx: The job index.
721        @return: A list of tests.
722        """
723        where = { 'job_idx':job_idx }
724        rows = self.select('test_idx', 'tko_tests', where)
725        if rows:
726            return [row[0] for row in rows]
727        else:
728            return []
729
730
731    def find_job(self, tag):
732        """Find a job by tag.
733
734        @param tag: The job tag name.
735        @return: The job object or None.
736        """
737        rows = self.select('job_idx', 'tko_jobs', {'tag': tag})
738        if rows:
739            return rows[0][0]
740        else:
741            return None
742
743
744def _get_db_type():
745    """Get the database type name to use from the global config."""
746    get_value = global_config.global_config.get_config_value_with_fallback
747    return "db_" + get_value("AUTOTEST_WEB", "global_db_type", "db_type",
748                             default="mysql")
749
750
751def _get_error_class(class_name):
752    """Retrieves the appropriate error class by name from the database
753    module."""
754    db_module = __import__("autotest_lib.tko." + _get_db_type(),
755                           globals(), locals(), ["driver"])
756    return getattr(db_module.driver, class_name)
757
758
759def db(*args, **dargs):
760    """Creates an instance of the database class with the arguments
761    provided in args and dargs, using the database type specified by
762    the global configuration (defaulting to mysql).
763
764    @param args: The db_type arguments.
765    @param dargs: The db_type named arguments.
766
767    @return: An db object.
768    """
769    db_type = _get_db_type()
770    db_module = __import__("autotest_lib.tko." + db_type, globals(),
771                           locals(), [db_type])
772    db = getattr(db_module, db_type)(*args, **dargs)
773    return db
774