1# pylint: disable=missing-docstring 2""" 3Utility functions for rpc_interface.py. We keep them in a separate file so that 4only RPC interface functions go into that file. 5""" 6 7__author__ = 'showard@google.com (Steve Howard)' 8 9import collections 10import datetime 11from functools import wraps 12import inspect 13import os 14import sys 15import django.db.utils 16import django.http 17 18from autotest_lib.frontend import thread_local 19from autotest_lib.frontend.afe import models, model_logic 20from autotest_lib.client.common_lib import control_data, error 21from autotest_lib.client.common_lib import global_config 22from autotest_lib.client.common_lib import time_utils 23from autotest_lib.client.common_lib.cros import dev_server 24# TODO(akeshet): Replace with monarch once we know how to instrument rpc server 25# with ts_mon. 26from autotest_lib.client.common_lib.cros.graphite import autotest_stats 27from autotest_lib.server import utils as server_utils 28from autotest_lib.server.cros import provision 29from autotest_lib.server.cros.dynamic_suite import frontend_wrappers 30 31NULL_DATETIME = datetime.datetime.max 32NULL_DATE = datetime.date.max 33DUPLICATE_KEY_MSG = 'Duplicate entry' 34 35def prepare_for_serialization(objects): 36 """ 37 Prepare Python objects to be returned via RPC. 38 @param objects: objects to be prepared. 39 """ 40 if (isinstance(objects, list) and len(objects) and 41 isinstance(objects[0], dict) and 'id' in objects[0]): 42 objects = _gather_unique_dicts(objects) 43 return _prepare_data(objects) 44 45 46def prepare_rows_as_nested_dicts(query, nested_dict_column_names): 47 """ 48 Prepare a Django query to be returned via RPC as a sequence of nested 49 dictionaries. 50 51 @param query - A Django model query object with a select_related() method. 52 @param nested_dict_column_names - A list of column/attribute names for the 53 rows returned by query to expand into nested dictionaries using 54 their get_object_dict() method when not None. 55 56 @returns An list suitable to returned in an RPC. 57 """ 58 all_dicts = [] 59 for row in query.select_related(): 60 row_dict = row.get_object_dict() 61 for column in nested_dict_column_names: 62 if row_dict[column] is not None: 63 row_dict[column] = getattr(row, column).get_object_dict() 64 all_dicts.append(row_dict) 65 return prepare_for_serialization(all_dicts) 66 67 68def _prepare_data(data): 69 """ 70 Recursively process data structures, performing necessary type 71 conversions to values in data to allow for RPC serialization: 72 -convert datetimes to strings 73 -convert tuples and sets to lists 74 """ 75 if isinstance(data, dict): 76 new_data = {} 77 for key, value in data.iteritems(): 78 new_data[key] = _prepare_data(value) 79 return new_data 80 elif (isinstance(data, list) or isinstance(data, tuple) or 81 isinstance(data, set)): 82 return [_prepare_data(item) for item in data] 83 elif isinstance(data, datetime.date): 84 if data is NULL_DATETIME or data is NULL_DATE: 85 return None 86 return str(data) 87 else: 88 return data 89 90 91def fetchall_as_list_of_dicts(cursor): 92 """ 93 Converts each row in the cursor to a dictionary so that values can be read 94 by using the column name. 95 @param cursor: The database cursor to read from. 96 @returns: A list of each row in the cursor as a dictionary. 97 """ 98 desc = cursor.description 99 return [ dict(zip([col[0] for col in desc], row)) 100 for row in cursor.fetchall() ] 101 102 103def raw_http_response(response_data, content_type=None): 104 response = django.http.HttpResponse(response_data, mimetype=content_type) 105 response['Content-length'] = str(len(response.content)) 106 return response 107 108 109def _gather_unique_dicts(dict_iterable): 110 """\ 111 Pick out unique objects (by ID) from an iterable of object dicts. 112 """ 113 objects = collections.OrderedDict() 114 for obj in dict_iterable: 115 objects.setdefault(obj['id'], obj) 116 return objects.values() 117 118 119def extra_job_status_filters(not_yet_run=False, running=False, finished=False): 120 """\ 121 Generate a SQL WHERE clause for job status filtering, and return it in 122 a dict of keyword args to pass to query.extra(). 123 * not_yet_run: all HQEs are Queued 124 * finished: all HQEs are complete 125 * running: everything else 126 """ 127 if not (not_yet_run or running or finished): 128 return {} 129 not_queued = ('(SELECT job_id FROM afe_host_queue_entries ' 130 'WHERE status != "%s")' 131 % models.HostQueueEntry.Status.QUEUED) 132 not_finished = ('(SELECT job_id FROM afe_host_queue_entries ' 133 'WHERE not complete)') 134 135 where = [] 136 if not_yet_run: 137 where.append('id NOT IN ' + not_queued) 138 if running: 139 where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished)) 140 if finished: 141 where.append('id NOT IN ' + not_finished) 142 return {'where': [' OR '.join(['(%s)' % x for x in where])]} 143 144 145def extra_job_type_filters(extra_args, suite=False, 146 sub=False, standalone=False): 147 """\ 148 Generate a SQL WHERE clause for job status filtering, and return it in 149 a dict of keyword args to pass to query.extra(). 150 151 param extra_args: a dict of existing extra_args. 152 153 No more than one of the parameters should be passed as True: 154 * suite: job which is parent of other jobs 155 * sub: job with a parent job 156 * standalone: job with no child or parent jobs 157 """ 158 assert not ((suite and sub) or 159 (suite and standalone) or 160 (sub and standalone)), ('Cannot specify more than one ' 161 'filter to this function') 162 163 where = extra_args.get('where', []) 164 parent_job_id = ('DISTINCT parent_job_id') 165 child_job_id = ('id') 166 filter_common = ('(SELECT %s FROM afe_jobs ' 167 'WHERE parent_job_id IS NOT NULL)') 168 169 if suite: 170 where.append('id IN ' + filter_common % parent_job_id) 171 elif sub: 172 where.append('id IN ' + filter_common % child_job_id) 173 elif standalone: 174 where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query ' 175 'WHERE parent_job_id IS NOT NULL' 176 ' AND (sub_query.parent_job_id=afe_jobs.id' 177 ' OR sub_query.id=afe_jobs.id))') 178 else: 179 return extra_args 180 181 extra_args['where'] = where 182 return extra_args 183 184 185 186def extra_host_filters(multiple_labels=()): 187 """\ 188 Generate SQL WHERE clauses for matching hosts in an intersection of 189 labels. 190 """ 191 extra_args = {} 192 where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels ' 193 'where label_id=%s)') 194 extra_args['where'] = [where_str] * len(multiple_labels) 195 extra_args['params'] = [models.Label.smart_get(label).id 196 for label in multiple_labels] 197 return extra_args 198 199 200def get_host_query(multiple_labels, exclude_only_if_needed_labels, 201 valid_only, filter_data): 202 if valid_only: 203 query = models.Host.valid_objects.all() 204 else: 205 query = models.Host.objects.all() 206 207 if exclude_only_if_needed_labels: 208 only_if_needed_labels = models.Label.valid_objects.filter( 209 only_if_needed=True) 210 if only_if_needed_labels.count() > 0: 211 only_if_needed_ids = ','.join( 212 str(label['id']) 213 for label in only_if_needed_labels.values('id')) 214 query = models.Host.objects.add_join( 215 query, 'afe_hosts_labels', join_key='host_id', 216 join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)' 217 % only_if_needed_ids), 218 suffix='_exclude_OIN', exclude=True) 219 try: 220 assert 'extra_args' not in filter_data 221 filter_data['extra_args'] = extra_host_filters(multiple_labels) 222 return models.Host.query_objects(filter_data, initial_query=query) 223 except models.Label.DoesNotExist: 224 return models.Host.objects.none() 225 226 227class InconsistencyException(Exception): 228 'Raised when a list of objects does not have a consistent value' 229 230 231def get_consistent_value(objects, field): 232 if not objects: 233 # well a list of nothing is consistent 234 return None 235 236 value = getattr(objects[0], field) 237 for obj in objects: 238 this_value = getattr(obj, field) 239 if this_value != value: 240 raise InconsistencyException(objects[0], obj) 241 return value 242 243 244def afe_test_dict_to_test_object(test_dict): 245 if not isinstance(test_dict, dict): 246 return test_dict 247 248 numerized_dict = {} 249 for key, value in test_dict.iteritems(): 250 try: 251 numerized_dict[key] = int(value) 252 except (ValueError, TypeError): 253 numerized_dict[key] = value 254 255 return type('TestObject', (object,), numerized_dict) 256 257 258def _check_is_server_test(test_type): 259 """Checks if the test type is a server test. 260 261 @param test_type The test type in enum integer or string. 262 263 @returns A boolean to identify if the test type is server test. 264 """ 265 if test_type is not None: 266 if isinstance(test_type, basestring): 267 try: 268 test_type = control_data.CONTROL_TYPE.get_value(test_type) 269 except AttributeError: 270 return False 271 return (test_type == control_data.CONTROL_TYPE.SERVER) 272 return False 273 274 275def prepare_generate_control_file(tests, profilers, db_tests=True): 276 if db_tests: 277 test_objects = [models.Test.smart_get(test) for test in tests] 278 else: 279 test_objects = [afe_test_dict_to_test_object(test) for test in tests] 280 281 profiler_objects = [models.Profiler.smart_get(profiler) 282 for profiler in profilers] 283 # ensure tests are all the same type 284 try: 285 test_type = get_consistent_value(test_objects, 'test_type') 286 except InconsistencyException, exc: 287 test1, test2 = exc.args 288 raise model_logic.ValidationError( 289 {'tests' : 'You cannot run both test_suites and server-side ' 290 'tests together (tests %s and %s differ' % ( 291 test1.name, test2.name)}) 292 293 is_server = _check_is_server_test(test_type) 294 if test_objects: 295 synch_count = max(test.sync_count for test in test_objects) 296 else: 297 synch_count = 1 298 299 if db_tests: 300 dependencies = set(label.name for label 301 in models.Label.objects.filter(test__in=test_objects)) 302 else: 303 dependencies = reduce( 304 set.union, [set(test.dependencies) for test in test_objects]) 305 306 cf_info = dict(is_server=is_server, synch_count=synch_count, 307 dependencies=list(dependencies)) 308 return cf_info, test_objects, profiler_objects 309 310 311def check_job_dependencies(host_objects, job_dependencies): 312 """ 313 Check that a set of machines satisfies a job's dependencies. 314 host_objects: list of models.Host objects 315 job_dependencies: list of names of labels 316 """ 317 # check that hosts satisfy dependencies 318 host_ids = [host.id for host in host_objects] 319 hosts_in_job = models.Host.objects.filter(id__in=host_ids) 320 ok_hosts = hosts_in_job 321 for index, dependency in enumerate(job_dependencies): 322 if not provision.is_for_special_action(dependency): 323 ok_hosts = ok_hosts.filter(labels__name=dependency) 324 failing_hosts = (set(host.hostname for host in host_objects) - 325 set(host.hostname for host in ok_hosts)) 326 if failing_hosts: 327 raise model_logic.ValidationError( 328 {'hosts' : 'Host(s) failed to meet job dependencies (' + 329 (', '.join(job_dependencies)) + '): ' + 330 (', '.join(failing_hosts))}) 331 332 333def check_job_metahost_dependencies(metahost_objects, job_dependencies): 334 """ 335 Check that at least one machine within the metahost spec satisfies the job's 336 dependencies. 337 338 @param metahost_objects A list of label objects representing the metahosts. 339 @param job_dependencies A list of strings of the required label names. 340 @raises NoEligibleHostException If a metahost cannot run the job. 341 """ 342 for metahost in metahost_objects: 343 hosts = models.Host.objects.filter(labels=metahost) 344 for label_name in job_dependencies: 345 if not provision.is_for_special_action(label_name): 346 hosts = hosts.filter(labels__name=label_name) 347 if not any(hosts): 348 raise error.NoEligibleHostException("No hosts within %s satisfy %s." 349 % (metahost.name, ', '.join(job_dependencies))) 350 351 352def _execution_key_for(host_queue_entry): 353 return (host_queue_entry.job.id, host_queue_entry.execution_subdir) 354 355 356def check_abort_synchronous_jobs(host_queue_entries): 357 # ensure user isn't aborting part of a synchronous autoserv execution 358 count_per_execution = {} 359 for queue_entry in host_queue_entries: 360 key = _execution_key_for(queue_entry) 361 count_per_execution.setdefault(key, 0) 362 count_per_execution[key] += 1 363 364 for queue_entry in host_queue_entries: 365 if not queue_entry.execution_subdir: 366 continue 367 execution_count = count_per_execution[_execution_key_for(queue_entry)] 368 if execution_count < queue_entry.job.synch_count: 369 raise model_logic.ValidationError( 370 {'' : 'You cannot abort part of a synchronous job execution ' 371 '(%d/%s), %d included, %d expected' 372 % (queue_entry.job.id, queue_entry.execution_subdir, 373 execution_count, queue_entry.job.synch_count)}) 374 375 376def check_modify_host(update_data): 377 """ 378 Sanity check modify_host* requests. 379 380 @param update_data: A dictionary with the changes to make to a host 381 or hosts. 382 """ 383 # Only the scheduler (monitor_db) is allowed to modify Host status. 384 # Otherwise race conditions happen as a hosts state is changed out from 385 # beneath tasks being run on a host. 386 if 'status' in update_data: 387 raise model_logic.ValidationError({ 388 'status': 'Host status can not be modified by the frontend.'}) 389 390 391def check_modify_host_locking(host, update_data): 392 """ 393 Checks when locking/unlocking has been requested if the host is already 394 locked/unlocked. 395 396 @param host: models.Host object to be modified 397 @param update_data: A dictionary with the changes to make to the host. 398 """ 399 locked = update_data.get('locked', None) 400 lock_reason = update_data.get('lock_reason', None) 401 if locked is not None: 402 if locked and host.locked: 403 raise model_logic.ValidationError({ 404 'locked': 'Host %s already locked by %s on %s.' % 405 (host.hostname, host.locked_by, host.lock_time)}) 406 if not locked and not host.locked: 407 raise model_logic.ValidationError({ 408 'locked': 'Host %s already unlocked.' % host.hostname}) 409 if locked and not lock_reason and not host.locked: 410 raise model_logic.ValidationError({ 411 'locked': 'Please provide a reason for locking Host %s' % 412 host.hostname}) 413 414 415def get_motd(): 416 dirname = os.path.dirname(__file__) 417 filename = os.path.join(dirname, "..", "..", "motd.txt") 418 text = '' 419 try: 420 fp = open(filename, "r") 421 try: 422 text = fp.read() 423 finally: 424 fp.close() 425 except: 426 pass 427 428 return text 429 430 431def _get_metahost_counts(metahost_objects): 432 metahost_counts = {} 433 for metahost in metahost_objects: 434 metahost_counts.setdefault(metahost, 0) 435 metahost_counts[metahost] += 1 436 return metahost_counts 437 438 439def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None): 440 hosts = [] 441 one_time_hosts = [] 442 meta_hosts = [] 443 hostless = False 444 445 queue_entries = job.hostqueueentry_set.all() 446 if queue_entry_filter_data: 447 queue_entries = models.HostQueueEntry.query_objects( 448 queue_entry_filter_data, initial_query=queue_entries) 449 450 for queue_entry in queue_entries: 451 if (queue_entry.host and (preserve_metahosts or 452 not queue_entry.meta_host)): 453 if queue_entry.deleted: 454 continue 455 if queue_entry.host.invalid: 456 one_time_hosts.append(queue_entry.host) 457 else: 458 hosts.append(queue_entry.host) 459 elif queue_entry.meta_host: 460 meta_hosts.append(queue_entry.meta_host) 461 else: 462 hostless = True 463 464 meta_host_counts = _get_metahost_counts(meta_hosts) 465 466 info = dict(dependencies=[label.name for label 467 in job.dependency_labels.all()], 468 hosts=hosts, 469 meta_hosts=meta_hosts, 470 meta_host_counts=meta_host_counts, 471 one_time_hosts=one_time_hosts, 472 hostless=hostless) 473 return info 474 475 476def check_for_duplicate_hosts(host_objects): 477 host_counts = collections.Counter(host_objects) 478 duplicate_hostnames = {host.hostname 479 for host, count in host_counts.iteritems() 480 if count > 1} 481 if duplicate_hostnames: 482 raise model_logic.ValidationError( 483 {'hosts' : 'Duplicate hosts: %s' 484 % ', '.join(duplicate_hostnames)}) 485 486 487def create_new_job(owner, options, host_objects, metahost_objects): 488 all_host_objects = host_objects + metahost_objects 489 dependencies = options.get('dependencies', []) 490 synch_count = options.get('synch_count') 491 492 if synch_count is not None and synch_count > len(all_host_objects): 493 raise model_logic.ValidationError( 494 {'hosts': 495 'only %d hosts provided for job with synch_count = %d' % 496 (len(all_host_objects), synch_count)}) 497 498 check_for_duplicate_hosts(host_objects) 499 500 for label_name in dependencies: 501 if provision.is_for_special_action(label_name): 502 # TODO: We could save a few queries 503 # if we had a bulk ensure-label-exists function, which used 504 # a bulk .get() call. The win is probably very small. 505 _ensure_label_exists(label_name) 506 507 # This only checks targeted hosts, not hosts eligible due to the metahost 508 check_job_dependencies(host_objects, dependencies) 509 check_job_metahost_dependencies(metahost_objects, dependencies) 510 511 options['dependencies'] = list( 512 models.Label.objects.filter(name__in=dependencies)) 513 514 job = models.Job.create(owner=owner, options=options, 515 hosts=all_host_objects) 516 job.queue(all_host_objects, 517 is_template=options.get('is_template', False)) 518 return job.id 519 520 521def _ensure_label_exists(name): 522 """ 523 Ensure that a label called |name| exists in the Django models. 524 525 This function is to be called from within afe rpcs only, as an 526 alternative to server.cros.provision.ensure_label_exists(...). It works 527 by Django model manipulation, rather than by making another create_label 528 rpc call. 529 530 @param name: the label to check for/create. 531 @raises ValidationError: There was an error in the response that was 532 not because the label already existed. 533 @returns True is a label was created, False otherwise. 534 """ 535 # Make sure this function is not called on shards but only on master. 536 assert not server_utils.is_shard() 537 try: 538 models.Label.objects.get(name=name) 539 except models.Label.DoesNotExist: 540 try: 541 new_label = models.Label.objects.create(name=name) 542 new_label.save() 543 return True 544 except django.db.utils.IntegrityError as e: 545 # It is possible that another suite/test already 546 # created the label between the check and save. 547 if DUPLICATE_KEY_MSG in str(e): 548 return False 549 else: 550 raise 551 return False 552 553 554def find_platform(host): 555 """ 556 Figure out the platform name for the given host 557 object. If none, the return value for either will be None. 558 559 @returns platform name for the given host. 560 """ 561 platforms = [label.name for label in host.label_list if label.platform] 562 if not platforms: 563 platform = None 564 else: 565 platform = platforms[0] 566 if len(platforms) > 1: 567 raise ValueError('Host %s has more than one platform: %s' % 568 (host.hostname, ', '.join(platforms))) 569 return platform 570 571 572# support for get_host_queue_entries_and_special_tasks() 573 574def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on): 575 return dict(type=type, 576 host=entry['host'], 577 job=job_dict, 578 execution_path=exec_path, 579 status=status, 580 started_on=started_on, 581 id=str(entry['id']) + type, 582 oid=entry['id']) 583 584 585def _special_task_to_dict(task, queue_entries): 586 """Transforms a special task dictionary to another form of dictionary. 587 588 @param task Special task as a dictionary type 589 @param queue_entries Host queue entries as a list of dictionaries. 590 591 @return Transformed dictionary for a special task. 592 """ 593 job_dict = None 594 if task['queue_entry']: 595 # Scan queue_entries to get the job detail info. 596 for qentry in queue_entries: 597 if task['queue_entry']['id'] == qentry['id']: 598 job_dict = qentry['job'] 599 break 600 # If not found, get it from DB. 601 if job_dict is None: 602 job = models.Job.objects.get(id=task['queue_entry']['job']) 603 job_dict = job.get_object_dict() 604 605 exec_path = server_utils.get_special_task_exec_path( 606 task['host']['hostname'], task['id'], task['task'], 607 time_utils.time_string_to_datetime(task['time_requested'])) 608 status = server_utils.get_special_task_status( 609 task['is_complete'], task['success'], task['is_active']) 610 return _common_entry_to_dict(task, task['task'], job_dict, 611 exec_path, status, task['time_started']) 612 613 614def _queue_entry_to_dict(queue_entry): 615 job_dict = queue_entry['job'] 616 tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner']) 617 exec_path = server_utils.get_hqe_exec_path(tag, 618 queue_entry['execution_subdir']) 619 return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path, 620 queue_entry['status'], queue_entry['started_on']) 621 622 623def prepare_host_queue_entries_and_special_tasks(interleaved_entries, 624 queue_entries): 625 """ 626 Prepare for serialization the interleaved entries of host queue entries 627 and special tasks. 628 Each element in the entries is a dictionary type. 629 The special task dictionary has only a job id for a job and lacks 630 the detail of the job while the host queue entry dictionary has. 631 queue_entries is used to look up the job detail info. 632 633 @param interleaved_entries Host queue entries and special tasks as a list 634 of dictionaries. 635 @param queue_entries Host queue entries as a list of dictionaries. 636 637 @return A post-processed list of dictionaries that is to be serialized. 638 """ 639 dict_list = [] 640 for e in interleaved_entries: 641 # Distinguish the two mixed entries based on the existence of 642 # the key "task". If an entry has the key, the entry is for 643 # special task. Otherwise, host queue entry. 644 if 'task' in e: 645 dict_list.append(_special_task_to_dict(e, queue_entries)) 646 else: 647 dict_list.append(_queue_entry_to_dict(e)) 648 return prepare_for_serialization(dict_list) 649 650 651def _compute_next_job_for_tasks(queue_entries, special_tasks): 652 """ 653 For each task, try to figure out the next job that ran after that task. 654 This is done using two pieces of information: 655 * if the task has a queue entry, we can use that entry's job ID. 656 * if the task has a time_started, we can try to compare that against the 657 started_on field of queue_entries. this isn't guaranteed to work perfectly 658 since queue_entries may also have null started_on values. 659 * if the task has neither, or if use of time_started fails, just use the 660 last computed job ID. 661 662 @param queue_entries Host queue entries as a list of dictionaries. 663 @param special_tasks Special tasks as a list of dictionaries. 664 """ 665 next_job_id = None # most recently computed next job 666 hqe_index = 0 # index for scanning by started_on times 667 for task in special_tasks: 668 if task['queue_entry']: 669 next_job_id = task['queue_entry']['job'] 670 elif task['time_started'] is not None: 671 for queue_entry in queue_entries[hqe_index:]: 672 if queue_entry['started_on'] is None: 673 continue 674 t1 = time_utils.time_string_to_datetime( 675 queue_entry['started_on']) 676 t2 = time_utils.time_string_to_datetime(task['time_started']) 677 if t1 < t2: 678 break 679 next_job_id = queue_entry['job']['id'] 680 681 task['next_job_id'] = next_job_id 682 683 # advance hqe_index to just after next_job_id 684 if next_job_id is not None: 685 for queue_entry in queue_entries[hqe_index:]: 686 if queue_entry['job']['id'] < next_job_id: 687 break 688 hqe_index += 1 689 690 691def interleave_entries(queue_entries, special_tasks): 692 """ 693 Both lists should be ordered by descending ID. 694 """ 695 _compute_next_job_for_tasks(queue_entries, special_tasks) 696 697 # start with all special tasks that've run since the last job 698 interleaved_entries = [] 699 for task in special_tasks: 700 if task['next_job_id'] is not None: 701 break 702 interleaved_entries.append(task) 703 704 # now interleave queue entries with the remaining special tasks 705 special_task_index = len(interleaved_entries) 706 for queue_entry in queue_entries: 707 interleaved_entries.append(queue_entry) 708 # add all tasks that ran between this job and the previous one 709 for task in special_tasks[special_task_index:]: 710 if task['next_job_id'] < queue_entry['job']['id']: 711 break 712 interleaved_entries.append(task) 713 special_task_index += 1 714 715 return interleaved_entries 716 717 718def bucket_hosts_by_shard(host_objs, rpc_hostnames=False): 719 """Figure out which hosts are on which shards. 720 721 @param host_objs: A list of host objects. 722 @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned 723 instead of the 'real' shard hostnames. This only matters for testing 724 environments. 725 726 @return: A map of shard hostname: list of hosts on the shard. 727 """ 728 shard_host_map = collections.defaultdict(list) 729 for host in host_objs: 730 if host.shard: 731 shard_name = (host.shard.rpc_hostname() if rpc_hostnames 732 else host.shard.hostname) 733 shard_host_map[shard_name].append(host.hostname) 734 return shard_host_map 735 736 737def create_job_common( 738 name, 739 priority, 740 control_type, 741 control_file=None, 742 hosts=(), 743 meta_hosts=(), 744 one_time_hosts=(), 745 synch_count=None, 746 is_template=False, 747 timeout=None, 748 timeout_mins=None, 749 max_runtime_mins=None, 750 run_verify=True, 751 email_list='', 752 dependencies=(), 753 reboot_before=None, 754 reboot_after=None, 755 parse_failed_repair=None, 756 hostless=False, 757 keyvals=None, 758 drone_set=None, 759 parent_job_id=None, 760 test_retry=0, 761 run_reset=True, 762 require_ssp=None): 763 #pylint: disable-msg=C0111 764 """ 765 Common code between creating "standard" jobs and creating parameterized jobs 766 """ 767 # input validation 768 host_args_passed = any((hosts, meta_hosts, one_time_hosts)) 769 if hostless: 770 if host_args_passed: 771 raise model_logic.ValidationError({ 772 'hostless': 'Hostless jobs cannot include any hosts!'}) 773 if control_type != control_data.CONTROL_TYPE_NAMES.SERVER: 774 raise model_logic.ValidationError({ 775 'control_type': 'Hostless jobs cannot use client-side ' 776 'control files'}) 777 elif not host_args_passed: 778 raise model_logic.ValidationError({ 779 'arguments' : "For host jobs, you must pass at least one of" 780 " 'hosts', 'meta_hosts', 'one_time_hosts'." 781 }) 782 label_objects = list(models.Label.objects.filter(name__in=meta_hosts)) 783 784 # convert hostnames & meta hosts to host/label objects 785 host_objects = models.Host.smart_get_bulk(hosts) 786 _validate_host_job_sharding(host_objects) 787 for host in one_time_hosts: 788 this_host = models.Host.create_one_time_host(host) 789 host_objects.append(this_host) 790 791 metahost_objects = [] 792 meta_host_labels_by_name = {label.name: label for label in label_objects} 793 for label_name in meta_hosts: 794 if label_name in meta_host_labels_by_name: 795 metahost_objects.append(meta_host_labels_by_name[label_name]) 796 else: 797 raise model_logic.ValidationError( 798 {'meta_hosts' : 'Label "%s" not found' % label_name}) 799 800 options = dict(name=name, 801 priority=priority, 802 control_file=control_file, 803 control_type=control_type, 804 is_template=is_template, 805 timeout=timeout, 806 timeout_mins=timeout_mins, 807 max_runtime_mins=max_runtime_mins, 808 synch_count=synch_count, 809 run_verify=run_verify, 810 email_list=email_list, 811 dependencies=dependencies, 812 reboot_before=reboot_before, 813 reboot_after=reboot_after, 814 parse_failed_repair=parse_failed_repair, 815 keyvals=keyvals, 816 drone_set=drone_set, 817 parent_job_id=parent_job_id, 818 test_retry=test_retry, 819 run_reset=run_reset, 820 require_ssp=require_ssp) 821 822 return create_new_job(owner=models.User.current_user().login, 823 options=options, 824 host_objects=host_objects, 825 metahost_objects=metahost_objects) 826 827 828def _validate_host_job_sharding(host_objects): 829 """Check that the hosts obey job sharding rules.""" 830 if not (server_utils.is_shard() 831 or _allowed_hosts_for_master_job(host_objects)): 832 shard_host_map = bucket_hosts_by_shard(host_objects) 833 raise ValueError( 834 'The following hosts are on shard(s), please create ' 835 'seperate jobs for hosts on each shard: %s ' % 836 shard_host_map) 837 838 839def _allowed_hosts_for_master_job(host_objects): 840 """Check that the hosts are allowed for a job on master.""" 841 # We disallow the following jobs on master: 842 # num_shards > 1: this is a job spanning across multiple shards. 843 # num_shards == 1 but number of hosts on shard is less 844 # than total number of hosts: this is a job that spans across 845 # one shard and the master. 846 shard_host_map = bucket_hosts_by_shard(host_objects) 847 num_shards = len(shard_host_map) 848 if num_shards > 1: 849 return False 850 if num_shards == 1: 851 hosts_on_shard = shard_host_map.values()[0] 852 assert len(hosts_on_shard) <= len(host_objects) 853 return len(hosts_on_shard) == len(host_objects) 854 else: 855 return True 856 857 858def encode_ascii(control_file): 859 """Force a control file to only contain ascii characters. 860 861 @param control_file: Control file to encode. 862 863 @returns the control file in an ascii encoding. 864 865 @raises error.ControlFileMalformed: if encoding fails. 866 """ 867 try: 868 return control_file.encode('ascii') 869 except UnicodeDecodeError as e: 870 raise error.ControlFileMalformed(str(e)) 871 872 873def get_wmatrix_url(): 874 """Get wmatrix url from config file. 875 876 @returns the wmatrix url or an empty string. 877 """ 878 return global_config.global_config.get_config_value('AUTOTEST_WEB', 879 'wmatrix_url', 880 default='') 881 882 883def inject_times_to_filter(start_time_key=None, end_time_key=None, 884 start_time_value=None, end_time_value=None, 885 **filter_data): 886 """Inject the key value pairs of start and end time if provided. 887 888 @param start_time_key: A string represents the filter key of start_time. 889 @param end_time_key: A string represents the filter key of end_time. 890 @param start_time_value: Start_time value. 891 @param end_time_value: End_time value. 892 893 @returns the injected filter_data. 894 """ 895 if start_time_value: 896 filter_data[start_time_key] = start_time_value 897 if end_time_value: 898 filter_data[end_time_key] = end_time_value 899 return filter_data 900 901 902def inject_times_to_hqe_special_tasks_filters(filter_data_common, 903 start_time, end_time): 904 """Inject start and end time to hqe and special tasks filters. 905 906 @param filter_data_common: Common filter for hqe and special tasks. 907 @param start_time_key: A string represents the filter key of start_time. 908 @param end_time_key: A string represents the filter key of end_time. 909 910 @returns a pair of hqe and special tasks filters. 911 """ 912 filter_data_special_tasks = filter_data_common.copy() 913 return (inject_times_to_filter('started_on__gte', 'started_on__lte', 914 start_time, end_time, **filter_data_common), 915 inject_times_to_filter('time_started__gte', 'time_started__lte', 916 start_time, end_time, 917 **filter_data_special_tasks)) 918 919 920def retrieve_shard(shard_hostname): 921 """ 922 Retrieves the shard with the given hostname from the database. 923 924 @param shard_hostname: Hostname of the shard to retrieve 925 926 @raises models.Shard.DoesNotExist, if no shard with this hostname was found. 927 928 @returns: Shard object 929 """ 930 timer = autotest_stats.Timer('shard_heartbeat.retrieve_shard') 931 with timer: 932 return models.Shard.smart_get(shard_hostname) 933 934 935def find_records_for_shard(shard, known_job_ids, known_host_ids): 936 """Find records that should be sent to a shard. 937 938 @param shard: Shard to find records for. 939 @param known_job_ids: List of ids of jobs the shard already has. 940 @param known_host_ids: List of ids of hosts the shard already has. 941 942 @returns: Tuple of three lists for hosts, jobs, and suite job keyvals: 943 (hosts, jobs, suite_job_keyvals). 944 """ 945 timer = autotest_stats.Timer('shard_heartbeat') 946 with timer.get_client('find_hosts'): 947 hosts = models.Host.assign_to_shard(shard, known_host_ids) 948 with timer.get_client('find_jobs'): 949 jobs = models.Job.assign_to_shard(shard, known_job_ids) 950 with timer.get_client('find_suite_job_keyvals'): 951 parent_job_ids = [job.parent_job_id for job in jobs] 952 suite_job_keyvals = models.JobKeyval.objects.filter( 953 job_id__in=parent_job_ids) 954 return hosts, jobs, suite_job_keyvals 955 956 957def _persist_records_with_type_sent_from_shard( 958 shard, records, record_type, *args, **kwargs): 959 """ 960 Handle records of a specified type that were sent to the shard master. 961 962 @param shard: The shard the records were sent from. 963 @param records: The records sent in their serialized format. 964 @param record_type: Type of the objects represented by records. 965 @param args: Additional arguments that will be passed on to the sanity 966 checks. 967 @param kwargs: Additional arguments that will be passed on to the sanity 968 checks. 969 970 @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail. 971 972 @returns: List of primary keys of the processed records. 973 """ 974 pks = [] 975 for serialized_record in records: 976 pk = serialized_record['id'] 977 try: 978 current_record = record_type.objects.get(pk=pk) 979 except record_type.DoesNotExist: 980 raise error.UnallowedRecordsSentToMaster( 981 'Object with pk %s of type %s does not exist on master.' % ( 982 pk, record_type)) 983 984 current_record.sanity_check_update_from_shard( 985 shard, serialized_record, *args, **kwargs) 986 987 current_record.update_from_serialized(serialized_record) 988 pks.append(pk) 989 return pks 990 991 992def persist_records_sent_from_shard(shard, jobs, hqes): 993 """ 994 Sanity checking then saving serialized records sent to master from shard. 995 996 During heartbeats shards upload jobs and hostqueuentries. This performs 997 some sanity checks on these and then updates the existing records for those 998 entries with the updated ones from the heartbeat. 999 1000 The sanity checks include: 1001 - Checking if the objects sent already exist on the master. 1002 - Checking if the objects sent were assigned to this shard. 1003 - hostqueueentries must be sent together with their jobs. 1004 1005 @param shard: The shard the records were sent from. 1006 @param jobs: The jobs the shard sent. 1007 @param hqes: The hostqueuentries the shart sent. 1008 1009 @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail. 1010 """ 1011 timer = autotest_stats.Timer('shard_heartbeat') 1012 with timer.get_client('persist_jobs'): 1013 job_ids_sent = _persist_records_with_type_sent_from_shard( 1014 shard, jobs, models.Job) 1015 1016 with timer.get_client('persist_hqes'): 1017 _persist_records_with_type_sent_from_shard( 1018 shard, hqes, models.HostQueueEntry, job_ids_sent=job_ids_sent) 1019 1020 1021def forward_single_host_rpc_to_shard(func): 1022 """This decorator forwards rpc calls that modify a host to a shard. 1023 1024 If a host is assigned to a shard, rpcs that change his attributes should be 1025 forwarded to the shard. 1026 1027 This assumes the first argument of the function represents a host id. 1028 1029 @param func: The function to decorate 1030 1031 @returns: The function to replace func with. 1032 """ 1033 def replacement(**kwargs): 1034 # Only keyword arguments can be accepted here, as we need the argument 1035 # names to send the rpc. serviceHandler always provides arguments with 1036 # their keywords, so this is not a problem. 1037 1038 # A host record (identified by kwargs['id']) can be deleted in 1039 # func(). Therefore, we should save the data that can be needed later 1040 # before func() is called. 1041 shard_hostname = None 1042 host = models.Host.smart_get(kwargs['id']) 1043 if host and host.shard: 1044 shard_hostname = host.shard.rpc_hostname() 1045 ret = func(**kwargs) 1046 if shard_hostname and not server_utils.is_shard(): 1047 run_rpc_on_multiple_hostnames(func.func_name, 1048 [shard_hostname], 1049 **kwargs) 1050 return ret 1051 1052 return replacement 1053 1054 1055def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs): 1056 """Fanout the given rpc to shards of given hosts. 1057 1058 @param host_objs: Host objects for the rpc. 1059 @param rpc_name: The name of the rpc. 1060 @param include_hostnames: If True, include the hostnames in the kwargs. 1061 Hostnames are not always necessary, this functions is designed to 1062 send rpcs to the shard a host is on, the rpcs themselves could be 1063 related to labels, acls etc. 1064 @param kwargs: The kwargs for the rpc. 1065 """ 1066 # Figure out which hosts are on which shards. 1067 shard_host_map = bucket_hosts_by_shard( 1068 host_objs, rpc_hostnames=True) 1069 1070 # Execute the rpc against the appropriate shards. 1071 for shard, hostnames in shard_host_map.iteritems(): 1072 if include_hostnames: 1073 kwargs['hosts'] = hostnames 1074 try: 1075 run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs) 1076 except: 1077 ei = sys.exc_info() 1078 new_exc = error.RPCException('RPC %s failed on shard %s due to ' 1079 '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1])) 1080 raise new_exc.__class__, new_exc, ei[2] 1081 1082 1083def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs): 1084 """Runs an rpc to multiple AFEs 1085 1086 This is i.e. used to propagate changes made to hosts after they are assigned 1087 to a shard. 1088 1089 @param rpc_call: Name of the rpc endpoint to call. 1090 @param shard_hostnames: List of hostnames to run the rpcs on. 1091 @param **kwargs: Keyword arguments to pass in the rpcs. 1092 """ 1093 # Make sure this function is not called on shards but only on master. 1094 assert not server_utils.is_shard() 1095 for shard_hostname in shard_hostnames: 1096 afe = frontend_wrappers.RetryingAFE(server=shard_hostname, 1097 user=thread_local.get_user()) 1098 afe.run(rpc_call, **kwargs) 1099 1100 1101def get_label(name): 1102 """Gets a label object using a given name. 1103 1104 @param name: Label name. 1105 @raises model.Label.DoesNotExist: when there is no label matching 1106 the given name. 1107 @return: a label object matching the given name. 1108 """ 1109 try: 1110 label = models.Label.smart_get(name) 1111 except models.Label.DoesNotExist: 1112 return None 1113 return label 1114 1115 1116# TODO: hide the following rpcs under is_moblab 1117def moblab_only(func): 1118 """Ensure moblab specific functions only run on Moblab devices.""" 1119 def verify(*args, **kwargs): 1120 if not server_utils.is_moblab(): 1121 raise error.RPCException('RPC: %s can only run on Moblab Systems!', 1122 func.__name__) 1123 return func(*args, **kwargs) 1124 return verify 1125 1126 1127def route_rpc_to_master(func): 1128 """Route RPC to master AFE. 1129 1130 When a shard receives an RPC decorated by this, the RPC is just 1131 forwarded to the master. 1132 When the master gets the RPC, the RPC function is executed. 1133 1134 @param func: An RPC function to decorate 1135 1136 @returns: A function replacing the RPC func. 1137 """ 1138 argspec = inspect.getargspec(func) 1139 if argspec.varargs is not None: 1140 raise Exception('RPC function must not have *args.') 1141 1142 @wraps(func) 1143 def replacement(*args, **kwargs): 1144 """We need special handling when decorating an RPC that can be called 1145 directly using positional arguments. 1146 1147 One example is rpc_interface.create_job(). 1148 rpc_interface.create_job_page_handler() calls the function using both 1149 positional and keyword arguments. Since frontend.RpcClient.run() 1150 takes only keyword arguments for an RPC, positional arguments of the 1151 RPC function need to be transformed into keyword arguments. 1152 """ 1153 kwargs = _convert_to_kwargs_only(func, args, kwargs) 1154 if server_utils.is_shard(): 1155 afe = frontend_wrappers.RetryingAFE( 1156 server=server_utils.get_global_afe_hostname(), 1157 user=thread_local.get_user()) 1158 return afe.run(func.func_name, **kwargs) 1159 return func(**kwargs) 1160 1161 return replacement 1162 1163 1164def _convert_to_kwargs_only(func, args, kwargs): 1165 """Convert a function call's arguments to a kwargs dict. 1166 1167 This is best illustrated with an example. Given: 1168 1169 def foo(a, b, **kwargs): 1170 pass 1171 _to_kwargs(foo, (1, 2), {'c': 3}) # corresponding to foo(1, 2, c=3) 1172 1173 foo(**kwargs) 1174 1175 @param func: function whose signature to use 1176 @param args: positional arguments of call 1177 @param kwargs: keyword arguments of call 1178 1179 @returns: kwargs dict 1180 """ 1181 argspec = inspect.getargspec(func) 1182 # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}} 1183 callargs = inspect.getcallargs(func, *args, **kwargs) 1184 if argspec.keywords is None: 1185 kwargs = {} 1186 else: 1187 kwargs = callargs.pop(argspec.keywords) 1188 kwargs.update(callargs) 1189 return kwargs 1190 1191 1192def get_sample_dut(board, pool): 1193 """Get a dut with the given board and pool. 1194 1195 This method is used to help to locate a dut with the given board and pool. 1196 The dut then can be used to identify a devserver in the same subnet. 1197 1198 @param board: Name of the board. 1199 @param pool: Name of the pool. 1200 1201 @return: Name of a dut with the given board and pool. 1202 """ 1203 if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board): 1204 return None 1205 hosts = list(get_host_query( 1206 multiple_labels=('pool:%s' % pool, 'board:%s' % board), 1207 exclude_only_if_needed_labels=False, 1208 valid_only=True, 1209 filter_data={}, 1210 )) 1211 if not hosts: 1212 return None 1213 else: 1214 return hosts[0].hostname 1215