1# Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Site extensions to server_job.  Adds distribute_across_machines()."""
6
7import os, logging, multiprocessing
8from autotest_lib.server import site_gtest_runner, site_server_job_utils
9from autotest_lib.server import subcommand
10from autotest_lib.server.server_job import base_server_job
11import utils
12
13
14def get_site_job_data(job):
15    """Add custom data to the job keyval info.
16
17    When multiple machines are used in a job, change the hostname to
18    the platform of the first machine instead of machine1,machine2,...  This
19    makes the job reports easier to read and keeps the tko_machines table from
20    growing too large.
21
22    Args:
23        job: instance of server_job.
24
25    Returns:
26        keyval dictionary with new hostname value, or empty dictionary.
27    """
28    site_job_data = {}
29    # Only modify hostname on multimachine jobs. Assume all host have the same
30    # platform.
31    if len(job.machines) > 1:
32        # Search through machines for first machine with a platform.
33        for host in job.machines:
34            keyval_path = os.path.join(job.resultdir, 'host_keyvals', host)
35            keyvals = utils.read_keyval(keyval_path)
36            host_plat = keyvals.get('platform', None)
37            if not host_plat:
38                continue
39            site_job_data['hostname'] = host_plat
40            break
41    return site_job_data
42
43
44class site_server_job(base_server_job):
45    """Extend server_job adding distribute_across_machines."""
46
47    def __init__(self, *args, **dargs):
48        super(site_server_job, self).__init__(*args, **dargs)
49
50
51    def run(self, *args, **dargs):
52        """Extend server_job.run adding gtest_runner to the namespace."""
53
54        gtest_run = {'gtest_runner': site_gtest_runner.gtest_runner()}
55
56        # Namespace is the 5th parameter to run().  If args has 5 or more
57        # entries in it then we need to fix-up this namespace entry.
58        if len(args) >= 5:
59            args[4].update(gtest_run)
60        # Else, if present, namespace must be in dargs.
61        else:
62            dargs.setdefault('namespace', gtest_run).update(gtest_run)
63        # Now call the original run() with the modified namespace containing a
64        # gtest_runner
65        super(site_server_job, self).run(*args, **dargs)
66
67
68    def distribute_across_machines(self, tests, machines,
69                                   continuous_parsing=False):
70        """Run each test in tests once using machines.
71
72        Instead of running each test on each machine like parallel_on_machines,
73        run each test once across all machines. Put another way, the total
74        number of tests run by parallel_on_machines is len(tests) *
75        len(machines). The number of tests run by distribute_across_machines is
76        len(tests).
77
78        Args:
79            tests: List of tests to run.
80            machines: List of machines to use.
81            continuous_parsing: Bool, if true parse job while running.
82        """
83        # The Queue is thread safe, but since a machine may have to search
84        # through the queue to find a valid test the lock provides exclusive
85        # queue access for more than just the get call.
86        test_queue = multiprocessing.JoinableQueue()
87        test_queue_lock = multiprocessing.Lock()
88
89        unique_machine_attributes = []
90        sub_commands = []
91        work_dir = self.resultdir
92
93        for machine in machines:
94            if 'group' in self.resultdir:
95                work_dir = os.path.join(self.resultdir, machine)
96
97            mw = site_server_job_utils.machine_worker(self,
98                                                      machine,
99                                                      work_dir,
100                                                      test_queue,
101                                                      test_queue_lock,
102                                                      continuous_parsing)
103
104            # Create the subcommand instance to run this machine worker.
105            sub_commands.append(subcommand.subcommand(mw.run,
106                                                      [],
107                                                      work_dir))
108
109            # To (potentially) speed up searching for valid tests create a list
110            # of unique attribute sets present in the machines for this job. If
111            # sets were hashable we could just use a dictionary for fast
112            # verification. This at least reduces the search space from the
113            # number of machines to the number of unique machines.
114            if not mw.attribute_set in unique_machine_attributes:
115                unique_machine_attributes.append(mw.attribute_set)
116
117        # Only queue tests which are valid on at least one machine.  Record
118        # skipped tests in the status.log file using record_skipped_test().
119        for test_entry in tests:
120            # Check if it's an old style test entry.
121            if len(test_entry) > 2 and not isinstance(test_entry[2], dict):
122                test_attribs = {'include': test_entry[2]}
123                if len(test_entry) > 3:
124                    test_attribs['exclude'] = test_entry[3]
125                if len(test_entry) > 4:
126                    test_attribs['attributes'] = test_entry[4]
127
128                test_entry = list(test_entry[:2])
129                test_entry.append(test_attribs)
130
131            ti = site_server_job_utils.test_item(*test_entry)
132            machine_found = False
133            for ma in unique_machine_attributes:
134                if ti.validate(ma):
135                    test_queue.put(ti)
136                    machine_found = True
137                    break
138            if not machine_found:
139                self.record_skipped_test(ti)
140
141        # Run valid tests and wait for completion.
142        subcommand.parallel(sub_commands)
143
144
145    def record_skipped_test(self, skipped_test, message=None):
146        """Insert a failure record into status.log for this test."""
147        msg = message
148        if msg is None:
149            msg = 'No valid machines found for test %s.' % skipped_test
150        logging.info(msg)
151        self.record('START', None, skipped_test.test_name)
152        self.record('INFO', None, skipped_test.test_name, msg)
153        self.record('END TEST_NA', None, skipped_test.test_name, msg)
154