1# pylint: disable=missing-docstring
2"""
3Utility functions for rpc_interface.py.  We keep them in a separate file so that
4only RPC interface functions go into that file.
5"""
6
7__author__ = 'showard@google.com (Steve Howard)'
8
9import collections
10import datetime
11from functools import wraps
12import inspect
13import logging
14import os
15import sys
16import django.db.utils
17import django.http
18
19from autotest_lib.frontend import thread_local
20from autotest_lib.frontend.afe import models, model_logic
21from autotest_lib.client.common_lib import control_data, error
22from autotest_lib.client.common_lib import global_config
23from autotest_lib.client.common_lib import time_utils
24from autotest_lib.client.common_lib.cros import dev_server
25from autotest_lib.server import utils as server_utils
26from autotest_lib.server.cros import provision
27from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
28
29NULL_DATETIME = datetime.datetime.max
30NULL_DATE = datetime.date.max
31DUPLICATE_KEY_MSG = 'Duplicate entry'
32RESPECT_STATIC_LABELS = global_config.global_config.get_config_value(
33        'SKYLAB', 'respect_static_labels', type=bool, default=False)
34
35def prepare_for_serialization(objects):
36    """
37    Prepare Python objects to be returned via RPC.
38    @param objects: objects to be prepared.
39    """
40    if (isinstance(objects, list) and len(objects) and
41        isinstance(objects[0], dict) and 'id' in objects[0]):
42        objects = _gather_unique_dicts(objects)
43    return _prepare_data(objects)
44
45
46def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
47    """
48    Prepare a Django query to be returned via RPC as a sequence of nested
49    dictionaries.
50
51    @param query - A Django model query object with a select_related() method.
52    @param nested_dict_column_names - A list of column/attribute names for the
53            rows returned by query to expand into nested dictionaries using
54            their get_object_dict() method when not None.
55
56    @returns An list suitable to returned in an RPC.
57    """
58    all_dicts = []
59    for row in query.select_related():
60        row_dict = row.get_object_dict()
61        for column in nested_dict_column_names:
62            if row_dict[column] is not None:
63                row_dict[column] = getattr(row, column).get_object_dict()
64        all_dicts.append(row_dict)
65    return prepare_for_serialization(all_dicts)
66
67
68def _prepare_data(data):
69    """
70    Recursively process data structures, performing necessary type
71    conversions to values in data to allow for RPC serialization:
72    -convert datetimes to strings
73    -convert tuples and sets to lists
74    """
75    if isinstance(data, dict):
76        new_data = {}
77        for key, value in data.iteritems():
78            new_data[key] = _prepare_data(value)
79        return new_data
80    elif (isinstance(data, list) or isinstance(data, tuple) or
81          isinstance(data, set)):
82        return [_prepare_data(item) for item in data]
83    elif isinstance(data, datetime.date):
84        if data is NULL_DATETIME or data is NULL_DATE:
85            return None
86        return str(data)
87    else:
88        return data
89
90
91def fetchall_as_list_of_dicts(cursor):
92    """
93    Converts each row in the cursor to a dictionary so that values can be read
94    by using the column name.
95    @param cursor: The database cursor to read from.
96    @returns: A list of each row in the cursor as a dictionary.
97    """
98    desc = cursor.description
99    return [ dict(zip([col[0] for col in desc], row))
100             for row in cursor.fetchall() ]
101
102
103def raw_http_response(response_data, content_type=None):
104    response = django.http.HttpResponse(response_data, mimetype=content_type)
105    response['Content-length'] = str(len(response.content))
106    return response
107
108
109def _gather_unique_dicts(dict_iterable):
110    """\
111    Pick out unique objects (by ID) from an iterable of object dicts.
112    """
113    objects = collections.OrderedDict()
114    for obj in dict_iterable:
115        objects.setdefault(obj['id'], obj)
116    return objects.values()
117
118
119def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
120    """\
121    Generate a SQL WHERE clause for job status filtering, and return it in
122    a dict of keyword args to pass to query.extra().
123    * not_yet_run: all HQEs are Queued
124    * finished: all HQEs are complete
125    * running: everything else
126    """
127    if not (not_yet_run or running or finished):
128        return {}
129    not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
130                  'WHERE status != "%s")'
131                  % models.HostQueueEntry.Status.QUEUED)
132    not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
133                    'WHERE not complete)')
134
135    where = []
136    if not_yet_run:
137        where.append('id NOT IN ' + not_queued)
138    if running:
139        where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
140    if finished:
141        where.append('id NOT IN ' + not_finished)
142    return {'where': [' OR '.join(['(%s)' % x for x in where])]}
143
144
145def extra_job_type_filters(extra_args, suite=False,
146                           sub=False, standalone=False):
147    """\
148    Generate a SQL WHERE clause for job status filtering, and return it in
149    a dict of keyword args to pass to query.extra().
150
151    param extra_args: a dict of existing extra_args.
152
153    No more than one of the parameters should be passed as True:
154    * suite: job which is parent of other jobs
155    * sub: job with a parent job
156    * standalone: job with no child or parent jobs
157    """
158    assert not ((suite and sub) or
159                (suite and standalone) or
160                (sub and standalone)), ('Cannot specify more than one '
161                                        'filter to this function')
162
163    where = extra_args.get('where', [])
164    parent_job_id = ('DISTINCT parent_job_id')
165    child_job_id = ('id')
166    filter_common = ('(SELECT %s FROM afe_jobs '
167                     'WHERE parent_job_id IS NOT NULL)')
168
169    if suite:
170        where.append('id IN ' + filter_common % parent_job_id)
171    elif sub:
172        where.append('id IN ' + filter_common % child_job_id)
173    elif standalone:
174        where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
175                     'WHERE parent_job_id IS NOT NULL'
176                     ' AND (sub_query.parent_job_id=afe_jobs.id'
177                     ' OR sub_query.id=afe_jobs.id))')
178    else:
179        return extra_args
180
181    extra_args['where'] = where
182    return extra_args
183
184
185def get_host_query(multiple_labels, exclude_only_if_needed_labels,
186                   valid_only, filter_data):
187    """
188    @param exclude_only_if_needed_labels: Deprecated. By default it's false.
189    """
190    if valid_only:
191        initial_query = models.Host.valid_objects.all()
192    else:
193        initial_query = models.Host.objects.all()
194
195    try:
196        hosts = models.Host.get_hosts_with_labels(
197                multiple_labels, initial_query)
198        if not hosts:
199            return hosts
200
201        return models.Host.query_objects(filter_data, initial_query=hosts)
202    except models.Label.DoesNotExist:
203        return models.Host.objects.none()
204
205
206class InconsistencyException(Exception):
207    'Raised when a list of objects does not have a consistent value'
208
209
210def get_consistent_value(objects, field):
211    if not objects:
212        # well a list of nothing is consistent
213        return None
214
215    value = getattr(objects[0], field)
216    for obj in objects:
217        this_value = getattr(obj, field)
218        if this_value != value:
219            raise InconsistencyException(objects[0], obj)
220    return value
221
222
223def afe_test_dict_to_test_object(test_dict):
224    if not isinstance(test_dict, dict):
225        return test_dict
226
227    numerized_dict = {}
228    for key, value in test_dict.iteritems():
229        try:
230            numerized_dict[key] = int(value)
231        except (ValueError, TypeError):
232            numerized_dict[key] = value
233
234    return type('TestObject', (object,), numerized_dict)
235
236
237def _check_is_server_test(test_type):
238    """Checks if the test type is a server test.
239
240    @param test_type The test type in enum integer or string.
241
242    @returns A boolean to identify if the test type is server test.
243    """
244    if test_type is not None:
245        if isinstance(test_type, basestring):
246            try:
247                test_type = control_data.CONTROL_TYPE.get_value(test_type)
248            except AttributeError:
249                return False
250        return (test_type == control_data.CONTROL_TYPE.SERVER)
251    return False
252
253
254def prepare_generate_control_file(tests, profilers, db_tests=True):
255    if db_tests:
256        test_objects = [models.Test.smart_get(test) for test in tests]
257    else:
258        test_objects = [afe_test_dict_to_test_object(test) for test in tests]
259
260    profiler_objects = [models.Profiler.smart_get(profiler)
261                        for profiler in profilers]
262    # ensure tests are all the same type
263    try:
264        test_type = get_consistent_value(test_objects, 'test_type')
265    except InconsistencyException, exc:
266        test1, test2 = exc.args
267        raise model_logic.ValidationError(
268            {'tests' : 'You cannot run both test_suites and server-side '
269             'tests together (tests %s and %s differ' % (
270            test1.name, test2.name)})
271
272    is_server = _check_is_server_test(test_type)
273    if test_objects:
274        synch_count = max(test.sync_count for test in test_objects)
275    else:
276        synch_count = 1
277
278    if db_tests:
279        dependencies = set(label.name for label
280                           in models.Label.objects.filter(test__in=test_objects))
281    else:
282        dependencies = reduce(
283                set.union, [set(test.dependencies) for test in test_objects])
284
285    cf_info = dict(is_server=is_server, synch_count=synch_count,
286                   dependencies=list(dependencies))
287    return cf_info, test_objects, profiler_objects
288
289
290def check_job_dependencies(host_objects, job_dependencies):
291    """
292    Check that a set of machines satisfies a job's dependencies.
293    host_objects: list of models.Host objects
294    job_dependencies: list of names of labels
295    """
296    # check that hosts satisfy dependencies
297    host_ids = [host.id for host in host_objects]
298    hosts_in_job = models.Host.objects.filter(id__in=host_ids)
299    ok_hosts = hosts_in_job
300    for index, dependency in enumerate(job_dependencies):
301        if not provision.is_for_special_action(dependency):
302            try:
303              label = models.Label.smart_get(dependency)
304            except models.Label.DoesNotExist:
305              logging.info('Label %r does not exist, so it cannot '
306                           'be replaced by static label.', dependency)
307              label = None
308
309            if label is not None and label.is_replaced_by_static():
310                ok_hosts = ok_hosts.filter(static_labels__name=dependency)
311            else:
312                ok_hosts = ok_hosts.filter(labels__name=dependency)
313
314    failing_hosts = (set(host.hostname for host in host_objects) -
315                     set(host.hostname for host in ok_hosts))
316    if failing_hosts:
317        raise model_logic.ValidationError(
318            {'hosts' : 'Host(s) failed to meet job dependencies (' +
319                       (', '.join(job_dependencies)) + '): ' +
320                       (', '.join(failing_hosts))})
321
322
323def check_job_metahost_dependencies(metahost_objects, job_dependencies):
324    """
325    Check that at least one machine within the metahost spec satisfies the job's
326    dependencies.
327
328    @param metahost_objects A list of label objects representing the metahosts.
329    @param job_dependencies A list of strings of the required label names.
330    @raises NoEligibleHostException If a metahost cannot run the job.
331    """
332    for metahost in metahost_objects:
333        if metahost.is_replaced_by_static():
334            static_metahost = models.StaticLabel.smart_get(metahost.name)
335            hosts = models.Host.objects.filter(static_labels=static_metahost)
336        else:
337            hosts = models.Host.objects.filter(labels=metahost)
338
339        for label_name in job_dependencies:
340            if not provision.is_for_special_action(label_name):
341                try:
342                    label = models.Label.smart_get(label_name)
343                except models.Label.DoesNotExist:
344                    logging.info('Label %r does not exist, so it cannot '
345                                 'be replaced by static label.', label_name)
346                    label = None
347
348                if label is not None and label.is_replaced_by_static():
349                    hosts = hosts.filter(static_labels__name=label_name)
350                else:
351                    hosts = hosts.filter(labels__name=label_name)
352
353        if not any(hosts):
354            raise error.NoEligibleHostException("No hosts within %s satisfy %s."
355                    % (metahost.name, ', '.join(job_dependencies)))
356
357
358def _execution_key_for(host_queue_entry):
359    return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
360
361
362def check_abort_synchronous_jobs(host_queue_entries):
363    # ensure user isn't aborting part of a synchronous autoserv execution
364    count_per_execution = {}
365    for queue_entry in host_queue_entries:
366        key = _execution_key_for(queue_entry)
367        count_per_execution.setdefault(key, 0)
368        count_per_execution[key] += 1
369
370    for queue_entry in host_queue_entries:
371        if not queue_entry.execution_subdir:
372            continue
373        execution_count = count_per_execution[_execution_key_for(queue_entry)]
374        if execution_count < queue_entry.job.synch_count:
375            raise model_logic.ValidationError(
376                {'' : 'You cannot abort part of a synchronous job execution '
377                      '(%d/%s), %d included, %d expected'
378                      % (queue_entry.job.id, queue_entry.execution_subdir,
379                         execution_count, queue_entry.job.synch_count)})
380
381
382def check_modify_host(update_data):
383    """
384    Sanity check modify_host* requests.
385
386    @param update_data: A dictionary with the changes to make to a host
387            or hosts.
388    """
389    # Only the scheduler (monitor_db) is allowed to modify Host status.
390    # Otherwise race conditions happen as a hosts state is changed out from
391    # beneath tasks being run on a host.
392    if 'status' in update_data:
393        raise model_logic.ValidationError({
394                'status': 'Host status can not be modified by the frontend.'})
395
396
397def check_modify_host_locking(host, update_data):
398    """
399    Checks when locking/unlocking has been requested if the host is already
400    locked/unlocked.
401
402    @param host: models.Host object to be modified
403    @param update_data: A dictionary with the changes to make to the host.
404    """
405    locked = update_data.get('locked', None)
406    lock_reason = update_data.get('lock_reason', None)
407    if locked is not None:
408        if locked and host.locked:
409            raise model_logic.ValidationError({
410                    'locked': 'Host %s already locked by %s on %s.' %
411                    (host.hostname, host.locked_by, host.lock_time)})
412        if not locked and not host.locked:
413            raise model_logic.ValidationError({
414                    'locked': 'Host %s already unlocked.' % host.hostname})
415        if locked and not lock_reason and not host.locked:
416            raise model_logic.ValidationError({
417                    'locked': 'Please provide a reason for locking Host %s' %
418                    host.hostname})
419
420
421def get_motd():
422    dirname = os.path.dirname(__file__)
423    filename = os.path.join(dirname, "..", "..", "motd.txt")
424    text = ''
425    try:
426        fp = open(filename, "r")
427        try:
428            text = fp.read()
429        finally:
430            fp.close()
431    except:
432        pass
433
434    return text
435
436
437def _get_metahost_counts(metahost_objects):
438    metahost_counts = {}
439    for metahost in metahost_objects:
440        metahost_counts.setdefault(metahost, 0)
441        metahost_counts[metahost] += 1
442    return metahost_counts
443
444
445def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
446    hosts = []
447    one_time_hosts = []
448    meta_hosts = []
449    hostless = False
450
451    queue_entries = job.hostqueueentry_set.all()
452    if queue_entry_filter_data:
453        queue_entries = models.HostQueueEntry.query_objects(
454            queue_entry_filter_data, initial_query=queue_entries)
455
456    for queue_entry in queue_entries:
457        if (queue_entry.host and (preserve_metahosts or
458                                  not queue_entry.meta_host)):
459            if queue_entry.deleted:
460                continue
461            if queue_entry.host.invalid:
462                one_time_hosts.append(queue_entry.host)
463            else:
464                hosts.append(queue_entry.host)
465        elif queue_entry.meta_host:
466            meta_hosts.append(queue_entry.meta_host)
467        else:
468            hostless = True
469
470    meta_host_counts = _get_metahost_counts(meta_hosts)
471
472    info = dict(dependencies=[label.name for label
473                              in job.dependency_labels.all()],
474                hosts=hosts,
475                meta_hosts=meta_hosts,
476                meta_host_counts=meta_host_counts,
477                one_time_hosts=one_time_hosts,
478                hostless=hostless)
479    return info
480
481
482def check_for_duplicate_hosts(host_objects):
483    host_counts = collections.Counter(host_objects)
484    duplicate_hostnames = {host.hostname
485                           for host, count in host_counts.iteritems()
486                           if count > 1}
487    if duplicate_hostnames:
488        raise model_logic.ValidationError(
489                {'hosts' : 'Duplicate hosts: %s'
490                 % ', '.join(duplicate_hostnames)})
491
492
493def create_new_job(owner, options, host_objects, metahost_objects):
494    all_host_objects = host_objects + metahost_objects
495    dependencies = options.get('dependencies', [])
496    synch_count = options.get('synch_count')
497
498    if synch_count is not None and synch_count > len(all_host_objects):
499        raise model_logic.ValidationError(
500                {'hosts':
501                 'only %d hosts provided for job with synch_count = %d' %
502                 (len(all_host_objects), synch_count)})
503
504    check_for_duplicate_hosts(host_objects)
505
506    for label_name in dependencies:
507        if provision.is_for_special_action(label_name):
508            # TODO: We could save a few queries
509            # if we had a bulk ensure-label-exists function, which used
510            # a bulk .get() call. The win is probably very small.
511            _ensure_label_exists(label_name)
512
513    # This only checks targeted hosts, not hosts eligible due to the metahost
514    check_job_dependencies(host_objects, dependencies)
515    check_job_metahost_dependencies(metahost_objects, dependencies)
516
517    options['dependencies'] = list(
518            models.Label.objects.filter(name__in=dependencies))
519
520    job = models.Job.create(owner=owner, options=options,
521                            hosts=all_host_objects)
522    job.queue(all_host_objects,
523              is_template=options.get('is_template', False))
524    return job.id
525
526
527def _ensure_label_exists(name):
528    """
529    Ensure that a label called |name| exists in the Django models.
530
531    This function is to be called from within afe rpcs only, as an
532    alternative to server.cros.provision.ensure_label_exists(...). It works
533    by Django model manipulation, rather than by making another create_label
534    rpc call.
535
536    @param name: the label to check for/create.
537    @raises ValidationError: There was an error in the response that was
538                             not because the label already existed.
539    @returns True is a label was created, False otherwise.
540    """
541    # Make sure this function is not called on shards but only on main.
542    assert not server_utils.is_shard()
543    try:
544        models.Label.objects.get(name=name)
545    except models.Label.DoesNotExist:
546        try:
547            new_label = models.Label.objects.create(name=name)
548            new_label.save()
549            return True
550        except django.db.utils.IntegrityError as e:
551            # It is possible that another suite/test already
552            # created the label between the check and save.
553            if DUPLICATE_KEY_MSG in str(e):
554                return False
555            else:
556                raise
557    return False
558
559
560def find_platform(hostname, label_list):
561    """
562    Figure out the platform name for the given host
563    object.  If none, the return value for either will be None.
564
565    @param hostname: The hostname to find platform.
566    @param label_list: The label list to find platform.
567
568    @returns platform name for the given host.
569    """
570    platforms = [label.name for label in label_list if label.platform]
571    if not platforms:
572        platform = None
573    else:
574        platform = platforms[0]
575
576    if len(platforms) > 1:
577        raise ValueError('Host %s has more than one platform: %s' %
578                         (hostname, ', '.join(platforms)))
579
580    return platform
581
582
583# support for get_host_queue_entries_and_special_tasks()
584
585def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
586    return dict(type=type,
587                host=entry['host'],
588                job=job_dict,
589                execution_path=exec_path,
590                status=status,
591                started_on=started_on,
592                id=str(entry['id']) + type,
593                oid=entry['id'])
594
595
596def _special_task_to_dict(task, queue_entries):
597    """Transforms a special task dictionary to another form of dictionary.
598
599    @param task           Special task as a dictionary type
600    @param queue_entries  Host queue entries as a list of dictionaries.
601
602    @return Transformed dictionary for a special task.
603    """
604    job_dict = None
605    if task['queue_entry']:
606        # Scan queue_entries to get the job detail info.
607        for qentry in queue_entries:
608            if task['queue_entry']['id'] == qentry['id']:
609                job_dict = qentry['job']
610                break
611        # If not found, get it from DB.
612        if job_dict is None:
613            job = models.Job.objects.get(id=task['queue_entry']['job'])
614            job_dict = job.get_object_dict()
615
616    exec_path = server_utils.get_special_task_exec_path(
617            task['host']['hostname'], task['id'], task['task'],
618            time_utils.time_string_to_datetime(task['time_requested']))
619    status = server_utils.get_special_task_status(
620            task['is_complete'], task['success'], task['is_active'])
621    return _common_entry_to_dict(task, task['task'], job_dict,
622            exec_path, status, task['time_started'])
623
624
625def _queue_entry_to_dict(queue_entry):
626    job_dict = queue_entry['job']
627    tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
628    exec_path = server_utils.get_hqe_exec_path(tag,
629                                               queue_entry['execution_subdir'])
630    return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
631            queue_entry['status'], queue_entry['started_on'])
632
633
634def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
635                                                 queue_entries):
636    """
637    Prepare for serialization the interleaved entries of host queue entries
638    and special tasks.
639    Each element in the entries is a dictionary type.
640    The special task dictionary has only a job id for a job and lacks
641    the detail of the job while the host queue entry dictionary has.
642    queue_entries is used to look up the job detail info.
643
644    @param interleaved_entries  Host queue entries and special tasks as a list
645                                of dictionaries.
646    @param queue_entries        Host queue entries as a list of dictionaries.
647
648    @return A post-processed list of dictionaries that is to be serialized.
649    """
650    dict_list = []
651    for e in interleaved_entries:
652        # Distinguish the two mixed entries based on the existence of
653        # the key "task". If an entry has the key, the entry is for
654        # special task. Otherwise, host queue entry.
655        if 'task' in e:
656            dict_list.append(_special_task_to_dict(e, queue_entries))
657        else:
658            dict_list.append(_queue_entry_to_dict(e))
659    return prepare_for_serialization(dict_list)
660
661
662def _compute_next_job_for_tasks(queue_entries, special_tasks):
663    """
664    For each task, try to figure out the next job that ran after that task.
665    This is done using two pieces of information:
666    * if the task has a queue entry, we can use that entry's job ID.
667    * if the task has a time_started, we can try to compare that against the
668      started_on field of queue_entries. this isn't guaranteed to work perfectly
669      since queue_entries may also have null started_on values.
670    * if the task has neither, or if use of time_started fails, just use the
671      last computed job ID.
672
673    @param queue_entries    Host queue entries as a list of dictionaries.
674    @param special_tasks    Special tasks as a list of dictionaries.
675    """
676    next_job_id = None # most recently computed next job
677    hqe_index = 0 # index for scanning by started_on times
678    for task in special_tasks:
679        if task['queue_entry']:
680            next_job_id = task['queue_entry']['job']
681        elif task['time_started'] is not None:
682            for queue_entry in queue_entries[hqe_index:]:
683                if queue_entry['started_on'] is None:
684                    continue
685                t1 = time_utils.time_string_to_datetime(
686                        queue_entry['started_on'])
687                t2 = time_utils.time_string_to_datetime(task['time_started'])
688                if t1 < t2:
689                    break
690                next_job_id = queue_entry['job']['id']
691
692        task['next_job_id'] = next_job_id
693
694        # advance hqe_index to just after next_job_id
695        if next_job_id is not None:
696            for queue_entry in queue_entries[hqe_index:]:
697                if queue_entry['job']['id'] < next_job_id:
698                    break
699                hqe_index += 1
700
701
702def interleave_entries(queue_entries, special_tasks):
703    """
704    Both lists should be ordered by descending ID.
705    """
706    _compute_next_job_for_tasks(queue_entries, special_tasks)
707
708    # start with all special tasks that've run since the last job
709    interleaved_entries = []
710    for task in special_tasks:
711        if task['next_job_id'] is not None:
712            break
713        interleaved_entries.append(task)
714
715    # now interleave queue entries with the remaining special tasks
716    special_task_index = len(interleaved_entries)
717    for queue_entry in queue_entries:
718        interleaved_entries.append(queue_entry)
719        # add all tasks that ran between this job and the previous one
720        for task in special_tasks[special_task_index:]:
721            if task['next_job_id'] < queue_entry['job']['id']:
722                break
723            interleaved_entries.append(task)
724            special_task_index += 1
725
726    return interleaved_entries
727
728
729def bucket_hosts_by_shard(host_objs):
730    """Figure out which hosts are on which shards.
731
732    @param host_objs: A list of host objects.
733
734    @return: A map of shard hostname: list of hosts on the shard.
735    """
736    shard_host_map = collections.defaultdict(list)
737    for host in host_objs:
738        if host.shard:
739            shard_host_map[host.shard.hostname].append(host.hostname)
740    return shard_host_map
741
742
743def create_job_common(
744        name,
745        priority,
746        control_type,
747        control_file=None,
748        hosts=(),
749        meta_hosts=(),
750        one_time_hosts=(),
751        synch_count=None,
752        is_template=False,
753        timeout=None,
754        timeout_mins=None,
755        max_runtime_mins=None,
756        run_verify=True,
757        email_list='',
758        dependencies=(),
759        reboot_before=None,
760        reboot_after=None,
761        parse_failed_repair=None,
762        hostless=False,
763        keyvals=None,
764        drone_set=None,
765        parent_job_id=None,
766        run_reset=True,
767        require_ssp=None):
768    #pylint: disable-msg=C0111
769    """
770    Common code between creating "standard" jobs and creating parameterized jobs
771    """
772    # input validation
773    host_args_passed = any((hosts, meta_hosts, one_time_hosts))
774    if hostless:
775        if host_args_passed:
776            raise model_logic.ValidationError({
777                    'hostless': 'Hostless jobs cannot include any hosts!'})
778        if control_type != control_data.CONTROL_TYPE_NAMES.SERVER:
779            raise model_logic.ValidationError({
780                    'control_type': 'Hostless jobs cannot use client-side '
781                                    'control files'})
782    elif not host_args_passed:
783        raise model_logic.ValidationError({
784            'arguments' : "For host jobs, you must pass at least one of"
785                          " 'hosts', 'meta_hosts', 'one_time_hosts'."
786            })
787    label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
788
789    # convert hostnames & meta hosts to host/label objects
790    host_objects = models.Host.smart_get_bulk(hosts)
791    _validate_host_job_sharding(host_objects)
792    for host in one_time_hosts:
793        this_host = models.Host.create_one_time_host(host)
794        host_objects.append(this_host)
795
796    metahost_objects = []
797    meta_host_labels_by_name = {label.name: label for label in label_objects}
798    for label_name in meta_hosts:
799        if label_name in meta_host_labels_by_name:
800            metahost_objects.append(meta_host_labels_by_name[label_name])
801        else:
802            raise model_logic.ValidationError(
803                {'meta_hosts' : 'Label "%s" not found' % label_name})
804
805    options = dict(name=name,
806                   priority=priority,
807                   control_file=control_file,
808                   control_type=control_type,
809                   is_template=is_template,
810                   timeout=timeout,
811                   timeout_mins=timeout_mins,
812                   max_runtime_mins=max_runtime_mins,
813                   synch_count=synch_count,
814                   run_verify=run_verify,
815                   email_list=email_list,
816                   dependencies=dependencies,
817                   reboot_before=reboot_before,
818                   reboot_after=reboot_after,
819                   parse_failed_repair=parse_failed_repair,
820                   keyvals=keyvals,
821                   drone_set=drone_set,
822                   parent_job_id=parent_job_id,
823                   # TODO(crbug.com/873716) DEPRECATED. Remove entirely.
824                   test_retry=0,
825                   run_reset=run_reset,
826                   require_ssp=require_ssp)
827
828    return create_new_job(owner=models.User.current_user().login,
829                          options=options,
830                          host_objects=host_objects,
831                          metahost_objects=metahost_objects)
832
833
834def _validate_host_job_sharding(host_objects):
835    """Check that the hosts obey job sharding rules."""
836    if not (server_utils.is_shard()
837            or _allowed_hosts_for_main_job(host_objects)):
838        shard_host_map = bucket_hosts_by_shard(host_objects)
839        raise ValueError(
840                'The following hosts are on shard(s), please create '
841                'seperate jobs for hosts on each shard: %s ' %
842                shard_host_map)
843
844
845def _allowed_hosts_for_main_job(host_objects):
846    """Check that the hosts are allowed for a job on main."""
847    # We disallow the following jobs on main:
848    #   num_shards > 1: this is a job spanning across multiple shards.
849    #   num_shards == 1 but number of hosts on shard is less
850    #   than total number of hosts: this is a job that spans across
851    #   one shard and the main.
852    shard_host_map = bucket_hosts_by_shard(host_objects)
853    num_shards = len(shard_host_map)
854    if num_shards > 1:
855        return False
856    if num_shards == 1:
857        hosts_on_shard = shard_host_map.values()[0]
858        assert len(hosts_on_shard) <= len(host_objects)
859        return len(hosts_on_shard) == len(host_objects)
860    else:
861        return True
862
863
864def encode_ascii(control_file):
865    """Force a control file to only contain ascii characters.
866
867    @param control_file: Control file to encode.
868
869    @returns the control file in an ascii encoding.
870
871    @raises error.ControlFileMalformed: if encoding fails.
872    """
873    try:
874        return control_file.encode('ascii')
875    except UnicodeDecodeError as e:
876        raise error.ControlFileMalformed(str(e))
877
878
879def get_wmatrix_url():
880    """Get wmatrix url from config file.
881
882    @returns the wmatrix url or an empty string.
883    """
884    return global_config.global_config.get_config_value('AUTOTEST_WEB',
885                                                        'wmatrix_url',
886                                                        default='')
887
888
889def get_stainless_url():
890    """Get stainless url from config file.
891
892    @returns the stainless url or an empty string.
893    """
894    return global_config.global_config.get_config_value('AUTOTEST_WEB',
895                                                        'stainless_url',
896                                                        default='')
897
898
899def inject_times_to_filter(start_time_key=None, end_time_key=None,
900                         start_time_value=None, end_time_value=None,
901                         **filter_data):
902    """Inject the key value pairs of start and end time if provided.
903
904    @param start_time_key: A string represents the filter key of start_time.
905    @param end_time_key: A string represents the filter key of end_time.
906    @param start_time_value: Start_time value.
907    @param end_time_value: End_time value.
908
909    @returns the injected filter_data.
910    """
911    if start_time_value:
912        filter_data[start_time_key] = start_time_value
913    if end_time_value:
914        filter_data[end_time_key] = end_time_value
915    return filter_data
916
917
918def inject_times_to_hqe_special_tasks_filters(filter_data_common,
919                                              start_time, end_time):
920    """Inject start and end time to hqe and special tasks filters.
921
922    @param filter_data_common: Common filter for hqe and special tasks.
923    @param start_time_key: A string represents the filter key of start_time.
924    @param end_time_key: A string represents the filter key of end_time.
925
926    @returns a pair of hqe and special tasks filters.
927    """
928    filter_data_special_tasks = filter_data_common.copy()
929    return (inject_times_to_filter('started_on__gte', 'started_on__lte',
930                                   start_time, end_time, **filter_data_common),
931           inject_times_to_filter('time_started__gte', 'time_started__lte',
932                                  start_time, end_time,
933                                  **filter_data_special_tasks))
934
935
936def retrieve_shard(shard_hostname):
937    """
938    Retrieves the shard with the given hostname from the database.
939
940    @param shard_hostname: Hostname of the shard to retrieve
941
942    @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
943
944    @returns: Shard object
945    """
946    return models.Shard.smart_get(shard_hostname)
947
948
949def find_records_for_shard(shard, known_job_ids, known_host_ids):
950    """Find records that should be sent to a shard.
951
952    @param shard: Shard to find records for.
953    @param known_job_ids: List of ids of jobs the shard already has.
954    @param known_host_ids: List of ids of hosts the shard already has.
955
956    @returns: Tuple of lists:
957              (hosts, jobs, suite_job_keyvals, invalid_host_ids)
958    """
959    hosts, invalid_host_ids = models.Host.assign_to_shard(
960            shard, known_host_ids)
961    jobs = models.Job.assign_to_shard(shard, known_job_ids)
962    parent_job_ids = [job.parent_job_id for job in jobs]
963    suite_job_keyvals = models.JobKeyval.objects.filter(
964            job_id__in=parent_job_ids)
965    return hosts, jobs, suite_job_keyvals, invalid_host_ids
966
967
968def _persist_records_with_type_sent_from_shard(
969    shard, records, record_type, *args, **kwargs):
970    """
971    Handle records of a specified type that were sent to the shard main.
972
973    @param shard: The shard the records were sent from.
974    @param records: The records sent in their serialized format.
975    @param record_type: Type of the objects represented by records.
976    @param args: Additional arguments that will be passed on to the sanity
977                 checks.
978    @param kwargs: Additional arguments that will be passed on to the sanity
979                  checks.
980
981    @raises error.UnallowedRecordsSentToMain if any of the sanity checks fail.
982
983    @returns: List of primary keys of the processed records.
984    """
985    pks = []
986    for serialized_record in records:
987        pk = serialized_record['id']
988        try:
989            current_record = record_type.objects.get(pk=pk)
990        except record_type.DoesNotExist:
991            raise error.UnallowedRecordsSentToMain(
992                'Object with pk %s of type %s does not exist on main.' % (
993                    pk, record_type))
994
995        try:
996            current_record.sanity_check_update_from_shard(
997                shard, serialized_record, *args, **kwargs)
998        except error.IgnorableUnallowedRecordsSentToMain:
999            # An illegal record change was attempted, but it was of a non-fatal
1000            # variety. Silently skip this record.
1001            pass
1002        else:
1003            current_record.update_from_serialized(serialized_record)
1004            pks.append(pk)
1005
1006    return pks
1007
1008
1009def persist_records_sent_from_shard(shard, jobs, hqes):
1010    """
1011    Sanity checking then saving serialized records sent to main from shard.
1012
1013    During heartbeats shards upload jobs and hostqueuentries. This performs
1014    some sanity checks on these and then updates the existing records for those
1015    entries with the updated ones from the heartbeat.
1016
1017    The sanity checks include:
1018    - Checking if the objects sent already exist on the main.
1019    - Checking if the objects sent were assigned to this shard.
1020    - hostqueueentries must be sent together with their jobs.
1021
1022    @param shard: The shard the records were sent from.
1023    @param jobs: The jobs the shard sent.
1024    @param hqes: The hostqueuentries the shart sent.
1025
1026    @raises error.UnallowedRecordsSentToMain if any of the sanity checks fail.
1027    """
1028    job_ids_persisted = _persist_records_with_type_sent_from_shard(
1029            shard, jobs, models.Job)
1030    _persist_records_with_type_sent_from_shard(
1031            shard, hqes, models.HostQueueEntry,
1032            job_ids_sent=job_ids_persisted)
1033
1034
1035def forward_single_host_rpc_to_shard(func):
1036    """This decorator forwards rpc calls that modify a host to a shard.
1037
1038    If a host is assigned to a shard, rpcs that change his attributes should be
1039    forwarded to the shard.
1040
1041    This assumes the first argument of the function represents a host id.
1042
1043    @param func: The function to decorate
1044
1045    @returns: The function to replace func with.
1046    """
1047    def replacement(**kwargs):
1048        # Only keyword arguments can be accepted here, as we need the argument
1049        # names to send the rpc. serviceHandler always provides arguments with
1050        # their keywords, so this is not a problem.
1051
1052        # A host record (identified by kwargs['id']) can be deleted in
1053        # func(). Therefore, we should save the data that can be needed later
1054        # before func() is called.
1055        shard_hostname = None
1056        host = models.Host.smart_get(kwargs['id'])
1057        if host and host.shard:
1058            shard_hostname = host.shard.hostname
1059        ret = func(**kwargs)
1060        if shard_hostname and not server_utils.is_shard():
1061            run_rpc_on_multiple_hostnames(func.func_name,
1062                                          [shard_hostname],
1063                                          **kwargs)
1064        return ret
1065
1066    return replacement
1067
1068
1069def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
1070    """Fanout the given rpc to shards of given hosts.
1071
1072    @param host_objs: Host objects for the rpc.
1073    @param rpc_name: The name of the rpc.
1074    @param include_hostnames: If True, include the hostnames in the kwargs.
1075        Hostnames are not always necessary, this functions is designed to
1076        send rpcs to the shard a host is on, the rpcs themselves could be
1077        related to labels, acls etc.
1078    @param kwargs: The kwargs for the rpc.
1079    """
1080    # Figure out which hosts are on which shards.
1081    shard_host_map = bucket_hosts_by_shard(host_objs)
1082
1083    # Execute the rpc against the appropriate shards.
1084    for shard, hostnames in shard_host_map.iteritems():
1085        if include_hostnames:
1086            kwargs['hosts'] = hostnames
1087        try:
1088            run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
1089        except:
1090            ei = sys.exc_info()
1091            new_exc = error.RPCException('RPC %s failed on shard %s due to '
1092                    '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
1093            raise new_exc.__class__, new_exc, ei[2]
1094
1095
1096def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
1097    """Runs an rpc to multiple AFEs
1098
1099    This is i.e. used to propagate changes made to hosts after they are assigned
1100    to a shard.
1101
1102    @param rpc_call: Name of the rpc endpoint to call.
1103    @param shard_hostnames: List of hostnames to run the rpcs on.
1104    @param **kwargs: Keyword arguments to pass in the rpcs.
1105    """
1106    # Make sure this function is not called on shards but only on main.
1107    assert not server_utils.is_shard()
1108    for shard_hostname in shard_hostnames:
1109        afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
1110                                            user=thread_local.get_user())
1111        afe.run(rpc_call, **kwargs)
1112
1113
1114def get_label(name):
1115    """Gets a label object using a given name.
1116
1117    @param name: Label name.
1118    @raises model.Label.DoesNotExist: when there is no label matching
1119                                      the given name.
1120    @return: a label object matching the given name.
1121    """
1122    try:
1123        label = models.Label.smart_get(name)
1124    except models.Label.DoesNotExist:
1125        return None
1126    return label
1127
1128
1129# TODO: hide the following rpcs under is_moblab
1130def moblab_only(func):
1131    """Ensure moblab specific functions only run on Moblab devices."""
1132    def verify(*args, **kwargs):
1133        if not server_utils.is_moblab():
1134            raise error.RPCException('RPC: %s can only run on Moblab Systems!',
1135                                     func.__name__)
1136        return func(*args, **kwargs)
1137    return verify
1138
1139
1140def route_rpc_to_main(func):
1141    """Route RPC to main AFE.
1142
1143    When a shard receives an RPC decorated by this, the RPC is just
1144    forwarded to the main.
1145    When the main gets the RPC, the RPC function is executed.
1146
1147    @param func: An RPC function to decorate
1148
1149    @returns: A function replacing the RPC func.
1150    """
1151    argspec = inspect.getargspec(func)
1152    if argspec.varargs is not None:
1153        raise Exception('RPC function must not have *args.')
1154
1155    @wraps(func)
1156    def replacement(*args, **kwargs):
1157        """We need special handling when decorating an RPC that can be called
1158        directly using positional arguments.
1159
1160        One example is rpc_interface.create_job().
1161        rpc_interface.create_job_page_handler() calls the function using both
1162        positional and keyword arguments.  Since frontend.RpcClient.run()
1163        takes only keyword arguments for an RPC, positional arguments of the
1164        RPC function need to be transformed into keyword arguments.
1165        """
1166        kwargs = _convert_to_kwargs_only(func, args, kwargs)
1167        if server_utils.is_shard():
1168            afe = frontend_wrappers.RetryingAFE(
1169                    server=server_utils.get_global_afe_hostname(),
1170                    user=thread_local.get_user())
1171            return afe.run(func.func_name, **kwargs)
1172        return func(**kwargs)
1173
1174    return replacement
1175
1176
1177def _convert_to_kwargs_only(func, args, kwargs):
1178    """Convert a function call's arguments to a kwargs dict.
1179
1180    This is best illustrated with an example.  Given:
1181
1182    def foo(a, b, **kwargs):
1183        pass
1184    _to_kwargs(foo, (1, 2), {'c': 3})  # corresponding to foo(1, 2, c=3)
1185
1186        foo(**kwargs)
1187
1188    @param func: function whose signature to use
1189    @param args: positional arguments of call
1190    @param kwargs: keyword arguments of call
1191
1192    @returns: kwargs dict
1193    """
1194    argspec = inspect.getargspec(func)
1195    # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}}
1196    callargs = inspect.getcallargs(func, *args, **kwargs)
1197    if argspec.keywords is None:
1198        kwargs = {}
1199    else:
1200        kwargs = callargs.pop(argspec.keywords)
1201    kwargs.update(callargs)
1202    return kwargs
1203
1204
1205def get_sample_dut(board, pool):
1206    """Get a dut with the given board and pool.
1207
1208    This method is used to help to locate a dut with the given board and pool.
1209    The dut then can be used to identify a devserver in the same subnet.
1210
1211    @param board: Name of the board.
1212    @param pool: Name of the pool.
1213
1214    @return: Name of a dut with the given board and pool.
1215    """
1216    if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board):
1217        return None
1218
1219    hosts = list(get_host_query(
1220        multiple_labels=('pool:%s' % pool, 'board:%s' % board),
1221        exclude_only_if_needed_labels=False,
1222        valid_only=True,
1223        filter_data={},
1224    ))
1225    if not hosts:
1226        return None
1227    else:
1228        return hosts[0].hostname
1229