1#pylint: disable-msg=C0111
2"""\
3Utility functions for rpc_interface.py.  We keep them in a separate file so that
4only RPC interface functions go into that file.
5"""
6
7__author__ = 'showard@google.com (Steve Howard)'
8
9import datetime
10from functools import wraps
11import inspect
12import os
13import sys
14import django.db.utils
15import django.http
16
17from autotest_lib.frontend import thread_local
18from autotest_lib.frontend.afe import models, model_logic
19from autotest_lib.client.common_lib import control_data, error
20from autotest_lib.client.common_lib import global_config, priorities
21from autotest_lib.client.common_lib import time_utils
22from autotest_lib.client.common_lib.cros.graphite import autotest_stats
23from autotest_lib.server import utils as server_utils
24from autotest_lib.server.cros import provision
25from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
26
27NULL_DATETIME = datetime.datetime.max
28NULL_DATE = datetime.date.max
29DUPLICATE_KEY_MSG = 'Duplicate entry'
30
31def prepare_for_serialization(objects):
32    """
33    Prepare Python objects to be returned via RPC.
34    @param objects: objects to be prepared.
35    """
36    if (isinstance(objects, list) and len(objects) and
37        isinstance(objects[0], dict) and 'id' in objects[0]):
38        objects = gather_unique_dicts(objects)
39    return _prepare_data(objects)
40
41
42def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
43    """
44    Prepare a Django query to be returned via RPC as a sequence of nested
45    dictionaries.
46
47    @param query - A Django model query object with a select_related() method.
48    @param nested_dict_column_names - A list of column/attribute names for the
49            rows returned by query to expand into nested dictionaries using
50            their get_object_dict() method when not None.
51
52    @returns An list suitable to returned in an RPC.
53    """
54    all_dicts = []
55    for row in query.select_related():
56        row_dict = row.get_object_dict()
57        for column in nested_dict_column_names:
58            if row_dict[column] is not None:
59                row_dict[column] = getattr(row, column).get_object_dict()
60        all_dicts.append(row_dict)
61    return prepare_for_serialization(all_dicts)
62
63
64def _prepare_data(data):
65    """
66    Recursively process data structures, performing necessary type
67    conversions to values in data to allow for RPC serialization:
68    -convert datetimes to strings
69    -convert tuples and sets to lists
70    """
71    if isinstance(data, dict):
72        new_data = {}
73        for key, value in data.iteritems():
74            new_data[key] = _prepare_data(value)
75        return new_data
76    elif (isinstance(data, list) or isinstance(data, tuple) or
77          isinstance(data, set)):
78        return [_prepare_data(item) for item in data]
79    elif isinstance(data, datetime.date):
80        if data is NULL_DATETIME or data is NULL_DATE:
81            return None
82        return str(data)
83    else:
84        return data
85
86
87def fetchall_as_list_of_dicts(cursor):
88    """
89    Converts each row in the cursor to a dictionary so that values can be read
90    by using the column name.
91    @param cursor: The database cursor to read from.
92    @returns: A list of each row in the cursor as a dictionary.
93    """
94    desc = cursor.description
95    return [ dict(zip([col[0] for col in desc], row))
96             for row in cursor.fetchall() ]
97
98
99def raw_http_response(response_data, content_type=None):
100    response = django.http.HttpResponse(response_data, mimetype=content_type)
101    response['Content-length'] = str(len(response.content))
102    return response
103
104
105def gather_unique_dicts(dict_iterable):
106    """\
107    Pick out unique objects (by ID) from an iterable of object dicts.
108    """
109    id_set = set()
110    result = []
111    for obj in dict_iterable:
112        if obj['id'] not in id_set:
113            id_set.add(obj['id'])
114            result.append(obj)
115    return result
116
117
118def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
119    """\
120    Generate a SQL WHERE clause for job status filtering, and return it in
121    a dict of keyword args to pass to query.extra().
122    * not_yet_run: all HQEs are Queued
123    * finished: all HQEs are complete
124    * running: everything else
125    """
126    if not (not_yet_run or running or finished):
127        return {}
128    not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
129                  'WHERE status != "%s")'
130                  % models.HostQueueEntry.Status.QUEUED)
131    not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
132                    'WHERE not complete)')
133
134    where = []
135    if not_yet_run:
136        where.append('id NOT IN ' + not_queued)
137    if running:
138        where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
139    if finished:
140        where.append('id NOT IN ' + not_finished)
141    return {'where': [' OR '.join(['(%s)' % x for x in where])]}
142
143
144def extra_job_type_filters(extra_args, suite=False,
145                           sub=False, standalone=False):
146    """\
147    Generate a SQL WHERE clause for job status filtering, and return it in
148    a dict of keyword args to pass to query.extra().
149
150    param extra_args: a dict of existing extra_args.
151
152    No more than one of the parameters should be passed as True:
153    * suite: job which is parent of other jobs
154    * sub: job with a parent job
155    * standalone: job with no child or parent jobs
156    """
157    assert not ((suite and sub) or
158                (suite and standalone) or
159                (sub and standalone)), ('Cannot specify more than one '
160                                        'filter to this function')
161
162    where = extra_args.get('where', [])
163    parent_job_id = ('DISTINCT parent_job_id')
164    child_job_id = ('id')
165    filter_common = ('(SELECT %s FROM afe_jobs '
166                     'WHERE parent_job_id IS NOT NULL)')
167
168    if suite:
169        where.append('id IN ' + filter_common % parent_job_id)
170    elif sub:
171        where.append('id IN ' + filter_common % child_job_id)
172    elif standalone:
173        where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
174                     'WHERE parent_job_id IS NOT NULL'
175                     ' AND (sub_query.parent_job_id=afe_jobs.id'
176                     ' OR sub_query.id=afe_jobs.id))')
177    else:
178        return extra_args
179
180    extra_args['where'] = where
181    return extra_args
182
183
184
185def extra_host_filters(multiple_labels=()):
186    """\
187    Generate SQL WHERE clauses for matching hosts in an intersection of
188    labels.
189    """
190    extra_args = {}
191    where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels '
192                 'where label_id=%s)')
193    extra_args['where'] = [where_str] * len(multiple_labels)
194    extra_args['params'] = [models.Label.smart_get(label).id
195                            for label in multiple_labels]
196    return extra_args
197
198
199def get_host_query(multiple_labels, exclude_only_if_needed_labels,
200                   exclude_atomic_group_hosts, valid_only, filter_data):
201    if valid_only:
202        query = models.Host.valid_objects.all()
203    else:
204        query = models.Host.objects.all()
205
206    if exclude_only_if_needed_labels:
207        only_if_needed_labels = models.Label.valid_objects.filter(
208            only_if_needed=True)
209        if only_if_needed_labels.count() > 0:
210            only_if_needed_ids = ','.join(
211                    str(label['id'])
212                    for label in only_if_needed_labels.values('id'))
213            query = models.Host.objects.add_join(
214                query, 'afe_hosts_labels', join_key='host_id',
215                join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)'
216                                % only_if_needed_ids),
217                suffix='_exclude_OIN', exclude=True)
218
219    if exclude_atomic_group_hosts:
220        atomic_group_labels = models.Label.valid_objects.filter(
221                atomic_group__isnull=False)
222        if atomic_group_labels.count() > 0:
223            atomic_group_label_ids = ','.join(
224                    str(atomic_group['id'])
225                    for atomic_group in atomic_group_labels.values('id'))
226            query = models.Host.objects.add_join(
227                    query, 'afe_hosts_labels', join_key='host_id',
228                    join_condition=(
229                            'afe_hosts_labels_exclude_AG.label_id IN (%s)'
230                            % atomic_group_label_ids),
231                    suffix='_exclude_AG', exclude=True)
232    try:
233        assert 'extra_args' not in filter_data
234        filter_data['extra_args'] = extra_host_filters(multiple_labels)
235        return models.Host.query_objects(filter_data, initial_query=query)
236    except models.Label.DoesNotExist as e:
237        return models.Host.objects.none()
238
239
240class InconsistencyException(Exception):
241    'Raised when a list of objects does not have a consistent value'
242
243
244def get_consistent_value(objects, field):
245    if not objects:
246        # well a list of nothing is consistent
247        return None
248
249    value = getattr(objects[0], field)
250    for obj in objects:
251        this_value = getattr(obj, field)
252        if this_value != value:
253            raise InconsistencyException(objects[0], obj)
254    return value
255
256
257def afe_test_dict_to_test_object(test_dict):
258    if not isinstance(test_dict, dict):
259        return test_dict
260
261    numerized_dict = {}
262    for key, value in test_dict.iteritems():
263        try:
264            numerized_dict[key] = int(value)
265        except (ValueError, TypeError):
266            numerized_dict[key] = value
267
268    return type('TestObject', (object,), numerized_dict)
269
270
271def prepare_generate_control_file(tests, kernel, label, profilers,
272                                  db_tests=True):
273    if db_tests:
274        test_objects = [models.Test.smart_get(test) for test in tests]
275    else:
276        test_objects = [afe_test_dict_to_test_object(test) for test in tests]
277
278    profiler_objects = [models.Profiler.smart_get(profiler)
279                        for profiler in profilers]
280    # ensure tests are all the same type
281    try:
282        test_type = get_consistent_value(test_objects, 'test_type')
283    except InconsistencyException, exc:
284        test1, test2 = exc.args
285        raise model_logic.ValidationError(
286            {'tests' : 'You cannot run both test_suites and server-side '
287             'tests together (tests %s and %s differ' % (
288            test1.name, test2.name)})
289
290    is_server = (test_type == control_data.CONTROL_TYPE.SERVER)
291    if test_objects:
292        synch_count = max(test.sync_count for test in test_objects)
293    else:
294        synch_count = 1
295    if label:
296        label = models.Label.smart_get(label)
297
298    if db_tests:
299        dependencies = set(label.name for label
300                           in models.Label.objects.filter(test__in=test_objects))
301    else:
302        dependencies = reduce(
303                set.union, [set(test.dependencies) for test in test_objects])
304
305    cf_info = dict(is_server=is_server, synch_count=synch_count,
306                   dependencies=list(dependencies))
307    return cf_info, test_objects, profiler_objects, label
308
309
310def check_job_dependencies(host_objects, job_dependencies):
311    """
312    Check that a set of machines satisfies a job's dependencies.
313    host_objects: list of models.Host objects
314    job_dependencies: list of names of labels
315    """
316    # check that hosts satisfy dependencies
317    host_ids = [host.id for host in host_objects]
318    hosts_in_job = models.Host.objects.filter(id__in=host_ids)
319    ok_hosts = hosts_in_job
320    for index, dependency in enumerate(job_dependencies):
321        if not provision.is_for_special_action(dependency):
322            ok_hosts = ok_hosts.filter(labels__name=dependency)
323    failing_hosts = (set(host.hostname for host in host_objects) -
324                     set(host.hostname for host in ok_hosts))
325    if failing_hosts:
326        raise model_logic.ValidationError(
327            {'hosts' : 'Host(s) failed to meet job dependencies (' +
328                       (', '.join(job_dependencies)) + '): ' +
329                       (', '.join(failing_hosts))})
330
331
332def check_job_metahost_dependencies(metahost_objects, job_dependencies):
333    """
334    Check that at least one machine within the metahost spec satisfies the job's
335    dependencies.
336
337    @param metahost_objects A list of label objects representing the metahosts.
338    @param job_dependencies A list of strings of the required label names.
339    @raises NoEligibleHostException If a metahost cannot run the job.
340    """
341    for metahost in metahost_objects:
342        hosts = models.Host.objects.filter(labels=metahost)
343        for label_name in job_dependencies:
344            if not provision.is_for_special_action(label_name):
345                hosts = hosts.filter(labels__name=label_name)
346        if not any(hosts):
347            raise error.NoEligibleHostException("No hosts within %s satisfy %s."
348                    % (metahost.name, ', '.join(job_dependencies)))
349
350
351def _execution_key_for(host_queue_entry):
352    return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
353
354
355def check_abort_synchronous_jobs(host_queue_entries):
356    # ensure user isn't aborting part of a synchronous autoserv execution
357    count_per_execution = {}
358    for queue_entry in host_queue_entries:
359        key = _execution_key_for(queue_entry)
360        count_per_execution.setdefault(key, 0)
361        count_per_execution[key] += 1
362
363    for queue_entry in host_queue_entries:
364        if not queue_entry.execution_subdir:
365            continue
366        execution_count = count_per_execution[_execution_key_for(queue_entry)]
367        if execution_count < queue_entry.job.synch_count:
368            raise model_logic.ValidationError(
369                {'' : 'You cannot abort part of a synchronous job execution '
370                      '(%d/%s), %d included, %d expected'
371                      % (queue_entry.job.id, queue_entry.execution_subdir,
372                         execution_count, queue_entry.job.synch_count)})
373
374
375def check_atomic_group_create_job(synch_count, host_objects, metahost_objects,
376                                  dependencies, atomic_group):
377    """
378    Attempt to reject create_job requests with an atomic group that
379    will be impossible to schedule.  The checks are not perfect but
380    should catch the most obvious issues.
381
382    @param synch_count - The job's minimum synch count.
383    @param host_objects - A list of models.Host instances.
384    @param metahost_objects - A list of models.Label instances.
385    @param dependencies - A list of job dependency label names.
386    @param labels_by_name - A dictionary mapping label names to models.Label
387            instance.  Used to look up instances for dependencies.
388
389    @raises model_logic.ValidationError - When an issue is found.
390    """
391    # If specific host objects were supplied with an atomic group, verify
392    # that there are enough to satisfy the synch_count.
393    minimum_required = synch_count or 1
394    if (host_objects and not metahost_objects and
395        len(host_objects) < minimum_required):
396        raise model_logic.ValidationError(
397                {'hosts':
398                 'only %d hosts provided for job with synch_count = %d' %
399                 (len(host_objects), synch_count)})
400
401    # Check that the atomic group has a hope of running this job
402    # given any supplied metahosts and dependancies that may limit.
403
404    # Get a set of hostnames in the atomic group.
405    possible_hosts = set()
406    for label in atomic_group.label_set.all():
407        possible_hosts.update(h.hostname for h in label.host_set.all())
408
409    # Filter out hosts that don't match all of the job dependency labels.
410    for label in models.Label.objects.filter(name__in=dependencies):
411        hosts_in_label = (h.hostname for h in label.host_set.all())
412        possible_hosts.intersection_update(hosts_in_label)
413
414    if not host_objects and not metahost_objects:
415        # No hosts or metahosts are required to queue an atomic group Job.
416        # However, if they are given, we respect them below.
417        host_set = possible_hosts
418    else:
419        host_set = set(host.hostname for host in host_objects)
420        unusable_host_set = host_set.difference(possible_hosts)
421        if unusable_host_set:
422            raise model_logic.ValidationError(
423                {'hosts': 'Hosts "%s" are not in Atomic Group "%s"' %
424                 (', '.join(sorted(unusable_host_set)), atomic_group.name)})
425
426    # Lookup hosts provided by each meta host and merge them into the
427    # host_set for final counting.
428    for meta_host in metahost_objects:
429        meta_possible = possible_hosts.copy()
430        hosts_in_meta_host = (h.hostname for h in meta_host.host_set.all())
431        meta_possible.intersection_update(hosts_in_meta_host)
432
433        # Count all hosts that this meta_host will provide.
434        host_set.update(meta_possible)
435
436    if len(host_set) < minimum_required:
437        raise model_logic.ValidationError(
438                {'atomic_group_name':
439                 'Insufficient hosts in Atomic Group "%s" with the'
440                 ' supplied dependencies and meta_hosts.' %
441                 (atomic_group.name,)})
442
443
444def check_modify_host(update_data):
445    """
446    Sanity check modify_host* requests.
447
448    @param update_data: A dictionary with the changes to make to a host
449            or hosts.
450    """
451    # Only the scheduler (monitor_db) is allowed to modify Host status.
452    # Otherwise race conditions happen as a hosts state is changed out from
453    # beneath tasks being run on a host.
454    if 'status' in update_data:
455        raise model_logic.ValidationError({
456                'status': 'Host status can not be modified by the frontend.'})
457
458
459def check_modify_host_locking(host, update_data):
460    """
461    Checks when locking/unlocking has been requested if the host is already
462    locked/unlocked.
463
464    @param host: models.Host object to be modified
465    @param update_data: A dictionary with the changes to make to the host.
466    """
467    locked = update_data.get('locked', None)
468    lock_reason = update_data.get('lock_reason', None)
469    if locked is not None:
470        if locked and host.locked:
471            raise model_logic.ValidationError({
472                    'locked': 'Host %s already locked by %s on %s.' %
473                    (host.hostname, host.locked_by, host.lock_time)})
474        if not locked and not host.locked:
475            raise model_logic.ValidationError({
476                    'locked': 'Host %s already unlocked.' % host.hostname})
477        if locked and not lock_reason and not host.locked:
478            raise model_logic.ValidationError({
479                    'locked': 'Please provide a reason for locking Host %s' %
480                    host.hostname})
481
482
483def get_motd():
484    dirname = os.path.dirname(__file__)
485    filename = os.path.join(dirname, "..", "..", "motd.txt")
486    text = ''
487    try:
488        fp = open(filename, "r")
489        try:
490            text = fp.read()
491        finally:
492            fp.close()
493    except:
494        pass
495
496    return text
497
498
499def _get_metahost_counts(metahost_objects):
500    metahost_counts = {}
501    for metahost in metahost_objects:
502        metahost_counts.setdefault(metahost, 0)
503        metahost_counts[metahost] += 1
504    return metahost_counts
505
506
507def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
508    hosts = []
509    one_time_hosts = []
510    meta_hosts = []
511    atomic_group = None
512    hostless = False
513
514    queue_entries = job.hostqueueentry_set.all()
515    if queue_entry_filter_data:
516        queue_entries = models.HostQueueEntry.query_objects(
517            queue_entry_filter_data, initial_query=queue_entries)
518
519    for queue_entry in queue_entries:
520        if (queue_entry.host and (preserve_metahosts or
521                                  not queue_entry.meta_host)):
522            if queue_entry.deleted:
523                continue
524            if queue_entry.host.invalid:
525                one_time_hosts.append(queue_entry.host)
526            else:
527                hosts.append(queue_entry.host)
528        elif queue_entry.meta_host:
529            meta_hosts.append(queue_entry.meta_host)
530        else:
531            hostless = True
532
533        if atomic_group is None:
534            if queue_entry.atomic_group is not None:
535                atomic_group = queue_entry.atomic_group
536        else:
537            assert atomic_group.name == queue_entry.atomic_group.name, (
538                    'DB inconsistency.  HostQueueEntries with multiple atomic'
539                    ' groups on job %s: %s != %s' % (
540                        id, atomic_group.name, queue_entry.atomic_group.name))
541
542    meta_host_counts = _get_metahost_counts(meta_hosts)
543
544    info = dict(dependencies=[label.name for label
545                              in job.dependency_labels.all()],
546                hosts=hosts,
547                meta_hosts=meta_hosts,
548                meta_host_counts=meta_host_counts,
549                one_time_hosts=one_time_hosts,
550                atomic_group=atomic_group,
551                hostless=hostless)
552    return info
553
554
555def check_for_duplicate_hosts(host_objects):
556    host_ids = set()
557    duplicate_hostnames = set()
558    for host in host_objects:
559        if host.id in host_ids:
560            duplicate_hostnames.add(host.hostname)
561        host_ids.add(host.id)
562
563    if duplicate_hostnames:
564        raise model_logic.ValidationError(
565                {'hosts' : 'Duplicate hosts: %s'
566                 % ', '.join(duplicate_hostnames)})
567
568
569def create_new_job(owner, options, host_objects, metahost_objects,
570                   atomic_group=None):
571    all_host_objects = host_objects + metahost_objects
572    dependencies = options.get('dependencies', [])
573    synch_count = options.get('synch_count')
574
575    if atomic_group:
576        check_atomic_group_create_job(
577                synch_count, host_objects, metahost_objects,
578                dependencies, atomic_group)
579    else:
580        if synch_count is not None and synch_count > len(all_host_objects):
581            raise model_logic.ValidationError(
582                    {'hosts':
583                     'only %d hosts provided for job with synch_count = %d' %
584                     (len(all_host_objects), synch_count)})
585        atomic_hosts = models.Host.objects.filter(
586                id__in=[host.id for host in host_objects],
587                labels__atomic_group=True)
588        unusable_host_names = [host.hostname for host in atomic_hosts]
589        if unusable_host_names:
590            raise model_logic.ValidationError(
591                    {'hosts':
592                     'Host(s) "%s" are atomic group hosts but no '
593                     'atomic group was specified for this job.' %
594                     (', '.join(unusable_host_names),)})
595
596    check_for_duplicate_hosts(host_objects)
597
598    for label_name in dependencies:
599        if provision.is_for_special_action(label_name):
600            # TODO: We could save a few queries
601            # if we had a bulk ensure-label-exists function, which used
602            # a bulk .get() call. The win is probably very small.
603            _ensure_label_exists(label_name)
604
605    # This only checks targeted hosts, not hosts eligible due to the metahost
606    check_job_dependencies(host_objects, dependencies)
607    check_job_metahost_dependencies(metahost_objects, dependencies)
608
609    options['dependencies'] = list(
610            models.Label.objects.filter(name__in=dependencies))
611
612    for label in metahost_objects + options['dependencies']:
613        if label.atomic_group and not atomic_group:
614            raise model_logic.ValidationError(
615                    {'atomic_group_name':
616                     'Dependency %r requires an atomic group but no '
617                     'atomic_group_name or meta_host in an atomic group was '
618                     'specified for this job.' % label.name})
619        elif (label.atomic_group and
620              label.atomic_group.name != atomic_group.name):
621            raise model_logic.ValidationError(
622                    {'atomic_group_name':
623                     'meta_hosts or dependency %r requires atomic group '
624                     '%r instead of the supplied atomic_group_name=%r.' %
625                     (label.name, label.atomic_group.name, atomic_group.name)})
626
627    job = models.Job.create(owner=owner, options=options,
628                            hosts=all_host_objects)
629    job.queue(all_host_objects, atomic_group=atomic_group,
630              is_template=options.get('is_template', False))
631    return job.id
632
633
634def _ensure_label_exists(name):
635    """
636    Ensure that a label called |name| exists in the Django models.
637
638    This function is to be called from within afe rpcs only, as an
639    alternative to server.cros.provision.ensure_label_exists(...). It works
640    by Django model manipulation, rather than by making another create_label
641    rpc call.
642
643    @param name: the label to check for/create.
644    @raises ValidationError: There was an error in the response that was
645                             not because the label already existed.
646    @returns True is a label was created, False otherwise.
647    """
648    # Make sure this function is not called on shards but only on master.
649    assert not server_utils.is_shard()
650    try:
651        models.Label.objects.get(name=name)
652    except models.Label.DoesNotExist:
653        try:
654            new_label = models.Label.objects.create(name=name)
655            new_label.save()
656            return True
657        except django.db.utils.IntegrityError as e:
658            # It is possible that another suite/test already
659            # created the label between the check and save.
660            if DUPLICATE_KEY_MSG in str(e):
661                return False
662            else:
663                raise
664    return False
665
666
667def find_platform_and_atomic_group(host):
668    """
669    Figure out the platform name and atomic group name for the given host
670    object.  If none, the return value for either will be None.
671
672    @returns (platform name, atomic group name) for the given host.
673    """
674    platforms = [label.name for label in host.label_list if label.platform]
675    if not platforms:
676        platform = None
677    else:
678        platform = platforms[0]
679    if len(platforms) > 1:
680        raise ValueError('Host %s has more than one platform: %s' %
681                         (host.hostname, ', '.join(platforms)))
682    for label in host.label_list:
683        if label.atomic_group:
684            atomic_group_name = label.atomic_group.name
685            break
686    else:
687        atomic_group_name = None
688    # Don't check for multiple atomic groups on a host here.  That is an
689    # error but should not trip up the RPC interface.  monitor_db_cleanup
690    # deals with it.  This just returns the first one found.
691    return platform, atomic_group_name
692
693
694# support for get_host_queue_entries_and_special_tasks()
695
696def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
697    return dict(type=type,
698                host=entry['host'],
699                job=job_dict,
700                execution_path=exec_path,
701                status=status,
702                started_on=started_on,
703                id=str(entry['id']) + type,
704                oid=entry['id'])
705
706
707def _special_task_to_dict(task, queue_entries):
708    """Transforms a special task dictionary to another form of dictionary.
709
710    @param task           Special task as a dictionary type
711    @param queue_entries  Host queue entries as a list of dictionaries.
712
713    @return Transformed dictionary for a special task.
714    """
715    job_dict = None
716    if task['queue_entry']:
717        # Scan queue_entries to get the job detail info.
718        for qentry in queue_entries:
719            if task['queue_entry']['id'] == qentry['id']:
720                job_dict = qentry['job']
721                break
722        # If not found, get it from DB.
723        if job_dict is None:
724            job = models.Job.objects.get(id=task['queue_entry']['job'])
725            job_dict = job.get_object_dict()
726
727    exec_path = server_utils.get_special_task_exec_path(
728            task['host']['hostname'], task['id'], task['task'],
729            time_utils.time_string_to_datetime(task['time_requested']))
730    status = server_utils.get_special_task_status(
731            task['is_complete'], task['success'], task['is_active'])
732    return _common_entry_to_dict(task, task['task'], job_dict,
733            exec_path, status, task['time_started'])
734
735
736def _queue_entry_to_dict(queue_entry):
737    job_dict = queue_entry['job']
738    tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
739    exec_path = server_utils.get_hqe_exec_path(tag,
740                                               queue_entry['execution_subdir'])
741    return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
742            queue_entry['status'], queue_entry['started_on'])
743
744
745def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
746                                                 queue_entries):
747    """
748    Prepare for serialization the interleaved entries of host queue entries
749    and special tasks.
750    Each element in the entries is a dictionary type.
751    The special task dictionary has only a job id for a job and lacks
752    the detail of the job while the host queue entry dictionary has.
753    queue_entries is used to look up the job detail info.
754
755    @param interleaved_entries  Host queue entries and special tasks as a list
756                                of dictionaries.
757    @param queue_entries        Host queue entries as a list of dictionaries.
758
759    @return A post-processed list of dictionaries that is to be serialized.
760    """
761    dict_list = []
762    for e in interleaved_entries:
763        # Distinguish the two mixed entries based on the existence of
764        # the key "task". If an entry has the key, the entry is for
765        # special task. Otherwise, host queue entry.
766        if 'task' in e:
767            dict_list.append(_special_task_to_dict(e, queue_entries))
768        else:
769            dict_list.append(_queue_entry_to_dict(e))
770    return prepare_for_serialization(dict_list)
771
772
773def _compute_next_job_for_tasks(queue_entries, special_tasks):
774    """
775    For each task, try to figure out the next job that ran after that task.
776    This is done using two pieces of information:
777    * if the task has a queue entry, we can use that entry's job ID.
778    * if the task has a time_started, we can try to compare that against the
779      started_on field of queue_entries. this isn't guaranteed to work perfectly
780      since queue_entries may also have null started_on values.
781    * if the task has neither, or if use of time_started fails, just use the
782      last computed job ID.
783
784    @param queue_entries    Host queue entries as a list of dictionaries.
785    @param special_tasks    Special tasks as a list of dictionaries.
786    """
787    next_job_id = None # most recently computed next job
788    hqe_index = 0 # index for scanning by started_on times
789    for task in special_tasks:
790        if task['queue_entry']:
791            next_job_id = task['queue_entry']['job']
792        elif task['time_started'] is not None:
793            for queue_entry in queue_entries[hqe_index:]:
794                if queue_entry['started_on'] is None:
795                    continue
796                t1 = time_utils.time_string_to_datetime(
797                        queue_entry['started_on'])
798                t2 = time_utils.time_string_to_datetime(task['time_started'])
799                if t1 < t2:
800                    break
801                next_job_id = queue_entry['job']['id']
802
803        task['next_job_id'] = next_job_id
804
805        # advance hqe_index to just after next_job_id
806        if next_job_id is not None:
807            for queue_entry in queue_entries[hqe_index:]:
808                if queue_entry['job']['id'] < next_job_id:
809                    break
810                hqe_index += 1
811
812
813def interleave_entries(queue_entries, special_tasks):
814    """
815    Both lists should be ordered by descending ID.
816    """
817    _compute_next_job_for_tasks(queue_entries, special_tasks)
818
819    # start with all special tasks that've run since the last job
820    interleaved_entries = []
821    for task in special_tasks:
822        if task['next_job_id'] is not None:
823            break
824        interleaved_entries.append(task)
825
826    # now interleave queue entries with the remaining special tasks
827    special_task_index = len(interleaved_entries)
828    for queue_entry in queue_entries:
829        interleaved_entries.append(queue_entry)
830        # add all tasks that ran between this job and the previous one
831        for task in special_tasks[special_task_index:]:
832            if task['next_job_id'] < queue_entry['job']['id']:
833                break
834            interleaved_entries.append(task)
835            special_task_index += 1
836
837    return interleaved_entries
838
839
840def bucket_hosts_by_shard(host_objs, rpc_hostnames=False):
841    """Figure out which hosts are on which shards.
842
843    @param host_objs: A list of host objects.
844    @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned
845        instead of the 'real' shard hostnames. This only matters for testing
846        environments.
847
848    @return: A map of shard hostname: list of hosts on the shard.
849    """
850    shard_host_map = {}
851    for host in host_objs:
852        if host.shard:
853            shard_name = (host.shard.rpc_hostname() if rpc_hostnames
854                          else host.shard.hostname)
855            shard_host_map.setdefault(shard_name, []).append(host.hostname)
856    return shard_host_map
857
858
859def get_create_job_common_args(local_args):
860    """
861    Returns a dict containing only the args that apply for create_job_common
862
863    Returns a subset of local_args, which contains only the arguments that can
864    be passed in to create_job_common().
865    """
866    # This code is only here to not kill suites scheduling tests when priority
867    # becomes an int instead of a string.
868    if isinstance(local_args['priority'], str):
869        local_args['priority'] = priorities.Priority.DEFAULT
870    # </migration hack>
871    arg_names, _, _, _ = inspect.getargspec(create_job_common)
872    return dict(item for item in local_args.iteritems() if item[0] in arg_names)
873
874
875def create_job_common(name, priority, control_type, control_file=None,
876                      hosts=(), meta_hosts=(), one_time_hosts=(),
877                      atomic_group_name=None, synch_count=None,
878                      is_template=False, timeout=None, timeout_mins=None,
879                      max_runtime_mins=None, run_verify=True, email_list='',
880                      dependencies=(), reboot_before=None, reboot_after=None,
881                      parse_failed_repair=None, hostless=False, keyvals=None,
882                      drone_set=None, parameterized_job=None,
883                      parent_job_id=None, test_retry=0, run_reset=True,
884                      require_ssp=None):
885    #pylint: disable-msg=C0111
886    """
887    Common code between creating "standard" jobs and creating parameterized jobs
888    """
889    user = models.User.current_user()
890    owner = user.login
891
892    # input validation
893    if not (hosts or meta_hosts or one_time_hosts or atomic_group_name
894            or hostless):
895        raise model_logic.ValidationError({
896            'arguments' : "You must pass at least one of 'hosts', "
897                          "'meta_hosts', 'one_time_hosts', "
898                          "'atomic_group_name', or 'hostless'"
899            })
900
901    if hostless:
902        if hosts or meta_hosts or one_time_hosts or atomic_group_name:
903            raise model_logic.ValidationError({
904                    'hostless': 'Hostless jobs cannot include any hosts!'})
905        server_type = control_data.CONTROL_TYPE_NAMES.SERVER
906        if control_type != server_type:
907            raise model_logic.ValidationError({
908                    'control_type': 'Hostless jobs cannot use client-side '
909                                    'control files'})
910
911    atomic_groups_by_name = dict((ag.name, ag)
912                                 for ag in models.AtomicGroup.objects.all())
913    label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
914
915    # Schedule on an atomic group automagically if one of the labels given
916    # is an atomic group label and no explicit atomic_group_name was supplied.
917    if not atomic_group_name:
918        for label in label_objects:
919            if label and label.atomic_group:
920                atomic_group_name = label.atomic_group.name
921                break
922    # convert hostnames & meta hosts to host/label objects
923    host_objects = models.Host.smart_get_bulk(hosts)
924    if not server_utils.is_shard():
925        shard_host_map = bucket_hosts_by_shard(host_objects)
926        num_shards = len(shard_host_map)
927        if (num_shards > 1 or (num_shards == 1 and
928                len(shard_host_map.values()[0]) != len(host_objects))):
929            # We disallow the following jobs on master:
930            #   num_shards > 1: this is a job spanning across multiple shards.
931            #   num_shards == 1 but number of hosts on shard is less
932            #   than total number of hosts: this is a job that spans across
933            #   one shard and the master.
934            raise ValueError(
935                    'The following hosts are on shard(s), please create '
936                    'seperate jobs for hosts on each shard: %s ' %
937                    shard_host_map)
938    metahost_objects = []
939    meta_host_labels_by_name = {label.name: label for label in label_objects}
940    for label_name in meta_hosts or []:
941        if label_name in meta_host_labels_by_name:
942            metahost_objects.append(meta_host_labels_by_name[label_name])
943        elif label_name in atomic_groups_by_name:
944            # If given a metahost name that isn't a Label, check to
945            # see if the user was specifying an Atomic Group instead.
946            atomic_group = atomic_groups_by_name[label_name]
947            if atomic_group_name and atomic_group_name != atomic_group.name:
948                raise model_logic.ValidationError({
949                        'meta_hosts': (
950                                'Label "%s" not found.  If assumed to be an '
951                                'atomic group it would conflict with the '
952                                'supplied atomic group "%s".' % (
953                                        label_name, atomic_group_name))})
954            atomic_group_name = atomic_group.name
955        else:
956            raise model_logic.ValidationError(
957                {'meta_hosts' : 'Label "%s" not found' % label_name})
958
959    # Create and sanity check an AtomicGroup object if requested.
960    if atomic_group_name:
961        if one_time_hosts:
962            raise model_logic.ValidationError(
963                    {'one_time_hosts':
964                     'One time hosts cannot be used with an Atomic Group.'})
965        atomic_group = models.AtomicGroup.smart_get(atomic_group_name)
966        if synch_count and synch_count > atomic_group.max_number_of_machines:
967            raise model_logic.ValidationError(
968                    {'atomic_group_name' :
969                     'You have requested a synch_count (%d) greater than the '
970                     'maximum machines in the requested Atomic Group (%d).' %
971                     (synch_count, atomic_group.max_number_of_machines)})
972    else:
973        atomic_group = None
974
975    for host in one_time_hosts or []:
976        this_host = models.Host.create_one_time_host(host)
977        host_objects.append(this_host)
978
979    options = dict(name=name,
980                   priority=priority,
981                   control_file=control_file,
982                   control_type=control_type,
983                   is_template=is_template,
984                   timeout=timeout,
985                   timeout_mins=timeout_mins,
986                   max_runtime_mins=max_runtime_mins,
987                   synch_count=synch_count,
988                   run_verify=run_verify,
989                   email_list=email_list,
990                   dependencies=dependencies,
991                   reboot_before=reboot_before,
992                   reboot_after=reboot_after,
993                   parse_failed_repair=parse_failed_repair,
994                   keyvals=keyvals,
995                   drone_set=drone_set,
996                   parameterized_job=parameterized_job,
997                   parent_job_id=parent_job_id,
998                   test_retry=test_retry,
999                   run_reset=run_reset,
1000                   require_ssp=require_ssp)
1001    return create_new_job(owner=owner,
1002                          options=options,
1003                          host_objects=host_objects,
1004                          metahost_objects=metahost_objects,
1005                          atomic_group=atomic_group)
1006
1007
1008def encode_ascii(control_file):
1009    """Force a control file to only contain ascii characters.
1010
1011    @param control_file: Control file to encode.
1012
1013    @returns the control file in an ascii encoding.
1014
1015    @raises error.ControlFileMalformed: if encoding fails.
1016    """
1017    try:
1018        return control_file.encode('ascii')
1019    except UnicodeDecodeError as e:
1020        raise error.ControlFileMalformed(str(e))
1021
1022
1023def get_wmatrix_url():
1024    """Get wmatrix url from config file.
1025
1026    @returns the wmatrix url or an empty string.
1027    """
1028    return global_config.global_config.get_config_value('AUTOTEST_WEB',
1029                                                        'wmatrix_url',
1030                                                        default='')
1031
1032
1033def inject_times_to_filter(start_time_key=None, end_time_key=None,
1034                         start_time_value=None, end_time_value=None,
1035                         **filter_data):
1036    """Inject the key value pairs of start and end time if provided.
1037
1038    @param start_time_key: A string represents the filter key of start_time.
1039    @param end_time_key: A string represents the filter key of end_time.
1040    @param start_time_value: Start_time value.
1041    @param end_time_value: End_time value.
1042
1043    @returns the injected filter_data.
1044    """
1045    if start_time_value:
1046        filter_data[start_time_key] = start_time_value
1047    if end_time_value:
1048        filter_data[end_time_key] = end_time_value
1049    return filter_data
1050
1051
1052def inject_times_to_hqe_special_tasks_filters(filter_data_common,
1053                                              start_time, end_time):
1054    """Inject start and end time to hqe and special tasks filters.
1055
1056    @param filter_data_common: Common filter for hqe and special tasks.
1057    @param start_time_key: A string represents the filter key of start_time.
1058    @param end_time_key: A string represents the filter key of end_time.
1059
1060    @returns a pair of hqe and special tasks filters.
1061    """
1062    filter_data_special_tasks = filter_data_common.copy()
1063    return (inject_times_to_filter('started_on__gte', 'started_on__lte',
1064                                   start_time, end_time, **filter_data_common),
1065           inject_times_to_filter('time_started__gte', 'time_started__lte',
1066                                  start_time, end_time,
1067                                  **filter_data_special_tasks))
1068
1069
1070def retrieve_shard(shard_hostname):
1071    """
1072    Retrieves the shard with the given hostname from the database.
1073
1074    @param shard_hostname: Hostname of the shard to retrieve
1075
1076    @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
1077
1078    @returns: Shard object
1079    """
1080    timer = autotest_stats.Timer('shard_heartbeat.retrieve_shard')
1081    with timer:
1082        return models.Shard.smart_get(shard_hostname)
1083
1084
1085def find_records_for_shard(shard, known_job_ids, known_host_ids):
1086    """Find records that should be sent to a shard.
1087
1088    @param shard: Shard to find records for.
1089    @param known_job_ids: List of ids of jobs the shard already has.
1090    @param known_host_ids: List of ids of hosts the shard already has.
1091
1092    @returns: Tuple of three lists for hosts, jobs, and suite job keyvals:
1093              (hosts, jobs, suite_job_keyvals).
1094    """
1095    timer = autotest_stats.Timer('shard_heartbeat')
1096    with timer.get_client('find_hosts'):
1097        hosts = models.Host.assign_to_shard(shard, known_host_ids)
1098    with timer.get_client('find_jobs'):
1099        jobs = models.Job.assign_to_shard(shard, known_job_ids)
1100    with timer.get_client('find_suite_job_keyvals'):
1101        parent_job_ids = [job.parent_job_id for job in jobs]
1102        suite_job_keyvals = models.JobKeyval.objects.filter(
1103                job_id__in=parent_job_ids)
1104    return hosts, jobs, suite_job_keyvals
1105
1106
1107def _persist_records_with_type_sent_from_shard(
1108    shard, records, record_type, *args, **kwargs):
1109    """
1110    Handle records of a specified type that were sent to the shard master.
1111
1112    @param shard: The shard the records were sent from.
1113    @param records: The records sent in their serialized format.
1114    @param record_type: Type of the objects represented by records.
1115    @param args: Additional arguments that will be passed on to the sanity
1116                 checks.
1117    @param kwargs: Additional arguments that will be passed on to the sanity
1118                  checks.
1119
1120    @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
1121
1122    @returns: List of primary keys of the processed records.
1123    """
1124    pks = []
1125    for serialized_record in records:
1126        pk = serialized_record['id']
1127        try:
1128            current_record = record_type.objects.get(pk=pk)
1129        except record_type.DoesNotExist:
1130            raise error.UnallowedRecordsSentToMaster(
1131                'Object with pk %s of type %s does not exist on master.' % (
1132                    pk, record_type))
1133
1134        current_record.sanity_check_update_from_shard(
1135            shard, serialized_record, *args, **kwargs)
1136
1137        current_record.update_from_serialized(serialized_record)
1138        pks.append(pk)
1139    return pks
1140
1141
1142def persist_records_sent_from_shard(shard, jobs, hqes):
1143    """
1144    Sanity checking then saving serialized records sent to master from shard.
1145
1146    During heartbeats shards upload jobs and hostqueuentries. This performs
1147    some sanity checks on these and then updates the existing records for those
1148    entries with the updated ones from the heartbeat.
1149
1150    The sanity checks include:
1151    - Checking if the objects sent already exist on the master.
1152    - Checking if the objects sent were assigned to this shard.
1153    - hostqueueentries must be sent together with their jobs.
1154
1155    @param shard: The shard the records were sent from.
1156    @param jobs: The jobs the shard sent.
1157    @param hqes: The hostqueuentries the shart sent.
1158
1159    @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
1160    """
1161    timer = autotest_stats.Timer('shard_heartbeat')
1162    with timer.get_client('persist_jobs'):
1163        job_ids_sent = _persist_records_with_type_sent_from_shard(
1164                shard, jobs, models.Job)
1165
1166    with timer.get_client('persist_hqes'):
1167        _persist_records_with_type_sent_from_shard(
1168                shard, hqes, models.HostQueueEntry, job_ids_sent=job_ids_sent)
1169
1170
1171def forward_single_host_rpc_to_shard(func):
1172    """This decorator forwards rpc calls that modify a host to a shard.
1173
1174    If a host is assigned to a shard, rpcs that change his attributes should be
1175    forwarded to the shard.
1176
1177    This assumes the first argument of the function represents a host id.
1178
1179    @param func: The function to decorate
1180
1181    @returns: The function to replace func with.
1182    """
1183    def replacement(**kwargs):
1184        # Only keyword arguments can be accepted here, as we need the argument
1185        # names to send the rpc. serviceHandler always provides arguments with
1186        # their keywords, so this is not a problem.
1187
1188        # A host record (identified by kwargs['id']) can be deleted in
1189        # func(). Therefore, we should save the data that can be needed later
1190        # before func() is called.
1191        shard_hostname = None
1192        host = models.Host.smart_get(kwargs['id'])
1193        if host and host.shard:
1194            shard_hostname = host.shard.rpc_hostname()
1195        ret = func(**kwargs)
1196        if shard_hostname and not server_utils.is_shard():
1197            run_rpc_on_multiple_hostnames(func.func_name,
1198                                          [shard_hostname],
1199                                          **kwargs)
1200        return ret
1201
1202    return replacement
1203
1204
1205def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
1206    """Fanout the given rpc to shards of given hosts.
1207
1208    @param host_objs: Host objects for the rpc.
1209    @param rpc_name: The name of the rpc.
1210    @param include_hostnames: If True, include the hostnames in the kwargs.
1211        Hostnames are not always necessary, this functions is designed to
1212        send rpcs to the shard a host is on, the rpcs themselves could be
1213        related to labels, acls etc.
1214    @param kwargs: The kwargs for the rpc.
1215    """
1216    # Figure out which hosts are on which shards.
1217    shard_host_map = bucket_hosts_by_shard(
1218            host_objs, rpc_hostnames=True)
1219
1220    # Execute the rpc against the appropriate shards.
1221    for shard, hostnames in shard_host_map.iteritems():
1222        if include_hostnames:
1223            kwargs['hosts'] = hostnames
1224        try:
1225            run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
1226        except:
1227            ei = sys.exc_info()
1228            new_exc = error.RPCException('RPC %s failed on shard %s due to '
1229                    '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
1230            raise new_exc.__class__, new_exc, ei[2]
1231
1232
1233def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
1234    """Runs an rpc to multiple AFEs
1235
1236    This is i.e. used to propagate changes made to hosts after they are assigned
1237    to a shard.
1238
1239    @param rpc_call: Name of the rpc endpoint to call.
1240    @param shard_hostnames: List of hostnames to run the rpcs on.
1241    @param **kwargs: Keyword arguments to pass in the rpcs.
1242    """
1243    # Make sure this function is not called on shards but only on master.
1244    assert not server_utils.is_shard()
1245    for shard_hostname in shard_hostnames:
1246        afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
1247                                            user=thread_local.get_user())
1248        afe.run(rpc_call, **kwargs)
1249
1250
1251def get_label(name):
1252    """Gets a label object using a given name.
1253
1254    @param name: Label name.
1255    @raises model.Label.DoesNotExist: when there is no label matching
1256                                      the given name.
1257    @return: a label object matching the given name.
1258    """
1259    try:
1260        label = models.Label.smart_get(name)
1261    except models.Label.DoesNotExist:
1262        return None
1263    return label
1264
1265
1266def route_rpc_to_master(func):
1267    """Route RPC to master AFE.
1268
1269    When a shard receives an RPC decorated by this, the RPC is just
1270    forwarded to the master.
1271    When the master gets the RPC, the RPC function is executed.
1272
1273    @param func: An RPC function to decorate
1274
1275    @returns: A function replacing the RPC func.
1276    """
1277    @wraps(func)
1278    def replacement(*args, **kwargs):
1279        """
1280        We need a special care when decorating an RPC that can be called
1281        directly using positional arguments. One example is
1282        rpc_interface.create_job().
1283        rpc_interface.create_job_page_handler() calls the function using
1284        positional and keyword arguments.
1285        Since frontend.RpcClient.run() takes only keyword arguments for
1286        an RPC, positional arguments of the RPC function need to be
1287        transformed to key-value pair (dictionary type).
1288
1289        inspect.getcallargs() is a useful utility to achieve the goal,
1290        however, we need an additional effort when an RPC function has
1291        **kwargs argument.
1292        Let's say we have a following form of RPC function.
1293
1294        def rpcfunc(a, b, **kwargs)
1295
1296        When we call the function like "rpcfunc(1, 2, id=3, name='mk')",
1297        inspect.getcallargs() returns a dictionary like below.
1298
1299        {'a':1, 'b':2, 'kwargs': {'id':3, 'name':'mk'}}
1300
1301        This is an incorrect form of arguments to pass to the rpc function.
1302        Instead, the dictionary should be like this.
1303
1304        {'a':1, 'b':2, 'id':3, 'name':'mk'}
1305        """
1306        argspec = inspect.getargspec(func)
1307        if argspec.varargs is not None:
1308            raise Exception('RPC function must not have *args.')
1309        funcargs = inspect.getcallargs(func, *args, **kwargs)
1310        kwargs = dict()
1311        for k, v in funcargs.iteritems():
1312            if argspec.keywords and k == argspec.keywords:
1313                kwargs.update(v)
1314            else:
1315                kwargs[k] = v
1316
1317        if server_utils.is_shard():
1318            afe = frontend_wrappers.RetryingAFE(
1319                    server=server_utils.get_global_afe_hostname(),
1320                    user=thread_local.get_user())
1321            return afe.run(func.func_name, **kwargs)
1322        return func(**kwargs)
1323    return replacement
1324