1# pylint: disable=missing-docstring
2"""
3Utility functions for rpc_interface.py.  We keep them in a separate file so that
4only RPC interface functions go into that file.
5"""
6
7__author__ = 'showard@google.com (Steve Howard)'
8
9import collections
10import datetime
11from functools import wraps
12import inspect
13import os
14import sys
15import django.db.utils
16import django.http
17
18from autotest_lib.frontend import thread_local
19from autotest_lib.frontend.afe import models, model_logic
20from autotest_lib.client.common_lib import control_data, error
21from autotest_lib.client.common_lib import global_config
22from autotest_lib.client.common_lib import time_utils
23from autotest_lib.client.common_lib.cros import dev_server
24# TODO(akeshet): Replace with monarch once we know how to instrument rpc server
25# with ts_mon.
26from autotest_lib.client.common_lib.cros.graphite import autotest_stats
27from autotest_lib.server import utils as server_utils
28from autotest_lib.server.cros import provision
29from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
30
31NULL_DATETIME = datetime.datetime.max
32NULL_DATE = datetime.date.max
33DUPLICATE_KEY_MSG = 'Duplicate entry'
34
35def prepare_for_serialization(objects):
36    """
37    Prepare Python objects to be returned via RPC.
38    @param objects: objects to be prepared.
39    """
40    if (isinstance(objects, list) and len(objects) and
41        isinstance(objects[0], dict) and 'id' in objects[0]):
42        objects = _gather_unique_dicts(objects)
43    return _prepare_data(objects)
44
45
46def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
47    """
48    Prepare a Django query to be returned via RPC as a sequence of nested
49    dictionaries.
50
51    @param query - A Django model query object with a select_related() method.
52    @param nested_dict_column_names - A list of column/attribute names for the
53            rows returned by query to expand into nested dictionaries using
54            their get_object_dict() method when not None.
55
56    @returns An list suitable to returned in an RPC.
57    """
58    all_dicts = []
59    for row in query.select_related():
60        row_dict = row.get_object_dict()
61        for column in nested_dict_column_names:
62            if row_dict[column] is not None:
63                row_dict[column] = getattr(row, column).get_object_dict()
64        all_dicts.append(row_dict)
65    return prepare_for_serialization(all_dicts)
66
67
68def _prepare_data(data):
69    """
70    Recursively process data structures, performing necessary type
71    conversions to values in data to allow for RPC serialization:
72    -convert datetimes to strings
73    -convert tuples and sets to lists
74    """
75    if isinstance(data, dict):
76        new_data = {}
77        for key, value in data.iteritems():
78            new_data[key] = _prepare_data(value)
79        return new_data
80    elif (isinstance(data, list) or isinstance(data, tuple) or
81          isinstance(data, set)):
82        return [_prepare_data(item) for item in data]
83    elif isinstance(data, datetime.date):
84        if data is NULL_DATETIME or data is NULL_DATE:
85            return None
86        return str(data)
87    else:
88        return data
89
90
91def fetchall_as_list_of_dicts(cursor):
92    """
93    Converts each row in the cursor to a dictionary so that values can be read
94    by using the column name.
95    @param cursor: The database cursor to read from.
96    @returns: A list of each row in the cursor as a dictionary.
97    """
98    desc = cursor.description
99    return [ dict(zip([col[0] for col in desc], row))
100             for row in cursor.fetchall() ]
101
102
103def raw_http_response(response_data, content_type=None):
104    response = django.http.HttpResponse(response_data, mimetype=content_type)
105    response['Content-length'] = str(len(response.content))
106    return response
107
108
109def _gather_unique_dicts(dict_iterable):
110    """\
111    Pick out unique objects (by ID) from an iterable of object dicts.
112    """
113    objects = collections.OrderedDict()
114    for obj in dict_iterable:
115        objects.setdefault(obj['id'], obj)
116    return objects.values()
117
118
119def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
120    """\
121    Generate a SQL WHERE clause for job status filtering, and return it in
122    a dict of keyword args to pass to query.extra().
123    * not_yet_run: all HQEs are Queued
124    * finished: all HQEs are complete
125    * running: everything else
126    """
127    if not (not_yet_run or running or finished):
128        return {}
129    not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
130                  'WHERE status != "%s")'
131                  % models.HostQueueEntry.Status.QUEUED)
132    not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
133                    'WHERE not complete)')
134
135    where = []
136    if not_yet_run:
137        where.append('id NOT IN ' + not_queued)
138    if running:
139        where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
140    if finished:
141        where.append('id NOT IN ' + not_finished)
142    return {'where': [' OR '.join(['(%s)' % x for x in where])]}
143
144
145def extra_job_type_filters(extra_args, suite=False,
146                           sub=False, standalone=False):
147    """\
148    Generate a SQL WHERE clause for job status filtering, and return it in
149    a dict of keyword args to pass to query.extra().
150
151    param extra_args: a dict of existing extra_args.
152
153    No more than one of the parameters should be passed as True:
154    * suite: job which is parent of other jobs
155    * sub: job with a parent job
156    * standalone: job with no child or parent jobs
157    """
158    assert not ((suite and sub) or
159                (suite and standalone) or
160                (sub and standalone)), ('Cannot specify more than one '
161                                        'filter to this function')
162
163    where = extra_args.get('where', [])
164    parent_job_id = ('DISTINCT parent_job_id')
165    child_job_id = ('id')
166    filter_common = ('(SELECT %s FROM afe_jobs '
167                     'WHERE parent_job_id IS NOT NULL)')
168
169    if suite:
170        where.append('id IN ' + filter_common % parent_job_id)
171    elif sub:
172        where.append('id IN ' + filter_common % child_job_id)
173    elif standalone:
174        where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
175                     'WHERE parent_job_id IS NOT NULL'
176                     ' AND (sub_query.parent_job_id=afe_jobs.id'
177                     ' OR sub_query.id=afe_jobs.id))')
178    else:
179        return extra_args
180
181    extra_args['where'] = where
182    return extra_args
183
184
185
186def extra_host_filters(multiple_labels=()):
187    """\
188    Generate SQL WHERE clauses for matching hosts in an intersection of
189    labels.
190    """
191    extra_args = {}
192    where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels '
193                 'where label_id=%s)')
194    extra_args['where'] = [where_str] * len(multiple_labels)
195    extra_args['params'] = [models.Label.smart_get(label).id
196                            for label in multiple_labels]
197    return extra_args
198
199
200def get_host_query(multiple_labels, exclude_only_if_needed_labels,
201                   valid_only, filter_data):
202    if valid_only:
203        query = models.Host.valid_objects.all()
204    else:
205        query = models.Host.objects.all()
206
207    if exclude_only_if_needed_labels:
208        only_if_needed_labels = models.Label.valid_objects.filter(
209            only_if_needed=True)
210        if only_if_needed_labels.count() > 0:
211            only_if_needed_ids = ','.join(
212                    str(label['id'])
213                    for label in only_if_needed_labels.values('id'))
214            query = models.Host.objects.add_join(
215                query, 'afe_hosts_labels', join_key='host_id',
216                join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)'
217                                % only_if_needed_ids),
218                suffix='_exclude_OIN', exclude=True)
219    try:
220        assert 'extra_args' not in filter_data
221        filter_data['extra_args'] = extra_host_filters(multiple_labels)
222        return models.Host.query_objects(filter_data, initial_query=query)
223    except models.Label.DoesNotExist:
224        return models.Host.objects.none()
225
226
227class InconsistencyException(Exception):
228    'Raised when a list of objects does not have a consistent value'
229
230
231def get_consistent_value(objects, field):
232    if not objects:
233        # well a list of nothing is consistent
234        return None
235
236    value = getattr(objects[0], field)
237    for obj in objects:
238        this_value = getattr(obj, field)
239        if this_value != value:
240            raise InconsistencyException(objects[0], obj)
241    return value
242
243
244def afe_test_dict_to_test_object(test_dict):
245    if not isinstance(test_dict, dict):
246        return test_dict
247
248    numerized_dict = {}
249    for key, value in test_dict.iteritems():
250        try:
251            numerized_dict[key] = int(value)
252        except (ValueError, TypeError):
253            numerized_dict[key] = value
254
255    return type('TestObject', (object,), numerized_dict)
256
257
258def _check_is_server_test(test_type):
259    """Checks if the test type is a server test.
260
261    @param test_type The test type in enum integer or string.
262
263    @returns A boolean to identify if the test type is server test.
264    """
265    if test_type is not None:
266        if isinstance(test_type, basestring):
267            try:
268                test_type = control_data.CONTROL_TYPE.get_value(test_type)
269            except AttributeError:
270                return False
271        return (test_type == control_data.CONTROL_TYPE.SERVER)
272    return False
273
274
275def prepare_generate_control_file(tests, profilers, db_tests=True):
276    if db_tests:
277        test_objects = [models.Test.smart_get(test) for test in tests]
278    else:
279        test_objects = [afe_test_dict_to_test_object(test) for test in tests]
280
281    profiler_objects = [models.Profiler.smart_get(profiler)
282                        for profiler in profilers]
283    # ensure tests are all the same type
284    try:
285        test_type = get_consistent_value(test_objects, 'test_type')
286    except InconsistencyException, exc:
287        test1, test2 = exc.args
288        raise model_logic.ValidationError(
289            {'tests' : 'You cannot run both test_suites and server-side '
290             'tests together (tests %s and %s differ' % (
291            test1.name, test2.name)})
292
293    is_server = _check_is_server_test(test_type)
294    if test_objects:
295        synch_count = max(test.sync_count for test in test_objects)
296    else:
297        synch_count = 1
298
299    if db_tests:
300        dependencies = set(label.name for label
301                           in models.Label.objects.filter(test__in=test_objects))
302    else:
303        dependencies = reduce(
304                set.union, [set(test.dependencies) for test in test_objects])
305
306    cf_info = dict(is_server=is_server, synch_count=synch_count,
307                   dependencies=list(dependencies))
308    return cf_info, test_objects, profiler_objects
309
310
311def check_job_dependencies(host_objects, job_dependencies):
312    """
313    Check that a set of machines satisfies a job's dependencies.
314    host_objects: list of models.Host objects
315    job_dependencies: list of names of labels
316    """
317    # check that hosts satisfy dependencies
318    host_ids = [host.id for host in host_objects]
319    hosts_in_job = models.Host.objects.filter(id__in=host_ids)
320    ok_hosts = hosts_in_job
321    for index, dependency in enumerate(job_dependencies):
322        if not provision.is_for_special_action(dependency):
323            ok_hosts = ok_hosts.filter(labels__name=dependency)
324    failing_hosts = (set(host.hostname for host in host_objects) -
325                     set(host.hostname for host in ok_hosts))
326    if failing_hosts:
327        raise model_logic.ValidationError(
328            {'hosts' : 'Host(s) failed to meet job dependencies (' +
329                       (', '.join(job_dependencies)) + '): ' +
330                       (', '.join(failing_hosts))})
331
332
333def check_job_metahost_dependencies(metahost_objects, job_dependencies):
334    """
335    Check that at least one machine within the metahost spec satisfies the job's
336    dependencies.
337
338    @param metahost_objects A list of label objects representing the metahosts.
339    @param job_dependencies A list of strings of the required label names.
340    @raises NoEligibleHostException If a metahost cannot run the job.
341    """
342    for metahost in metahost_objects:
343        hosts = models.Host.objects.filter(labels=metahost)
344        for label_name in job_dependencies:
345            if not provision.is_for_special_action(label_name):
346                hosts = hosts.filter(labels__name=label_name)
347        if not any(hosts):
348            raise error.NoEligibleHostException("No hosts within %s satisfy %s."
349                    % (metahost.name, ', '.join(job_dependencies)))
350
351
352def _execution_key_for(host_queue_entry):
353    return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
354
355
356def check_abort_synchronous_jobs(host_queue_entries):
357    # ensure user isn't aborting part of a synchronous autoserv execution
358    count_per_execution = {}
359    for queue_entry in host_queue_entries:
360        key = _execution_key_for(queue_entry)
361        count_per_execution.setdefault(key, 0)
362        count_per_execution[key] += 1
363
364    for queue_entry in host_queue_entries:
365        if not queue_entry.execution_subdir:
366            continue
367        execution_count = count_per_execution[_execution_key_for(queue_entry)]
368        if execution_count < queue_entry.job.synch_count:
369            raise model_logic.ValidationError(
370                {'' : 'You cannot abort part of a synchronous job execution '
371                      '(%d/%s), %d included, %d expected'
372                      % (queue_entry.job.id, queue_entry.execution_subdir,
373                         execution_count, queue_entry.job.synch_count)})
374
375
376def check_modify_host(update_data):
377    """
378    Sanity check modify_host* requests.
379
380    @param update_data: A dictionary with the changes to make to a host
381            or hosts.
382    """
383    # Only the scheduler (monitor_db) is allowed to modify Host status.
384    # Otherwise race conditions happen as a hosts state is changed out from
385    # beneath tasks being run on a host.
386    if 'status' in update_data:
387        raise model_logic.ValidationError({
388                'status': 'Host status can not be modified by the frontend.'})
389
390
391def check_modify_host_locking(host, update_data):
392    """
393    Checks when locking/unlocking has been requested if the host is already
394    locked/unlocked.
395
396    @param host: models.Host object to be modified
397    @param update_data: A dictionary with the changes to make to the host.
398    """
399    locked = update_data.get('locked', None)
400    lock_reason = update_data.get('lock_reason', None)
401    if locked is not None:
402        if locked and host.locked:
403            raise model_logic.ValidationError({
404                    'locked': 'Host %s already locked by %s on %s.' %
405                    (host.hostname, host.locked_by, host.lock_time)})
406        if not locked and not host.locked:
407            raise model_logic.ValidationError({
408                    'locked': 'Host %s already unlocked.' % host.hostname})
409        if locked and not lock_reason and not host.locked:
410            raise model_logic.ValidationError({
411                    'locked': 'Please provide a reason for locking Host %s' %
412                    host.hostname})
413
414
415def get_motd():
416    dirname = os.path.dirname(__file__)
417    filename = os.path.join(dirname, "..", "..", "motd.txt")
418    text = ''
419    try:
420        fp = open(filename, "r")
421        try:
422            text = fp.read()
423        finally:
424            fp.close()
425    except:
426        pass
427
428    return text
429
430
431def _get_metahost_counts(metahost_objects):
432    metahost_counts = {}
433    for metahost in metahost_objects:
434        metahost_counts.setdefault(metahost, 0)
435        metahost_counts[metahost] += 1
436    return metahost_counts
437
438
439def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
440    hosts = []
441    one_time_hosts = []
442    meta_hosts = []
443    hostless = False
444
445    queue_entries = job.hostqueueentry_set.all()
446    if queue_entry_filter_data:
447        queue_entries = models.HostQueueEntry.query_objects(
448            queue_entry_filter_data, initial_query=queue_entries)
449
450    for queue_entry in queue_entries:
451        if (queue_entry.host and (preserve_metahosts or
452                                  not queue_entry.meta_host)):
453            if queue_entry.deleted:
454                continue
455            if queue_entry.host.invalid:
456                one_time_hosts.append(queue_entry.host)
457            else:
458                hosts.append(queue_entry.host)
459        elif queue_entry.meta_host:
460            meta_hosts.append(queue_entry.meta_host)
461        else:
462            hostless = True
463
464    meta_host_counts = _get_metahost_counts(meta_hosts)
465
466    info = dict(dependencies=[label.name for label
467                              in job.dependency_labels.all()],
468                hosts=hosts,
469                meta_hosts=meta_hosts,
470                meta_host_counts=meta_host_counts,
471                one_time_hosts=one_time_hosts,
472                hostless=hostless)
473    return info
474
475
476def check_for_duplicate_hosts(host_objects):
477    host_counts = collections.Counter(host_objects)
478    duplicate_hostnames = {host.hostname
479                           for host, count in host_counts.iteritems()
480                           if count > 1}
481    if duplicate_hostnames:
482        raise model_logic.ValidationError(
483                {'hosts' : 'Duplicate hosts: %s'
484                 % ', '.join(duplicate_hostnames)})
485
486
487def create_new_job(owner, options, host_objects, metahost_objects):
488    all_host_objects = host_objects + metahost_objects
489    dependencies = options.get('dependencies', [])
490    synch_count = options.get('synch_count')
491
492    if synch_count is not None and synch_count > len(all_host_objects):
493        raise model_logic.ValidationError(
494                {'hosts':
495                 'only %d hosts provided for job with synch_count = %d' %
496                 (len(all_host_objects), synch_count)})
497
498    check_for_duplicate_hosts(host_objects)
499
500    for label_name in dependencies:
501        if provision.is_for_special_action(label_name):
502            # TODO: We could save a few queries
503            # if we had a bulk ensure-label-exists function, which used
504            # a bulk .get() call. The win is probably very small.
505            _ensure_label_exists(label_name)
506
507    # This only checks targeted hosts, not hosts eligible due to the metahost
508    check_job_dependencies(host_objects, dependencies)
509    check_job_metahost_dependencies(metahost_objects, dependencies)
510
511    options['dependencies'] = list(
512            models.Label.objects.filter(name__in=dependencies))
513
514    job = models.Job.create(owner=owner, options=options,
515                            hosts=all_host_objects)
516    job.queue(all_host_objects,
517              is_template=options.get('is_template', False))
518    return job.id
519
520
521def _ensure_label_exists(name):
522    """
523    Ensure that a label called |name| exists in the Django models.
524
525    This function is to be called from within afe rpcs only, as an
526    alternative to server.cros.provision.ensure_label_exists(...). It works
527    by Django model manipulation, rather than by making another create_label
528    rpc call.
529
530    @param name: the label to check for/create.
531    @raises ValidationError: There was an error in the response that was
532                             not because the label already existed.
533    @returns True is a label was created, False otherwise.
534    """
535    # Make sure this function is not called on shards but only on master.
536    assert not server_utils.is_shard()
537    try:
538        models.Label.objects.get(name=name)
539    except models.Label.DoesNotExist:
540        try:
541            new_label = models.Label.objects.create(name=name)
542            new_label.save()
543            return True
544        except django.db.utils.IntegrityError as e:
545            # It is possible that another suite/test already
546            # created the label between the check and save.
547            if DUPLICATE_KEY_MSG in str(e):
548                return False
549            else:
550                raise
551    return False
552
553
554def find_platform(host):
555    """
556    Figure out the platform name for the given host
557    object.  If none, the return value for either will be None.
558
559    @returns platform name for the given host.
560    """
561    platforms = [label.name for label in host.label_list if label.platform]
562    if not platforms:
563        platform = None
564    else:
565        platform = platforms[0]
566    if len(platforms) > 1:
567        raise ValueError('Host %s has more than one platform: %s' %
568                         (host.hostname, ', '.join(platforms)))
569    return platform
570
571
572# support for get_host_queue_entries_and_special_tasks()
573
574def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
575    return dict(type=type,
576                host=entry['host'],
577                job=job_dict,
578                execution_path=exec_path,
579                status=status,
580                started_on=started_on,
581                id=str(entry['id']) + type,
582                oid=entry['id'])
583
584
585def _special_task_to_dict(task, queue_entries):
586    """Transforms a special task dictionary to another form of dictionary.
587
588    @param task           Special task as a dictionary type
589    @param queue_entries  Host queue entries as a list of dictionaries.
590
591    @return Transformed dictionary for a special task.
592    """
593    job_dict = None
594    if task['queue_entry']:
595        # Scan queue_entries to get the job detail info.
596        for qentry in queue_entries:
597            if task['queue_entry']['id'] == qentry['id']:
598                job_dict = qentry['job']
599                break
600        # If not found, get it from DB.
601        if job_dict is None:
602            job = models.Job.objects.get(id=task['queue_entry']['job'])
603            job_dict = job.get_object_dict()
604
605    exec_path = server_utils.get_special_task_exec_path(
606            task['host']['hostname'], task['id'], task['task'],
607            time_utils.time_string_to_datetime(task['time_requested']))
608    status = server_utils.get_special_task_status(
609            task['is_complete'], task['success'], task['is_active'])
610    return _common_entry_to_dict(task, task['task'], job_dict,
611            exec_path, status, task['time_started'])
612
613
614def _queue_entry_to_dict(queue_entry):
615    job_dict = queue_entry['job']
616    tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
617    exec_path = server_utils.get_hqe_exec_path(tag,
618                                               queue_entry['execution_subdir'])
619    return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
620            queue_entry['status'], queue_entry['started_on'])
621
622
623def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
624                                                 queue_entries):
625    """
626    Prepare for serialization the interleaved entries of host queue entries
627    and special tasks.
628    Each element in the entries is a dictionary type.
629    The special task dictionary has only a job id for a job and lacks
630    the detail of the job while the host queue entry dictionary has.
631    queue_entries is used to look up the job detail info.
632
633    @param interleaved_entries  Host queue entries and special tasks as a list
634                                of dictionaries.
635    @param queue_entries        Host queue entries as a list of dictionaries.
636
637    @return A post-processed list of dictionaries that is to be serialized.
638    """
639    dict_list = []
640    for e in interleaved_entries:
641        # Distinguish the two mixed entries based on the existence of
642        # the key "task". If an entry has the key, the entry is for
643        # special task. Otherwise, host queue entry.
644        if 'task' in e:
645            dict_list.append(_special_task_to_dict(e, queue_entries))
646        else:
647            dict_list.append(_queue_entry_to_dict(e))
648    return prepare_for_serialization(dict_list)
649
650
651def _compute_next_job_for_tasks(queue_entries, special_tasks):
652    """
653    For each task, try to figure out the next job that ran after that task.
654    This is done using two pieces of information:
655    * if the task has a queue entry, we can use that entry's job ID.
656    * if the task has a time_started, we can try to compare that against the
657      started_on field of queue_entries. this isn't guaranteed to work perfectly
658      since queue_entries may also have null started_on values.
659    * if the task has neither, or if use of time_started fails, just use the
660      last computed job ID.
661
662    @param queue_entries    Host queue entries as a list of dictionaries.
663    @param special_tasks    Special tasks as a list of dictionaries.
664    """
665    next_job_id = None # most recently computed next job
666    hqe_index = 0 # index for scanning by started_on times
667    for task in special_tasks:
668        if task['queue_entry']:
669            next_job_id = task['queue_entry']['job']
670        elif task['time_started'] is not None:
671            for queue_entry in queue_entries[hqe_index:]:
672                if queue_entry['started_on'] is None:
673                    continue
674                t1 = time_utils.time_string_to_datetime(
675                        queue_entry['started_on'])
676                t2 = time_utils.time_string_to_datetime(task['time_started'])
677                if t1 < t2:
678                    break
679                next_job_id = queue_entry['job']['id']
680
681        task['next_job_id'] = next_job_id
682
683        # advance hqe_index to just after next_job_id
684        if next_job_id is not None:
685            for queue_entry in queue_entries[hqe_index:]:
686                if queue_entry['job']['id'] < next_job_id:
687                    break
688                hqe_index += 1
689
690
691def interleave_entries(queue_entries, special_tasks):
692    """
693    Both lists should be ordered by descending ID.
694    """
695    _compute_next_job_for_tasks(queue_entries, special_tasks)
696
697    # start with all special tasks that've run since the last job
698    interleaved_entries = []
699    for task in special_tasks:
700        if task['next_job_id'] is not None:
701            break
702        interleaved_entries.append(task)
703
704    # now interleave queue entries with the remaining special tasks
705    special_task_index = len(interleaved_entries)
706    for queue_entry in queue_entries:
707        interleaved_entries.append(queue_entry)
708        # add all tasks that ran between this job and the previous one
709        for task in special_tasks[special_task_index:]:
710            if task['next_job_id'] < queue_entry['job']['id']:
711                break
712            interleaved_entries.append(task)
713            special_task_index += 1
714
715    return interleaved_entries
716
717
718def bucket_hosts_by_shard(host_objs, rpc_hostnames=False):
719    """Figure out which hosts are on which shards.
720
721    @param host_objs: A list of host objects.
722    @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned
723        instead of the 'real' shard hostnames. This only matters for testing
724        environments.
725
726    @return: A map of shard hostname: list of hosts on the shard.
727    """
728    shard_host_map = collections.defaultdict(list)
729    for host in host_objs:
730        if host.shard:
731            shard_name = (host.shard.rpc_hostname() if rpc_hostnames
732                          else host.shard.hostname)
733            shard_host_map[shard_name].append(host.hostname)
734    return shard_host_map
735
736
737def create_job_common(
738        name,
739        priority,
740        control_type,
741        control_file=None,
742        hosts=(),
743        meta_hosts=(),
744        one_time_hosts=(),
745        synch_count=None,
746        is_template=False,
747        timeout=None,
748        timeout_mins=None,
749        max_runtime_mins=None,
750        run_verify=True,
751        email_list='',
752        dependencies=(),
753        reboot_before=None,
754        reboot_after=None,
755        parse_failed_repair=None,
756        hostless=False,
757        keyvals=None,
758        drone_set=None,
759        parent_job_id=None,
760        test_retry=0,
761        run_reset=True,
762        require_ssp=None):
763    #pylint: disable-msg=C0111
764    """
765    Common code between creating "standard" jobs and creating parameterized jobs
766    """
767    # input validation
768    host_args_passed = any((hosts, meta_hosts, one_time_hosts))
769    if hostless:
770        if host_args_passed:
771            raise model_logic.ValidationError({
772                    'hostless': 'Hostless jobs cannot include any hosts!'})
773        if control_type != control_data.CONTROL_TYPE_NAMES.SERVER:
774            raise model_logic.ValidationError({
775                    'control_type': 'Hostless jobs cannot use client-side '
776                                    'control files'})
777    elif not host_args_passed:
778        raise model_logic.ValidationError({
779            'arguments' : "For host jobs, you must pass at least one of"
780                          " 'hosts', 'meta_hosts', 'one_time_hosts'."
781            })
782    label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
783
784    # convert hostnames & meta hosts to host/label objects
785    host_objects = models.Host.smart_get_bulk(hosts)
786    _validate_host_job_sharding(host_objects)
787    for host in one_time_hosts:
788        this_host = models.Host.create_one_time_host(host)
789        host_objects.append(this_host)
790
791    metahost_objects = []
792    meta_host_labels_by_name = {label.name: label for label in label_objects}
793    for label_name in meta_hosts:
794        if label_name in meta_host_labels_by_name:
795            metahost_objects.append(meta_host_labels_by_name[label_name])
796        else:
797            raise model_logic.ValidationError(
798                {'meta_hosts' : 'Label "%s" not found' % label_name})
799
800    options = dict(name=name,
801                   priority=priority,
802                   control_file=control_file,
803                   control_type=control_type,
804                   is_template=is_template,
805                   timeout=timeout,
806                   timeout_mins=timeout_mins,
807                   max_runtime_mins=max_runtime_mins,
808                   synch_count=synch_count,
809                   run_verify=run_verify,
810                   email_list=email_list,
811                   dependencies=dependencies,
812                   reboot_before=reboot_before,
813                   reboot_after=reboot_after,
814                   parse_failed_repair=parse_failed_repair,
815                   keyvals=keyvals,
816                   drone_set=drone_set,
817                   parent_job_id=parent_job_id,
818                   test_retry=test_retry,
819                   run_reset=run_reset,
820                   require_ssp=require_ssp)
821
822    return create_new_job(owner=models.User.current_user().login,
823                          options=options,
824                          host_objects=host_objects,
825                          metahost_objects=metahost_objects)
826
827
828def _validate_host_job_sharding(host_objects):
829    """Check that the hosts obey job sharding rules."""
830    if not (server_utils.is_shard()
831            or _allowed_hosts_for_master_job(host_objects)):
832        shard_host_map = bucket_hosts_by_shard(host_objects)
833        raise ValueError(
834                'The following hosts are on shard(s), please create '
835                'seperate jobs for hosts on each shard: %s ' %
836                shard_host_map)
837
838
839def _allowed_hosts_for_master_job(host_objects):
840    """Check that the hosts are allowed for a job on master."""
841    # We disallow the following jobs on master:
842    #   num_shards > 1: this is a job spanning across multiple shards.
843    #   num_shards == 1 but number of hosts on shard is less
844    #   than total number of hosts: this is a job that spans across
845    #   one shard and the master.
846    shard_host_map = bucket_hosts_by_shard(host_objects)
847    num_shards = len(shard_host_map)
848    if num_shards > 1:
849        return False
850    if num_shards == 1:
851        hosts_on_shard = shard_host_map.values()[0]
852        assert len(hosts_on_shard) <= len(host_objects)
853        return len(hosts_on_shard) == len(host_objects)
854    else:
855        return True
856
857
858def encode_ascii(control_file):
859    """Force a control file to only contain ascii characters.
860
861    @param control_file: Control file to encode.
862
863    @returns the control file in an ascii encoding.
864
865    @raises error.ControlFileMalformed: if encoding fails.
866    """
867    try:
868        return control_file.encode('ascii')
869    except UnicodeDecodeError as e:
870        raise error.ControlFileMalformed(str(e))
871
872
873def get_wmatrix_url():
874    """Get wmatrix url from config file.
875
876    @returns the wmatrix url or an empty string.
877    """
878    return global_config.global_config.get_config_value('AUTOTEST_WEB',
879                                                        'wmatrix_url',
880                                                        default='')
881
882
883def inject_times_to_filter(start_time_key=None, end_time_key=None,
884                         start_time_value=None, end_time_value=None,
885                         **filter_data):
886    """Inject the key value pairs of start and end time if provided.
887
888    @param start_time_key: A string represents the filter key of start_time.
889    @param end_time_key: A string represents the filter key of end_time.
890    @param start_time_value: Start_time value.
891    @param end_time_value: End_time value.
892
893    @returns the injected filter_data.
894    """
895    if start_time_value:
896        filter_data[start_time_key] = start_time_value
897    if end_time_value:
898        filter_data[end_time_key] = end_time_value
899    return filter_data
900
901
902def inject_times_to_hqe_special_tasks_filters(filter_data_common,
903                                              start_time, end_time):
904    """Inject start and end time to hqe and special tasks filters.
905
906    @param filter_data_common: Common filter for hqe and special tasks.
907    @param start_time_key: A string represents the filter key of start_time.
908    @param end_time_key: A string represents the filter key of end_time.
909
910    @returns a pair of hqe and special tasks filters.
911    """
912    filter_data_special_tasks = filter_data_common.copy()
913    return (inject_times_to_filter('started_on__gte', 'started_on__lte',
914                                   start_time, end_time, **filter_data_common),
915           inject_times_to_filter('time_started__gte', 'time_started__lte',
916                                  start_time, end_time,
917                                  **filter_data_special_tasks))
918
919
920def retrieve_shard(shard_hostname):
921    """
922    Retrieves the shard with the given hostname from the database.
923
924    @param shard_hostname: Hostname of the shard to retrieve
925
926    @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
927
928    @returns: Shard object
929    """
930    timer = autotest_stats.Timer('shard_heartbeat.retrieve_shard')
931    with timer:
932        return models.Shard.smart_get(shard_hostname)
933
934
935def find_records_for_shard(shard, known_job_ids, known_host_ids):
936    """Find records that should be sent to a shard.
937
938    @param shard: Shard to find records for.
939    @param known_job_ids: List of ids of jobs the shard already has.
940    @param known_host_ids: List of ids of hosts the shard already has.
941
942    @returns: Tuple of three lists for hosts, jobs, and suite job keyvals:
943              (hosts, jobs, suite_job_keyvals).
944    """
945    timer = autotest_stats.Timer('shard_heartbeat')
946    with timer.get_client('find_hosts'):
947        hosts = models.Host.assign_to_shard(shard, known_host_ids)
948    with timer.get_client('find_jobs'):
949        jobs = models.Job.assign_to_shard(shard, known_job_ids)
950    with timer.get_client('find_suite_job_keyvals'):
951        parent_job_ids = [job.parent_job_id for job in jobs]
952        suite_job_keyvals = models.JobKeyval.objects.filter(
953                job_id__in=parent_job_ids)
954    return hosts, jobs, suite_job_keyvals
955
956
957def _persist_records_with_type_sent_from_shard(
958    shard, records, record_type, *args, **kwargs):
959    """
960    Handle records of a specified type that were sent to the shard master.
961
962    @param shard: The shard the records were sent from.
963    @param records: The records sent in their serialized format.
964    @param record_type: Type of the objects represented by records.
965    @param args: Additional arguments that will be passed on to the sanity
966                 checks.
967    @param kwargs: Additional arguments that will be passed on to the sanity
968                  checks.
969
970    @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
971
972    @returns: List of primary keys of the processed records.
973    """
974    pks = []
975    for serialized_record in records:
976        pk = serialized_record['id']
977        try:
978            current_record = record_type.objects.get(pk=pk)
979        except record_type.DoesNotExist:
980            raise error.UnallowedRecordsSentToMaster(
981                'Object with pk %s of type %s does not exist on master.' % (
982                    pk, record_type))
983
984        current_record.sanity_check_update_from_shard(
985            shard, serialized_record, *args, **kwargs)
986
987        current_record.update_from_serialized(serialized_record)
988        pks.append(pk)
989    return pks
990
991
992def persist_records_sent_from_shard(shard, jobs, hqes):
993    """
994    Sanity checking then saving serialized records sent to master from shard.
995
996    During heartbeats shards upload jobs and hostqueuentries. This performs
997    some sanity checks on these and then updates the existing records for those
998    entries with the updated ones from the heartbeat.
999
1000    The sanity checks include:
1001    - Checking if the objects sent already exist on the master.
1002    - Checking if the objects sent were assigned to this shard.
1003    - hostqueueentries must be sent together with their jobs.
1004
1005    @param shard: The shard the records were sent from.
1006    @param jobs: The jobs the shard sent.
1007    @param hqes: The hostqueuentries the shart sent.
1008
1009    @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
1010    """
1011    timer = autotest_stats.Timer('shard_heartbeat')
1012    with timer.get_client('persist_jobs'):
1013        job_ids_sent = _persist_records_with_type_sent_from_shard(
1014                shard, jobs, models.Job)
1015
1016    with timer.get_client('persist_hqes'):
1017        _persist_records_with_type_sent_from_shard(
1018                shard, hqes, models.HostQueueEntry, job_ids_sent=job_ids_sent)
1019
1020
1021def forward_single_host_rpc_to_shard(func):
1022    """This decorator forwards rpc calls that modify a host to a shard.
1023
1024    If a host is assigned to a shard, rpcs that change his attributes should be
1025    forwarded to the shard.
1026
1027    This assumes the first argument of the function represents a host id.
1028
1029    @param func: The function to decorate
1030
1031    @returns: The function to replace func with.
1032    """
1033    def replacement(**kwargs):
1034        # Only keyword arguments can be accepted here, as we need the argument
1035        # names to send the rpc. serviceHandler always provides arguments with
1036        # their keywords, so this is not a problem.
1037
1038        # A host record (identified by kwargs['id']) can be deleted in
1039        # func(). Therefore, we should save the data that can be needed later
1040        # before func() is called.
1041        shard_hostname = None
1042        host = models.Host.smart_get(kwargs['id'])
1043        if host and host.shard:
1044            shard_hostname = host.shard.rpc_hostname()
1045        ret = func(**kwargs)
1046        if shard_hostname and not server_utils.is_shard():
1047            run_rpc_on_multiple_hostnames(func.func_name,
1048                                          [shard_hostname],
1049                                          **kwargs)
1050        return ret
1051
1052    return replacement
1053
1054
1055def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
1056    """Fanout the given rpc to shards of given hosts.
1057
1058    @param host_objs: Host objects for the rpc.
1059    @param rpc_name: The name of the rpc.
1060    @param include_hostnames: If True, include the hostnames in the kwargs.
1061        Hostnames are not always necessary, this functions is designed to
1062        send rpcs to the shard a host is on, the rpcs themselves could be
1063        related to labels, acls etc.
1064    @param kwargs: The kwargs for the rpc.
1065    """
1066    # Figure out which hosts are on which shards.
1067    shard_host_map = bucket_hosts_by_shard(
1068            host_objs, rpc_hostnames=True)
1069
1070    # Execute the rpc against the appropriate shards.
1071    for shard, hostnames in shard_host_map.iteritems():
1072        if include_hostnames:
1073            kwargs['hosts'] = hostnames
1074        try:
1075            run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
1076        except:
1077            ei = sys.exc_info()
1078            new_exc = error.RPCException('RPC %s failed on shard %s due to '
1079                    '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
1080            raise new_exc.__class__, new_exc, ei[2]
1081
1082
1083def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
1084    """Runs an rpc to multiple AFEs
1085
1086    This is i.e. used to propagate changes made to hosts after they are assigned
1087    to a shard.
1088
1089    @param rpc_call: Name of the rpc endpoint to call.
1090    @param shard_hostnames: List of hostnames to run the rpcs on.
1091    @param **kwargs: Keyword arguments to pass in the rpcs.
1092    """
1093    # Make sure this function is not called on shards but only on master.
1094    assert not server_utils.is_shard()
1095    for shard_hostname in shard_hostnames:
1096        afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
1097                                            user=thread_local.get_user())
1098        afe.run(rpc_call, **kwargs)
1099
1100
1101def get_label(name):
1102    """Gets a label object using a given name.
1103
1104    @param name: Label name.
1105    @raises model.Label.DoesNotExist: when there is no label matching
1106                                      the given name.
1107    @return: a label object matching the given name.
1108    """
1109    try:
1110        label = models.Label.smart_get(name)
1111    except models.Label.DoesNotExist:
1112        return None
1113    return label
1114
1115
1116# TODO: hide the following rpcs under is_moblab
1117def moblab_only(func):
1118    """Ensure moblab specific functions only run on Moblab devices."""
1119    def verify(*args, **kwargs):
1120        if not server_utils.is_moblab():
1121            raise error.RPCException('RPC: %s can only run on Moblab Systems!',
1122                                     func.__name__)
1123        return func(*args, **kwargs)
1124    return verify
1125
1126
1127def route_rpc_to_master(func):
1128    """Route RPC to master AFE.
1129
1130    When a shard receives an RPC decorated by this, the RPC is just
1131    forwarded to the master.
1132    When the master gets the RPC, the RPC function is executed.
1133
1134    @param func: An RPC function to decorate
1135
1136    @returns: A function replacing the RPC func.
1137    """
1138    argspec = inspect.getargspec(func)
1139    if argspec.varargs is not None:
1140        raise Exception('RPC function must not have *args.')
1141
1142    @wraps(func)
1143    def replacement(*args, **kwargs):
1144        """We need special handling when decorating an RPC that can be called
1145        directly using positional arguments.
1146
1147        One example is rpc_interface.create_job().
1148        rpc_interface.create_job_page_handler() calls the function using both
1149        positional and keyword arguments.  Since frontend.RpcClient.run()
1150        takes only keyword arguments for an RPC, positional arguments of the
1151        RPC function need to be transformed into keyword arguments.
1152        """
1153        kwargs = _convert_to_kwargs_only(func, args, kwargs)
1154        if server_utils.is_shard():
1155            afe = frontend_wrappers.RetryingAFE(
1156                    server=server_utils.get_global_afe_hostname(),
1157                    user=thread_local.get_user())
1158            return afe.run(func.func_name, **kwargs)
1159        return func(**kwargs)
1160
1161    return replacement
1162
1163
1164def _convert_to_kwargs_only(func, args, kwargs):
1165    """Convert a function call's arguments to a kwargs dict.
1166
1167    This is best illustrated with an example.  Given:
1168
1169    def foo(a, b, **kwargs):
1170        pass
1171    _to_kwargs(foo, (1, 2), {'c': 3})  # corresponding to foo(1, 2, c=3)
1172
1173        foo(**kwargs)
1174
1175    @param func: function whose signature to use
1176    @param args: positional arguments of call
1177    @param kwargs: keyword arguments of call
1178
1179    @returns: kwargs dict
1180    """
1181    argspec = inspect.getargspec(func)
1182    # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}}
1183    callargs = inspect.getcallargs(func, *args, **kwargs)
1184    if argspec.keywords is None:
1185        kwargs = {}
1186    else:
1187        kwargs = callargs.pop(argspec.keywords)
1188    kwargs.update(callargs)
1189    return kwargs
1190
1191
1192def get_sample_dut(board, pool):
1193    """Get a dut with the given board and pool.
1194
1195    This method is used to help to locate a dut with the given board and pool.
1196    The dut then can be used to identify a devserver in the same subnet.
1197
1198    @param board: Name of the board.
1199    @param pool: Name of the pool.
1200
1201    @return: Name of a dut with the given board and pool.
1202    """
1203    if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board):
1204        return None
1205    hosts = list(get_host_query(
1206        multiple_labels=('pool:%s' % pool, 'board:%s' % board),
1207        exclude_only_if_needed_labels=False,
1208        valid_only=True,
1209        filter_data={},
1210    ))
1211    if not hosts:
1212        return None
1213    else:
1214        return hosts[0].hostname
1215