1# Copyright 2014 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import fnmatch 16import importlib 17import inspect 18import json 19import os 20import pdb 21import sys 22import unittest 23import traceback 24 25from collections import OrderedDict 26 27# This ensures that absolute imports of typ modules will work when 28# running typ/runner.py as a script even if typ is not installed. 29# We need this entry in addition to the one in __main__.py to ensure 30# that typ/runner.py works when invoked via subprocess on windows in 31# _spawn_main(). 32path_to_file = os.path.realpath(__file__) 33if path_to_file.endswith('.pyc'): # pragma: no cover 34 path_to_file = path_to_file[:-1] 35dir_above_typ = os.path.dirname(os.path.dirname(path_to_file)) 36if dir_above_typ not in sys.path: # pragma: no cover 37 sys.path.append(dir_above_typ) 38 39 40from typ import json_results 41from typ.arg_parser import ArgumentParser 42from typ.host import Host 43from typ.pool import make_pool 44from typ.stats import Stats 45from typ.printer import Printer 46from typ.test_case import TestCase as TypTestCase 47from typ.version import VERSION 48 49 50Result = json_results.Result 51ResultSet = json_results.ResultSet 52ResultType = json_results.ResultType 53 54 55def main(argv=None, host=None, win_multiprocessing=None, **defaults): 56 host = host or Host() 57 runner = Runner(host=host) 58 if win_multiprocessing is not None: 59 runner.win_multiprocessing = win_multiprocessing 60 return runner.main(argv, **defaults) 61 62 63class TestInput(object): 64 65 def __init__(self, name, msg='', timeout=None, expected=None): 66 self.name = name 67 self.msg = msg 68 self.timeout = timeout 69 self.expected = expected 70 71 72class TestSet(object): 73 74 def __init__(self, parallel_tests=None, isolated_tests=None, 75 tests_to_skip=None): 76 77 def promote(tests): 78 tests = tests or [] 79 return [test if isinstance(test, TestInput) else TestInput(test) 80 for test in tests] 81 82 self.parallel_tests = promote(parallel_tests) 83 self.isolated_tests = promote(isolated_tests) 84 self.tests_to_skip = promote(tests_to_skip) 85 86 87class WinMultiprocessing(object): 88 ignore = 'ignore' 89 importable = 'importable' 90 spawn = 'spawn' 91 92 values = [ignore, importable, spawn] 93 94 95class _AddTestsError(Exception): 96 pass 97 98 99class Runner(object): 100 101 def __init__(self, host=None): 102 self.args = None 103 self.classifier = None 104 self.cov = None 105 self.context = None 106 self.coverage_source = None 107 self.host = host or Host() 108 self.loader = unittest.loader.TestLoader() 109 self.printer = None 110 self.setup_fn = None 111 self.stats = None 112 self.teardown_fn = None 113 self.top_level_dir = None 114 self.win_multiprocessing = WinMultiprocessing.spawn 115 self.final_responses = [] 116 117 # initialize self.args to the defaults. 118 parser = ArgumentParser(self.host) 119 self.parse_args(parser, []) 120 121 def main(self, argv=None, **defaults): 122 parser = ArgumentParser(self.host) 123 self.parse_args(parser, argv, **defaults) 124 if parser.exit_status is not None: 125 return parser.exit_status 126 127 try: 128 ret, _, _ = self.run() 129 return ret 130 except KeyboardInterrupt: 131 self.print_("interrupted, exiting", stream=self.host.stderr) 132 return 130 133 134 def parse_args(self, parser, argv, **defaults): 135 for attrname in defaults: 136 if not hasattr(self.args, attrname): 137 parser.error("Unknown default argument name '%s'" % attrname, 138 bailout=False) 139 return 140 parser.set_defaults(**defaults) 141 self.args = parser.parse_args(args=argv) 142 if parser.exit_status is not None: 143 return 144 145 def print_(self, msg='', end='\n', stream=None): 146 self.host.print_(msg, end, stream=stream) 147 148 def run(self, test_set=None): 149 150 ret = 0 151 h = self.host 152 153 if self.args.version: 154 self.print_(VERSION) 155 return ret, None, None 156 157 should_spawn = self._check_win_multiprocessing() 158 if should_spawn: 159 return self._spawn(test_set) 160 161 ret = self._set_up_runner() 162 if ret: # pragma: no cover 163 return ret, None, None 164 165 find_start = h.time() 166 if self.cov: # pragma: no cover 167 self.cov.erase() 168 self.cov.start() 169 170 full_results = None 171 result_set = ResultSet() 172 173 if not test_set: 174 ret, test_set = self.find_tests(self.args) 175 find_end = h.time() 176 177 if not ret: 178 ret, full_results = self._run_tests(result_set, test_set) 179 180 if self.cov: # pragma: no cover 181 self.cov.stop() 182 self.cov.save() 183 test_end = h.time() 184 185 trace = self._trace_from_results(result_set) 186 if full_results: 187 self._summarize(full_results) 188 self._write(self.args.write_full_results_to, full_results) 189 upload_ret = self._upload(full_results) 190 if not ret: 191 ret = upload_ret 192 reporting_end = h.time() 193 self._add_trace_event(trace, 'run', find_start, reporting_end) 194 self._add_trace_event(trace, 'discovery', find_start, find_end) 195 self._add_trace_event(trace, 'testing', find_end, test_end) 196 self._add_trace_event(trace, 'reporting', test_end, reporting_end) 197 self._write(self.args.write_trace_to, trace) 198 self.report_coverage() 199 else: 200 upload_ret = 0 201 202 return ret, full_results, trace 203 204 def _check_win_multiprocessing(self): 205 wmp = self.win_multiprocessing 206 207 ignore, importable, spawn = WinMultiprocessing.values 208 209 if wmp not in WinMultiprocessing.values: 210 raise ValueError('illegal value %s for win_multiprocessing' % 211 wmp) 212 213 h = self.host 214 if wmp == ignore and h.platform == 'win32': # pragma: win32 215 raise ValueError('Cannot use WinMultiprocessing.ignore for ' 216 'win_multiprocessing when actually running ' 217 'on Windows.') 218 219 if wmp == ignore or self.args.jobs == 1: 220 return False 221 222 if wmp == importable: 223 if self._main_is_importable(): 224 return False 225 raise ValueError('The __main__ module (%s) ' # pragma: no cover 226 'may not be importable' % 227 sys.modules['__main__'].__file__) 228 229 assert wmp == spawn 230 return True 231 232 def _main_is_importable(self): # pragma: untested 233 path = sys.modules['__main__'].__file__ 234 if not path: 235 return False 236 if path.endswith('.pyc'): 237 path = path[:-1] 238 if not path.endswith('.py'): 239 return False 240 if path.endswith('__main__.py'): 241 # main modules are not directly importable. 242 return False 243 244 path = self.host.realpath(path) 245 for d in sys.path: 246 if path.startswith(self.host.realpath(d)): 247 return True 248 return False # pragma: no cover 249 250 def _spawn(self, test_set): 251 # TODO: Handle picklable hooks, rather than requiring them to be None. 252 assert self.classifier is None 253 assert self.context is None 254 assert self.setup_fn is None 255 assert self.teardown_fn is None 256 assert test_set is None 257 h = self.host 258 259 if self.args.write_trace_to: # pragma: untested 260 should_delete_trace = False 261 else: 262 should_delete_trace = True 263 fp = h.mktempfile(delete=False) 264 fp.close() 265 self.args.write_trace_to = fp.name 266 267 if self.args.write_full_results_to: # pragma: untested 268 should_delete_results = False 269 else: 270 should_delete_results = True 271 fp = h.mktempfile(delete=False) 272 fp.close() 273 self.args.write_full_results_to = fp.name 274 275 argv = ArgumentParser(h).argv_from_args(self.args) 276 ret = h.call_inline([h.python_interpreter, path_to_file] + argv) 277 278 trace = self._read_and_delete(self.args.write_trace_to, 279 should_delete_trace) 280 full_results = self._read_and_delete(self.args.write_full_results_to, 281 should_delete_results) 282 return ret, full_results, trace 283 284 def _set_up_runner(self): 285 h = self.host 286 args = self.args 287 288 self.stats = Stats(args.status_format, h.time, args.jobs) 289 self.printer = Printer( 290 self.print_, args.overwrite, args.terminal_width) 291 292 self.top_level_dir = args.top_level_dir 293 if not self.top_level_dir: 294 if args.tests and h.isdir(args.tests[0]): 295 # TODO: figure out what to do if multiple files are 296 # specified and they don't all have the same correct 297 # top level dir. 298 d = h.realpath(h.dirname(args.tests[0])) 299 if h.exists(d, '__init__.py'): 300 top_dir = d 301 else: 302 top_dir = args.tests[0] 303 else: 304 top_dir = h.getcwd() 305 while h.exists(top_dir, '__init__.py'): 306 top_dir = h.dirname(top_dir) 307 self.top_level_dir = h.realpath(top_dir) 308 309 h.add_to_path(self.top_level_dir) 310 311 for path in args.path: 312 h.add_to_path(path) 313 314 if args.coverage: # pragma: no cover 315 try: 316 import coverage 317 except ImportError: 318 h.print_("Error: coverage is not installed") 319 return 1 320 source = self.args.coverage_source 321 if not source: 322 source = [self.top_level_dir] + self.args.path 323 self.coverage_source = source 324 self.cov = coverage.coverage(source=self.coverage_source, 325 data_suffix=True) 326 self.cov.erase() 327 return 0 328 329 def find_tests(self, args): 330 test_set = TestSet() 331 332 orig_skip = unittest.skip 333 orig_skip_if = unittest.skipIf 334 if args.all: 335 unittest.skip = lambda reason: lambda x: x 336 unittest.skipIf = lambda condition, reason: lambda x: x 337 338 try: 339 names = self._name_list_from_args(args) 340 classifier = self.classifier or _default_classifier(args) 341 342 for name in names: 343 try: 344 self._add_tests_to_set(test_set, args.suffixes, 345 self.top_level_dir, classifier, 346 name) 347 except (AttributeError, ImportError, SyntaxError) as e: 348 self.print_('Failed to load "%s": %s' % (name, e)) 349 return 1, None 350 except _AddTestsError as e: 351 self.print_(str(e)) 352 return 1, None 353 354 # TODO: Add support for discovering setupProcess/teardownProcess? 355 356 test_set.parallel_tests = _sort_inputs(test_set.parallel_tests) 357 test_set.isolated_tests = _sort_inputs(test_set.isolated_tests) 358 test_set.tests_to_skip = _sort_inputs(test_set.tests_to_skip) 359 return 0, test_set 360 finally: 361 unittest.skip = orig_skip 362 unittest.skipIf = orig_skip_if 363 364 def _name_list_from_args(self, args): 365 if args.tests: 366 names = args.tests 367 elif args.file_list: 368 if args.file_list == '-': 369 s = self.host.stdin.read() 370 else: 371 s = self.host.read_text_file(args.file_list) 372 names = [line.strip() for line in s.splitlines()] 373 else: 374 names = [self.top_level_dir] 375 return names 376 377 def _add_tests_to_set(self, test_set, suffixes, top_level_dir, classifier, 378 name): 379 h = self.host 380 loader = self.loader 381 add_tests = _test_adder(test_set, classifier) 382 383 if h.isfile(name): 384 rpath = h.relpath(name, top_level_dir) 385 if rpath.endswith('.py'): 386 rpath = rpath[:-3] 387 module = rpath.replace(h.sep, '.') 388 add_tests(loader.loadTestsFromName(module)) 389 elif h.isdir(name): 390 for suffix in suffixes: 391 add_tests(loader.discover(name, suffix, top_level_dir)) 392 else: 393 possible_dir = name.replace('.', h.sep) 394 if h.isdir(top_level_dir, possible_dir): 395 for suffix in suffixes: 396 path = h.join(top_level_dir, possible_dir) 397 suite = loader.discover(path, suffix, top_level_dir) 398 add_tests(suite) 399 else: 400 add_tests(loader.loadTestsFromName(name)) 401 402 def _run_tests(self, result_set, test_set): 403 h = self.host 404 if not test_set.parallel_tests and not test_set.isolated_tests: 405 self.print_('No tests to run.') 406 return 1, None 407 408 all_tests = [ti.name for ti in 409 _sort_inputs(test_set.parallel_tests + 410 test_set.isolated_tests + 411 test_set.tests_to_skip)] 412 413 if self.args.list_only: 414 self.print_('\n'.join(all_tests)) 415 return 0, None 416 417 self._run_one_set(self.stats, result_set, test_set) 418 419 failed_tests = sorted(json_results.failed_test_names(result_set)) 420 retry_limit = self.args.retry_limit 421 422 while retry_limit and failed_tests: 423 if retry_limit == self.args.retry_limit: 424 self.flush() 425 self.args.overwrite = False 426 self.printer.should_overwrite = False 427 self.args.verbose = min(self.args.verbose, 1) 428 429 self.print_('') 430 self.print_('Retrying failed tests (attempt #%d of %d)...' % 431 (self.args.retry_limit - retry_limit + 1, 432 self.args.retry_limit)) 433 self.print_('') 434 435 stats = Stats(self.args.status_format, h.time, 1) 436 stats.total = len(failed_tests) 437 tests_to_retry = TestSet(isolated_tests=list(failed_tests)) 438 retry_set = ResultSet() 439 self._run_one_set(stats, retry_set, tests_to_retry) 440 result_set.results.extend(retry_set.results) 441 failed_tests = json_results.failed_test_names(retry_set) 442 retry_limit -= 1 443 444 if retry_limit != self.args.retry_limit: 445 self.print_('') 446 447 full_results = json_results.make_full_results(self.args.metadata, 448 int(h.time()), 449 all_tests, result_set) 450 451 return (json_results.exit_code_from_full_results(full_results), 452 full_results) 453 454 def _run_one_set(self, stats, result_set, test_set): 455 stats.total = (len(test_set.parallel_tests) + 456 len(test_set.isolated_tests) + 457 len(test_set.tests_to_skip)) 458 self._skip_tests(stats, result_set, test_set.tests_to_skip) 459 self._run_list(stats, result_set, 460 test_set.parallel_tests, self.args.jobs) 461 self._run_list(stats, result_set, 462 test_set.isolated_tests, 1) 463 464 def _skip_tests(self, stats, result_set, tests_to_skip): 465 for test_input in tests_to_skip: 466 last = self.host.time() 467 stats.started += 1 468 self._print_test_started(stats, test_input) 469 now = self.host.time() 470 result = Result(test_input.name, actual=ResultType.Skip, 471 started=last, took=(now - last), worker=0, 472 expected=[ResultType.Skip], 473 out=test_input.msg) 474 result_set.add(result) 475 stats.finished += 1 476 self._print_test_finished(stats, result) 477 478 def _run_list(self, stats, result_set, test_inputs, jobs): 479 h = self.host 480 running_jobs = set() 481 482 jobs = min(len(test_inputs), jobs) 483 if not jobs: 484 return 485 486 child = _Child(self) 487 pool = make_pool(h, jobs, _run_one_test, child, 488 _setup_process, _teardown_process) 489 try: 490 while test_inputs or running_jobs: 491 while test_inputs and (len(running_jobs) < self.args.jobs): 492 test_input = test_inputs.pop(0) 493 stats.started += 1 494 pool.send(test_input) 495 running_jobs.add(test_input.name) 496 self._print_test_started(stats, test_input) 497 498 result = pool.get() 499 running_jobs.remove(result.name) 500 result_set.add(result) 501 stats.finished += 1 502 self._print_test_finished(stats, result) 503 pool.close() 504 finally: 505 self.final_responses.extend(pool.join()) 506 507 def _print_test_started(self, stats, test_input): 508 if self.args.quiet: 509 # Print nothing when --quiet was passed. 510 return 511 512 # If -vvv was passed, print when the test is queued to be run. 513 # We don't actually know when the test picked up to run, because 514 # that is handled by the child process (where we can't easily 515 # print things). Otherwise, only print when the test is started 516 # if we know we can overwrite the line, so that we do not 517 # get multiple lines of output as noise (in -vvv, we actually want 518 # the noise). 519 test_start_msg = stats.format() + test_input.name 520 if self.args.verbose > 2: 521 self.update(test_start_msg + ' queued', elide=False) 522 if self.args.overwrite: 523 self.update(test_start_msg, elide=(not self.args.verbose)) 524 525 def _print_test_finished(self, stats, result): 526 stats.add_time() 527 528 assert result.actual in [ResultType.Failure, ResultType.Skip, 529 ResultType.Pass] 530 if result.actual == ResultType.Failure: 531 result_str = ' failed' 532 elif result.actual == ResultType.Skip: 533 result_str = ' was skipped' 534 elif result.actual == ResultType.Pass: 535 result_str = ' passed' 536 537 if result.unexpected: 538 result_str += ' unexpectedly' 539 if self.args.timing: 540 timing_str = ' %.4fs' % result.took 541 else: 542 timing_str = '' 543 suffix = '%s%s' % (result_str, timing_str) 544 out = result.out 545 err = result.err 546 if result.code: 547 if out or err: 548 suffix += ':\n' 549 self.update(stats.format() + result.name + suffix, elide=False) 550 for l in out.splitlines(): 551 self.print_(' %s' % l) 552 for l in err.splitlines(): 553 self.print_(' %s' % l) 554 elif not self.args.quiet: 555 if self.args.verbose > 1 and (out or err): 556 suffix += ':\n' 557 self.update(stats.format() + result.name + suffix, 558 elide=(not self.args.verbose)) 559 if self.args.verbose > 1: 560 for l in out.splitlines(): 561 self.print_(' %s' % l) 562 for l in err.splitlines(): 563 self.print_(' %s' % l) 564 if self.args.verbose: 565 self.flush() 566 567 def update(self, msg, elide): 568 self.printer.update(msg, elide) 569 570 def flush(self): 571 self.printer.flush() 572 573 def _summarize(self, full_results): 574 num_tests = self.stats.finished 575 num_failures = json_results.num_failures(full_results) 576 577 if self.args.quiet and num_failures == 0: 578 return 579 580 if self.args.timing: 581 timing_clause = ' in %.1fs' % (self.host.time() - 582 self.stats.started_time) 583 else: 584 timing_clause = '' 585 self.update('%d test%s run%s, %d failure%s.' % 586 (num_tests, 587 '' if num_tests == 1 else 's', 588 timing_clause, 589 num_failures, 590 '' if num_failures == 1 else 's'), elide=False) 591 self.print_() 592 593 def _read_and_delete(self, path, delete): 594 h = self.host 595 obj = None 596 if h.exists(path): 597 contents = h.read_text_file(path) 598 if contents: 599 obj = json.loads(contents) 600 if delete: 601 h.remove(path) 602 return obj 603 604 def _write(self, path, obj): 605 if path: 606 self.host.write_text_file(path, json.dumps(obj, indent=2) + '\n') 607 608 def _upload(self, full_results): 609 h = self.host 610 if not self.args.test_results_server: 611 return 0 612 613 url, content_type, data = json_results.make_upload_request( 614 self.args.test_results_server, self.args.builder_name, 615 self.args.master_name, self.args.test_type, 616 full_results) 617 618 try: 619 h.fetch(url, data, {'Content-Type': content_type}) 620 return 0 621 except Exception as e: 622 h.print_('Uploading the JSON results raised "%s"' % str(e)) 623 return 1 624 625 def report_coverage(self): 626 if self.args.coverage: # pragma: no cover 627 self.host.print_() 628 import coverage 629 cov = coverage.coverage(data_suffix=True) 630 cov.combine() 631 cov.report(show_missing=self.args.coverage_show_missing, 632 omit=self.args.coverage_omit) 633 if self.args.coverage_annotate: 634 cov.annotate(omit=self.args.coverage_omit) 635 636 def _add_trace_event(self, trace, name, start, end): 637 event = { 638 'name': name, 639 'ts': int((start - self.stats.started_time) * 1000000), 640 'dur': int((end - start) * 1000000), 641 'ph': 'X', 642 'pid': self.host.getpid(), 643 'tid': 0, 644 } 645 trace['traceEvents'].append(event) 646 647 def _trace_from_results(self, result_set): 648 trace = OrderedDict() 649 trace['traceEvents'] = [] 650 trace['otherData'] = {} 651 for m in self.args.metadata: 652 k, v = m.split('=') 653 trace['otherData'][k] = v 654 655 for result in result_set.results: 656 started = int((result.started - self.stats.started_time) * 1000000) 657 took = int(result.took * 1000000) 658 event = OrderedDict() 659 event['name'] = result.name 660 event['dur'] = took 661 event['ts'] = started 662 event['ph'] = 'X' # "Complete" events 663 event['pid'] = result.pid 664 event['tid'] = result.worker 665 666 args = OrderedDict() 667 args['expected'] = sorted(str(r) for r in result.expected) 668 args['actual'] = str(result.actual) 669 args['out'] = result.out 670 args['err'] = result.err 671 args['code'] = result.code 672 args['unexpected'] = result.unexpected 673 args['flaky'] = result.flaky 674 event['args'] = args 675 676 trace['traceEvents'].append(event) 677 return trace 678 679 680def _matches(name, globs): 681 return any(fnmatch.fnmatch(name, glob) for glob in globs) 682 683 684def _default_classifier(args): 685 def default_classifier(test_set, test): 686 name = test.id() 687 if not args.all and _matches(name, args.skip): 688 test_set.tests_to_skip.append(TestInput(name, 689 'skipped by request')) 690 elif _matches(name, args.isolate): 691 test_set.isolated_tests.append(TestInput(name)) 692 else: 693 test_set.parallel_tests.append(TestInput(name)) 694 return default_classifier 695 696 697def _test_adder(test_set, classifier): 698 def add_tests(obj): 699 if isinstance(obj, unittest.suite.TestSuite): 700 for el in obj: 701 add_tests(el) 702 elif (obj.id().startswith('unittest.loader.LoadTestsFailure') or 703 obj.id().startswith('unittest.loader.ModuleImportFailure')): 704 # Access to protected member pylint: disable=W0212 705 module_name = obj._testMethodName 706 try: 707 method = getattr(obj, obj._testMethodName) 708 method() 709 except Exception as e: 710 if 'LoadTests' in obj.id(): 711 raise _AddTestsError('%s.load_tests() failed: %s' 712 % (module_name, str(e))) 713 else: 714 raise _AddTestsError(str(e)) 715 else: 716 assert isinstance(obj, unittest.TestCase) 717 classifier(test_set, obj) 718 return add_tests 719 720 721class _Child(object): 722 723 def __init__(self, parent): 724 self.host = None 725 self.worker_num = None 726 self.all = parent.args.all 727 self.debugger = parent.args.debugger 728 self.coverage = parent.args.coverage and parent.args.jobs > 1 729 self.coverage_source = parent.coverage_source 730 self.dry_run = parent.args.dry_run 731 self.loader = parent.loader 732 self.passthrough = parent.args.passthrough 733 self.context = parent.context 734 self.setup_fn = parent.setup_fn 735 self.teardown_fn = parent.teardown_fn 736 self.context_after_setup = None 737 self.top_level_dir = parent.top_level_dir 738 self.loaded_suites = {} 739 self.cov = None 740 741 742def _setup_process(host, worker_num, child): 743 child.host = host 744 child.worker_num = worker_num 745 # pylint: disable=protected-access 746 747 if child.coverage: # pragma: no cover 748 import coverage 749 child.cov = coverage.coverage(source=child.coverage_source, 750 data_suffix=True) 751 child.cov._warn_no_data = False 752 child.cov.start() 753 754 if child.setup_fn: 755 child.context_after_setup = child.setup_fn(child, child.context) 756 else: 757 child.context_after_setup = child.context 758 return child 759 760 761def _teardown_process(child): 762 res = None 763 e = None 764 if child.teardown_fn: 765 try: 766 res = child.teardown_fn(child, child.context_after_setup) 767 except Exception as e: 768 pass 769 770 if child.cov: # pragma: no cover 771 child.cov.stop() 772 child.cov.save() 773 774 return (child.worker_num, res, e) 775 776 777def _run_one_test(child, test_input): 778 h = child.host 779 pid = h.getpid() 780 test_name = test_input.name 781 782 start = h.time() 783 784 # It is important to capture the output before loading the test 785 # to ensure that 786 # 1) the loader doesn't logs something we don't captured 787 # 2) neither the loader nor the test case grab a reference to the 788 # uncaptured stdout or stderr that later is used when the test is run. 789 # This comes up when using the FakeTestLoader and testing typ itself, 790 # but could come up when testing non-typ code as well. 791 h.capture_output(divert=not child.passthrough) 792 793 tb_str = '' 794 try: 795 orig_skip = unittest.skip 796 orig_skip_if = unittest.skipIf 797 if child.all: 798 unittest.skip = lambda reason: lambda x: x 799 unittest.skipIf = lambda condition, reason: lambda x: x 800 801 try: 802 suite = child.loader.loadTestsFromName(test_name) 803 except Exception as e: 804 try: 805 suite = _load_via_load_tests(child, test_name) 806 except Exception as e: # pragma: untested 807 suite = [] 808 tb_str = traceback.format_exc(e) 809 finally: 810 unittest.skip = orig_skip 811 unittest.skipIf = orig_skip_if 812 813 tests = list(suite) 814 if len(tests) != 1: 815 err = 'Failed to load %s' 816 if tb_str: # pragma: untested 817 err += (' (traceback follows):\n %s' % 818 ' \n'.join(tb_str.splitlines())) 819 820 h.restore_output() 821 return Result(test_name, ResultType.Failure, start, 0, 822 child.worker_num, unexpected=True, code=1, 823 err=err, pid=pid) 824 825 test_case = tests[0] 826 if isinstance(test_case, TypTestCase): 827 test_case.child = child 828 test_case.context = child.context_after_setup 829 830 test_result = unittest.TestResult() 831 out = '' 832 err = '' 833 try: 834 if child.dry_run: 835 pass 836 elif child.debugger: # pragma: no cover 837 _run_under_debugger(h, test_case, suite, test_result) 838 else: 839 suite.run(test_result) 840 finally: 841 out, err = h.restore_output() 842 843 took = h.time() - start 844 return _result_from_test_result(test_result, test_name, start, took, out, 845 err, child.worker_num, pid) 846 847 848def _run_under_debugger(host, test_case, suite, 849 test_result): # pragma: no cover 850 # Access to protected member pylint: disable=W0212 851 test_func = getattr(test_case, test_case._testMethodName) 852 fname = inspect.getsourcefile(test_func) 853 lineno = inspect.getsourcelines(test_func)[1] + 1 854 dbg = pdb.Pdb(stdout=host.stdout.stream) 855 dbg.set_break(fname, lineno) 856 dbg.runcall(suite.run, test_result) 857 858 859def _result_from_test_result(test_result, test_name, start, took, out, err, 860 worker_num, pid): 861 flaky = False 862 if test_result.failures: 863 expected = [ResultType.Pass] 864 actual = ResultType.Failure 865 code = 1 866 unexpected = True 867 err = err + test_result.failures[0][1] 868 elif test_result.errors: 869 expected = [ResultType.Pass] 870 actual = ResultType.Failure 871 code = 1 872 unexpected = True 873 err = err + test_result.errors[0][1] 874 elif test_result.skipped: 875 expected = [ResultType.Skip] 876 actual = ResultType.Skip 877 err = err + test_result.skipped[0][1] 878 code = 0 879 unexpected = False 880 elif test_result.expectedFailures: 881 expected = [ResultType.Failure] 882 actual = ResultType.Failure 883 code = 1 884 err = err + test_result.expectedFailures[0][1] 885 unexpected = False 886 elif test_result.unexpectedSuccesses: 887 expected = [ResultType.Failure] 888 actual = ResultType.Pass 889 code = 0 890 unexpected = True 891 else: 892 expected = [ResultType.Pass] 893 actual = ResultType.Pass 894 code = 0 895 unexpected = False 896 897 return Result(test_name, actual, start, took, worker_num, 898 expected, unexpected, flaky, code, out, err, pid) 899 900 901def _load_via_load_tests(child, test_name): 902 # If we couldn't import a test directly, the test may be only loadable 903 # via unittest's load_tests protocol. See if we can find a load_tests 904 # entry point that will work for this test. 905 loader = child.loader 906 comps = test_name.split('.') 907 new_suite = unittest.TestSuite() 908 909 while comps: 910 name = '.'.join(comps) 911 module = None 912 suite = None 913 if name not in child.loaded_suites: 914 try: 915 module = importlib.import_module(name) 916 except ImportError: 917 pass 918 if module: 919 suite = loader.loadTestsFromModule(module) 920 child.loaded_suites[name] = suite 921 suite = child.loaded_suites[name] 922 if suite: 923 for test_case in suite: 924 assert isinstance(test_case, unittest.TestCase) 925 if test_case.id() == test_name: 926 new_suite.addTest(test_case) 927 break 928 comps.pop() 929 return new_suite 930 931 932def _sort_inputs(inps): 933 return sorted(inps, key=lambda inp: inp.name) 934 935 936if __name__ == '__main__': # pragma: no cover 937 sys.modules['__main__'].__file__ = path_to_file 938 sys.exit(main(win_multiprocessing=WinMultiprocessing.importable)) 939