1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import fnmatch
16import importlib
17import inspect
18import json
19import os
20import pdb
21import sys
22import unittest
23import traceback
24
25from collections import OrderedDict
26
27# This ensures that absolute imports of typ modules will work when
28# running typ/runner.py as a script even if typ is not installed.
29# We need this entry in addition to the one in __main__.py to ensure
30# that typ/runner.py works when invoked via subprocess on windows in
31# _spawn_main().
32path_to_file = os.path.realpath(__file__)
33if path_to_file.endswith('.pyc'):  # pragma: no cover
34    path_to_file = path_to_file[:-1]
35dir_above_typ = os.path.dirname(os.path.dirname(path_to_file))
36if dir_above_typ not in sys.path:  # pragma: no cover
37    sys.path.append(dir_above_typ)
38
39
40from typ import json_results
41from typ.arg_parser import ArgumentParser
42from typ.host import Host
43from typ.pool import make_pool
44from typ.stats import Stats
45from typ.printer import Printer
46from typ.test_case import TestCase as TypTestCase
47from typ.version import VERSION
48
49
50Result = json_results.Result
51ResultSet = json_results.ResultSet
52ResultType = json_results.ResultType
53
54
55def main(argv=None, host=None, win_multiprocessing=None, **defaults):
56    host = host or Host()
57    runner = Runner(host=host)
58    if win_multiprocessing is not None:
59        runner.win_multiprocessing = win_multiprocessing
60    return runner.main(argv, **defaults)
61
62
63class TestInput(object):
64
65    def __init__(self, name, msg='', timeout=None, expected=None):
66        self.name = name
67        self.msg = msg
68        self.timeout = timeout
69        self.expected = expected
70
71
72class TestSet(object):
73
74    def __init__(self, parallel_tests=None, isolated_tests=None,
75                 tests_to_skip=None):
76
77        def promote(tests):
78            tests = tests or []
79            return [test if isinstance(test, TestInput) else TestInput(test)
80                    for test in tests]
81
82        self.parallel_tests = promote(parallel_tests)
83        self.isolated_tests = promote(isolated_tests)
84        self.tests_to_skip = promote(tests_to_skip)
85
86
87class WinMultiprocessing(object):
88    ignore = 'ignore'
89    importable = 'importable'
90    spawn = 'spawn'
91
92    values = [ignore, importable, spawn]
93
94
95class _AddTestsError(Exception):
96    pass
97
98
99class Runner(object):
100
101    def __init__(self, host=None):
102        self.args = None
103        self.classifier = None
104        self.cov = None
105        self.context = None
106        self.coverage_source = None
107        self.host = host or Host()
108        self.loader = unittest.loader.TestLoader()
109        self.printer = None
110        self.setup_fn = None
111        self.stats = None
112        self.teardown_fn = None
113        self.top_level_dir = None
114        self.win_multiprocessing = WinMultiprocessing.spawn
115        self.final_responses = []
116
117        # initialize self.args to the defaults.
118        parser = ArgumentParser(self.host)
119        self.parse_args(parser, [])
120
121    def main(self, argv=None, **defaults):
122        parser = ArgumentParser(self.host)
123        self.parse_args(parser, argv, **defaults)
124        if parser.exit_status is not None:
125            return parser.exit_status
126
127        try:
128            ret, _, _ = self.run()
129            return ret
130        except KeyboardInterrupt:
131            self.print_("interrupted, exiting", stream=self.host.stderr)
132            return 130
133
134    def parse_args(self, parser, argv, **defaults):
135        for attrname in defaults:
136            if not hasattr(self.args, attrname):
137                parser.error("Unknown default argument name '%s'" % attrname,
138                             bailout=False)
139                return
140        parser.set_defaults(**defaults)
141        self.args = parser.parse_args(args=argv)
142        if parser.exit_status is not None:
143            return
144
145    def print_(self, msg='', end='\n', stream=None):
146        self.host.print_(msg, end, stream=stream)
147
148    def run(self, test_set=None):
149
150        ret = 0
151        h = self.host
152
153        if self.args.version:
154            self.print_(VERSION)
155            return ret, None, None
156
157        should_spawn = self._check_win_multiprocessing()
158        if should_spawn:
159            return self._spawn(test_set)
160
161        ret = self._set_up_runner()
162        if ret:  # pragma: no cover
163            return ret, None, None
164
165        find_start = h.time()
166        if self.cov:  # pragma: no cover
167            self.cov.erase()
168            self.cov.start()
169
170        full_results = None
171        result_set = ResultSet()
172
173        if not test_set:
174            ret, test_set = self.find_tests(self.args)
175        find_end = h.time()
176
177        if not ret:
178            ret, full_results = self._run_tests(result_set, test_set)
179
180        if self.cov:  # pragma: no cover
181            self.cov.stop()
182            self.cov.save()
183        test_end = h.time()
184
185        trace = self._trace_from_results(result_set)
186        if full_results:
187            self._summarize(full_results)
188            self._write(self.args.write_full_results_to, full_results)
189            upload_ret = self._upload(full_results)
190            if not ret:
191                ret = upload_ret
192            reporting_end = h.time()
193            self._add_trace_event(trace, 'run', find_start, reporting_end)
194            self._add_trace_event(trace, 'discovery', find_start, find_end)
195            self._add_trace_event(trace, 'testing', find_end, test_end)
196            self._add_trace_event(trace, 'reporting', test_end, reporting_end)
197            self._write(self.args.write_trace_to, trace)
198            self.report_coverage()
199        else:
200            upload_ret = 0
201
202        return ret, full_results, trace
203
204    def _check_win_multiprocessing(self):
205        wmp = self.win_multiprocessing
206
207        ignore, importable, spawn = WinMultiprocessing.values
208
209        if wmp not in WinMultiprocessing.values:
210            raise ValueError('illegal value %s for win_multiprocessing' %
211                             wmp)
212
213        h = self.host
214        if wmp == ignore and h.platform == 'win32':  # pragma: win32
215            raise ValueError('Cannot use WinMultiprocessing.ignore for '
216                             'win_multiprocessing when actually running '
217                             'on Windows.')
218
219        if wmp == ignore or self.args.jobs == 1:
220            return False
221
222        if wmp == importable:
223            if self._main_is_importable():
224                return False
225            raise ValueError('The __main__ module (%s) '  # pragma: no cover
226                             'may not be importable' %
227                             sys.modules['__main__'].__file__)
228
229        assert wmp == spawn
230        return True
231
232    def _main_is_importable(self):  # pragma: untested
233        path = sys.modules['__main__'].__file__
234        if not path:
235            return False
236        if path.endswith('.pyc'):
237            path = path[:-1]
238        if not path.endswith('.py'):
239            return False
240        if path.endswith('__main__.py'):
241            # main modules are not directly importable.
242            return False
243
244        path = self.host.realpath(path)
245        for d in sys.path:
246            if path.startswith(self.host.realpath(d)):
247                return True
248        return False  # pragma: no cover
249
250    def _spawn(self, test_set):
251        # TODO: Handle picklable hooks, rather than requiring them to be None.
252        assert self.classifier is None
253        assert self.context is None
254        assert self.setup_fn is None
255        assert self.teardown_fn is None
256        assert test_set is None
257        h = self.host
258
259        if self.args.write_trace_to:  # pragma: untested
260            should_delete_trace = False
261        else:
262            should_delete_trace = True
263            fp = h.mktempfile(delete=False)
264            fp.close()
265            self.args.write_trace_to = fp.name
266
267        if self.args.write_full_results_to:  # pragma: untested
268            should_delete_results = False
269        else:
270            should_delete_results = True
271            fp = h.mktempfile(delete=False)
272            fp.close()
273            self.args.write_full_results_to = fp.name
274
275        argv = ArgumentParser(h).argv_from_args(self.args)
276        ret = h.call_inline([h.python_interpreter, path_to_file] + argv)
277
278        trace = self._read_and_delete(self.args.write_trace_to,
279                                      should_delete_trace)
280        full_results = self._read_and_delete(self.args.write_full_results_to,
281                                             should_delete_results)
282        return ret, full_results, trace
283
284    def _set_up_runner(self):
285        h = self.host
286        args = self.args
287
288        self.stats = Stats(args.status_format, h.time, args.jobs)
289        self.printer = Printer(
290            self.print_, args.overwrite, args.terminal_width)
291
292        self.top_level_dir = args.top_level_dir
293        if not self.top_level_dir:
294            if args.tests and h.isdir(args.tests[0]):
295                # TODO: figure out what to do if multiple files are
296                # specified and they don't all have the same correct
297                # top level dir.
298                d = h.realpath(h.dirname(args.tests[0]))
299                if h.exists(d, '__init__.py'):
300                    top_dir = d
301                else:
302                    top_dir = args.tests[0]
303            else:
304                top_dir = h.getcwd()
305            while h.exists(top_dir, '__init__.py'):
306                top_dir = h.dirname(top_dir)
307            self.top_level_dir = h.realpath(top_dir)
308
309        h.add_to_path(self.top_level_dir)
310
311        for path in args.path:
312            h.add_to_path(path)
313
314        if args.coverage:  # pragma: no cover
315            try:
316                import coverage
317            except ImportError:
318                h.print_("Error: coverage is not installed")
319                return 1
320            source = self.args.coverage_source
321            if not source:
322                source = [self.top_level_dir] + self.args.path
323            self.coverage_source = source
324            self.cov = coverage.coverage(source=self.coverage_source,
325                                         data_suffix=True)
326            self.cov.erase()
327        return 0
328
329    def find_tests(self, args):
330        test_set = TestSet()
331
332        orig_skip = unittest.skip
333        orig_skip_if = unittest.skipIf
334        if args.all:
335            unittest.skip = lambda reason: lambda x: x
336            unittest.skipIf = lambda condition, reason: lambda x: x
337
338        try:
339            names = self._name_list_from_args(args)
340            classifier = self.classifier or _default_classifier(args)
341
342            for name in names:
343                try:
344                    self._add_tests_to_set(test_set, args.suffixes,
345                                           self.top_level_dir, classifier,
346                                           name)
347                except (AttributeError, ImportError, SyntaxError) as e:
348                    self.print_('Failed to load "%s": %s' % (name, e))
349                    return 1, None
350                except _AddTestsError as e:
351                    self.print_(str(e))
352                    return 1, None
353
354            # TODO: Add support for discovering setupProcess/teardownProcess?
355
356            test_set.parallel_tests = _sort_inputs(test_set.parallel_tests)
357            test_set.isolated_tests = _sort_inputs(test_set.isolated_tests)
358            test_set.tests_to_skip = _sort_inputs(test_set.tests_to_skip)
359            return 0, test_set
360        finally:
361            unittest.skip = orig_skip
362            unittest.skipIf = orig_skip_if
363
364    def _name_list_from_args(self, args):
365        if args.tests:
366            names = args.tests
367        elif args.file_list:
368            if args.file_list == '-':
369                s = self.host.stdin.read()
370            else:
371                s = self.host.read_text_file(args.file_list)
372            names = [line.strip() for line in s.splitlines()]
373        else:
374            names = [self.top_level_dir]
375        return names
376
377    def _add_tests_to_set(self, test_set, suffixes, top_level_dir, classifier,
378                          name):
379        h = self.host
380        loader = self.loader
381        add_tests = _test_adder(test_set, classifier)
382
383        if h.isfile(name):
384            rpath = h.relpath(name, top_level_dir)
385            if rpath.endswith('.py'):
386                rpath = rpath[:-3]
387            module = rpath.replace(h.sep, '.')
388            add_tests(loader.loadTestsFromName(module))
389        elif h.isdir(name):
390            for suffix in suffixes:
391                add_tests(loader.discover(name, suffix, top_level_dir))
392        else:
393            possible_dir = name.replace('.', h.sep)
394            if h.isdir(top_level_dir, possible_dir):
395                for suffix in suffixes:
396                    path = h.join(top_level_dir, possible_dir)
397                    suite = loader.discover(path, suffix, top_level_dir)
398                    add_tests(suite)
399            else:
400                add_tests(loader.loadTestsFromName(name))
401
402    def _run_tests(self, result_set, test_set):
403        h = self.host
404        if not test_set.parallel_tests and not test_set.isolated_tests:
405            self.print_('No tests to run.')
406            return 1, None
407
408        all_tests = [ti.name for ti in
409                     _sort_inputs(test_set.parallel_tests +
410                                  test_set.isolated_tests +
411                                  test_set.tests_to_skip)]
412
413        if self.args.list_only:
414            self.print_('\n'.join(all_tests))
415            return 0, None
416
417        self._run_one_set(self.stats, result_set, test_set)
418
419        failed_tests = sorted(json_results.failed_test_names(result_set))
420        retry_limit = self.args.retry_limit
421
422        while retry_limit and failed_tests:
423            if retry_limit == self.args.retry_limit:
424                self.flush()
425                self.args.overwrite = False
426                self.printer.should_overwrite = False
427                self.args.verbose = min(self.args.verbose, 1)
428
429            self.print_('')
430            self.print_('Retrying failed tests (attempt #%d of %d)...' %
431                        (self.args.retry_limit - retry_limit + 1,
432                         self.args.retry_limit))
433            self.print_('')
434
435            stats = Stats(self.args.status_format, h.time, 1)
436            stats.total = len(failed_tests)
437            tests_to_retry = TestSet(isolated_tests=list(failed_tests))
438            retry_set = ResultSet()
439            self._run_one_set(stats, retry_set, tests_to_retry)
440            result_set.results.extend(retry_set.results)
441            failed_tests = json_results.failed_test_names(retry_set)
442            retry_limit -= 1
443
444        if retry_limit != self.args.retry_limit:
445            self.print_('')
446
447        full_results = json_results.make_full_results(self.args.metadata,
448                                                      int(h.time()),
449                                                      all_tests, result_set)
450
451        return (json_results.exit_code_from_full_results(full_results),
452                full_results)
453
454    def _run_one_set(self, stats, result_set, test_set):
455        stats.total = (len(test_set.parallel_tests) +
456                       len(test_set.isolated_tests) +
457                       len(test_set.tests_to_skip))
458        self._skip_tests(stats, result_set, test_set.tests_to_skip)
459        self._run_list(stats, result_set,
460                       test_set.parallel_tests, self.args.jobs)
461        self._run_list(stats, result_set,
462                       test_set.isolated_tests, 1)
463
464    def _skip_tests(self, stats, result_set, tests_to_skip):
465        for test_input in tests_to_skip:
466            last = self.host.time()
467            stats.started += 1
468            self._print_test_started(stats, test_input)
469            now = self.host.time()
470            result = Result(test_input.name, actual=ResultType.Skip,
471                            started=last, took=(now - last), worker=0,
472                            expected=[ResultType.Skip],
473                            out=test_input.msg)
474            result_set.add(result)
475            stats.finished += 1
476            self._print_test_finished(stats, result)
477
478    def _run_list(self, stats, result_set, test_inputs, jobs):
479        h = self.host
480        running_jobs = set()
481
482        jobs = min(len(test_inputs), jobs)
483        if not jobs:
484            return
485
486        child = _Child(self)
487        pool = make_pool(h, jobs, _run_one_test, child,
488                         _setup_process, _teardown_process)
489        try:
490            while test_inputs or running_jobs:
491                while test_inputs and (len(running_jobs) < self.args.jobs):
492                    test_input = test_inputs.pop(0)
493                    stats.started += 1
494                    pool.send(test_input)
495                    running_jobs.add(test_input.name)
496                    self._print_test_started(stats, test_input)
497
498                result = pool.get()
499                running_jobs.remove(result.name)
500                result_set.add(result)
501                stats.finished += 1
502                self._print_test_finished(stats, result)
503            pool.close()
504        finally:
505            self.final_responses.extend(pool.join())
506
507    def _print_test_started(self, stats, test_input):
508        if self.args.quiet:
509            # Print nothing when --quiet was passed.
510            return
511
512        # If -vvv was passed, print when the test is queued to be run.
513        # We don't actually know when the test picked up to run, because
514        # that is handled by the child process (where we can't easily
515        # print things). Otherwise, only print when the test is started
516        # if we know we can overwrite the line, so that we do not
517        # get multiple lines of output as noise (in -vvv, we actually want
518        # the noise).
519        test_start_msg = stats.format() + test_input.name
520        if self.args.verbose > 2:
521            self.update(test_start_msg + ' queued', elide=False)
522        if self.args.overwrite:
523            self.update(test_start_msg, elide=(not self.args.verbose))
524
525    def _print_test_finished(self, stats, result):
526        stats.add_time()
527
528        assert result.actual in [ResultType.Failure, ResultType.Skip,
529                                 ResultType.Pass]
530        if result.actual == ResultType.Failure:
531            result_str = ' failed'
532        elif result.actual == ResultType.Skip:
533            result_str = ' was skipped'
534        elif result.actual == ResultType.Pass:
535            result_str = ' passed'
536
537        if result.unexpected:
538            result_str += ' unexpectedly'
539        if self.args.timing:
540            timing_str = ' %.4fs' % result.took
541        else:
542            timing_str = ''
543        suffix = '%s%s' % (result_str, timing_str)
544        out = result.out
545        err = result.err
546        if result.code:
547            if out or err:
548                suffix += ':\n'
549            self.update(stats.format() + result.name + suffix, elide=False)
550            for l in out.splitlines():
551                self.print_('  %s' % l)
552            for l in err.splitlines():
553                self.print_('  %s' % l)
554        elif not self.args.quiet:
555            if self.args.verbose > 1 and (out or err):
556                suffix += ':\n'
557            self.update(stats.format() + result.name + suffix,
558                        elide=(not self.args.verbose))
559            if self.args.verbose > 1:
560                for l in out.splitlines():
561                    self.print_('  %s' % l)
562                for l in err.splitlines():
563                    self.print_('  %s' % l)
564            if self.args.verbose:
565                self.flush()
566
567    def update(self, msg, elide):
568        self.printer.update(msg, elide)
569
570    def flush(self):
571        self.printer.flush()
572
573    def _summarize(self, full_results):
574        num_tests = self.stats.finished
575        num_failures = json_results.num_failures(full_results)
576
577        if self.args.quiet and num_failures == 0:
578            return
579
580        if self.args.timing:
581            timing_clause = ' in %.1fs' % (self.host.time() -
582                                           self.stats.started_time)
583        else:
584            timing_clause = ''
585        self.update('%d test%s run%s, %d failure%s.' %
586                    (num_tests,
587                     '' if num_tests == 1 else 's',
588                     timing_clause,
589                     num_failures,
590                     '' if num_failures == 1 else 's'), elide=False)
591        self.print_()
592
593    def _read_and_delete(self, path, delete):
594        h = self.host
595        obj = None
596        if h.exists(path):
597            contents = h.read_text_file(path)
598            if contents:
599                obj = json.loads(contents)
600            if delete:
601                h.remove(path)
602        return obj
603
604    def _write(self, path, obj):
605        if path:
606            self.host.write_text_file(path, json.dumps(obj, indent=2) + '\n')
607
608    def _upload(self, full_results):
609        h = self.host
610        if not self.args.test_results_server:
611            return 0
612
613        url, content_type, data = json_results.make_upload_request(
614            self.args.test_results_server, self.args.builder_name,
615            self.args.master_name, self.args.test_type,
616            full_results)
617
618        try:
619            h.fetch(url, data, {'Content-Type': content_type})
620            return 0
621        except Exception as e:
622            h.print_('Uploading the JSON results raised "%s"' % str(e))
623            return 1
624
625    def report_coverage(self):
626        if self.args.coverage:  # pragma: no cover
627            self.host.print_()
628            import coverage
629            cov = coverage.coverage(data_suffix=True)
630            cov.combine()
631            cov.report(show_missing=self.args.coverage_show_missing,
632                       omit=self.args.coverage_omit)
633            if self.args.coverage_annotate:
634                cov.annotate(omit=self.args.coverage_omit)
635
636    def _add_trace_event(self, trace, name, start, end):
637        event = {
638            'name': name,
639            'ts': int((start - self.stats.started_time) * 1000000),
640            'dur': int((end - start) * 1000000),
641            'ph': 'X',
642            'pid': self.host.getpid(),
643            'tid': 0,
644        }
645        trace['traceEvents'].append(event)
646
647    def _trace_from_results(self, result_set):
648        trace = OrderedDict()
649        trace['traceEvents'] = []
650        trace['otherData'] = {}
651        for m in self.args.metadata:
652            k, v = m.split('=')
653            trace['otherData'][k] = v
654
655        for result in result_set.results:
656            started = int((result.started - self.stats.started_time) * 1000000)
657            took = int(result.took * 1000000)
658            event = OrderedDict()
659            event['name'] = result.name
660            event['dur'] = took
661            event['ts'] = started
662            event['ph'] = 'X'  # "Complete" events
663            event['pid'] = result.pid
664            event['tid'] = result.worker
665
666            args = OrderedDict()
667            args['expected'] = sorted(str(r) for r in result.expected)
668            args['actual'] = str(result.actual)
669            args['out'] = result.out
670            args['err'] = result.err
671            args['code'] = result.code
672            args['unexpected'] = result.unexpected
673            args['flaky'] = result.flaky
674            event['args'] = args
675
676            trace['traceEvents'].append(event)
677        return trace
678
679
680def _matches(name, globs):
681    return any(fnmatch.fnmatch(name, glob) for glob in globs)
682
683
684def _default_classifier(args):
685    def default_classifier(test_set, test):
686        name = test.id()
687        if not args.all and _matches(name, args.skip):
688            test_set.tests_to_skip.append(TestInput(name,
689                                                    'skipped by request'))
690        elif _matches(name, args.isolate):
691            test_set.isolated_tests.append(TestInput(name))
692        else:
693            test_set.parallel_tests.append(TestInput(name))
694    return default_classifier
695
696
697def _test_adder(test_set, classifier):
698    def add_tests(obj):
699        if isinstance(obj, unittest.suite.TestSuite):
700            for el in obj:
701                add_tests(el)
702        elif (obj.id().startswith('unittest.loader.LoadTestsFailure') or
703              obj.id().startswith('unittest.loader.ModuleImportFailure')):
704            # Access to protected member pylint: disable=W0212
705            module_name = obj._testMethodName
706            try:
707                method = getattr(obj, obj._testMethodName)
708                method()
709            except Exception as e:
710                if 'LoadTests' in obj.id():
711                    raise _AddTestsError('%s.load_tests() failed: %s'
712                                         % (module_name, str(e)))
713                else:
714                    raise _AddTestsError(str(e))
715        else:
716            assert isinstance(obj, unittest.TestCase)
717            classifier(test_set, obj)
718    return add_tests
719
720
721class _Child(object):
722
723    def __init__(self, parent):
724        self.host = None
725        self.worker_num = None
726        self.all = parent.args.all
727        self.debugger = parent.args.debugger
728        self.coverage = parent.args.coverage and parent.args.jobs > 1
729        self.coverage_source = parent.coverage_source
730        self.dry_run = parent.args.dry_run
731        self.loader = parent.loader
732        self.passthrough = parent.args.passthrough
733        self.context = parent.context
734        self.setup_fn = parent.setup_fn
735        self.teardown_fn = parent.teardown_fn
736        self.context_after_setup = None
737        self.top_level_dir = parent.top_level_dir
738        self.loaded_suites = {}
739        self.cov = None
740
741
742def _setup_process(host, worker_num, child):
743    child.host = host
744    child.worker_num = worker_num
745    # pylint: disable=protected-access
746
747    if child.coverage:  # pragma: no cover
748        import coverage
749        child.cov = coverage.coverage(source=child.coverage_source,
750                                      data_suffix=True)
751        child.cov._warn_no_data = False
752        child.cov.start()
753
754    if child.setup_fn:
755        child.context_after_setup = child.setup_fn(child, child.context)
756    else:
757        child.context_after_setup = child.context
758    return child
759
760
761def _teardown_process(child):
762    res = None
763    e = None
764    if child.teardown_fn:
765        try:
766            res = child.teardown_fn(child, child.context_after_setup)
767        except Exception as e:
768            pass
769
770    if child.cov:  # pragma: no cover
771        child.cov.stop()
772        child.cov.save()
773
774    return (child.worker_num, res, e)
775
776
777def _run_one_test(child, test_input):
778    h = child.host
779    pid = h.getpid()
780    test_name = test_input.name
781
782    start = h.time()
783
784    # It is important to capture the output before loading the test
785    # to ensure that
786    # 1) the loader doesn't logs something we don't captured
787    # 2) neither the loader nor the test case grab a reference to the
788    #    uncaptured stdout or stderr that later is used when the test is run.
789    # This comes up when using the FakeTestLoader and testing typ itself,
790    # but could come up when testing non-typ code as well.
791    h.capture_output(divert=not child.passthrough)
792
793    tb_str = ''
794    try:
795        orig_skip = unittest.skip
796        orig_skip_if = unittest.skipIf
797        if child.all:
798            unittest.skip = lambda reason: lambda x: x
799            unittest.skipIf = lambda condition, reason: lambda x: x
800
801        try:
802            suite = child.loader.loadTestsFromName(test_name)
803        except Exception as e:
804            try:
805                suite = _load_via_load_tests(child, test_name)
806            except Exception as e:  # pragma: untested
807                suite = []
808                tb_str = traceback.format_exc(e)
809    finally:
810        unittest.skip = orig_skip
811        unittest.skipIf = orig_skip_if
812
813    tests = list(suite)
814    if len(tests) != 1:
815        err = 'Failed to load %s'
816        if tb_str:  # pragma: untested
817            err += (' (traceback follows):\n  %s' %
818                    '  \n'.join(tb_str.splitlines()))
819
820        h.restore_output()
821        return Result(test_name, ResultType.Failure, start, 0,
822                      child.worker_num, unexpected=True, code=1,
823                      err=err, pid=pid)
824
825    test_case = tests[0]
826    if isinstance(test_case, TypTestCase):
827        test_case.child = child
828        test_case.context = child.context_after_setup
829
830    test_result = unittest.TestResult()
831    out = ''
832    err = ''
833    try:
834        if child.dry_run:
835            pass
836        elif child.debugger:  # pragma: no cover
837            _run_under_debugger(h, test_case, suite, test_result)
838        else:
839            suite.run(test_result)
840    finally:
841        out, err = h.restore_output()
842
843    took = h.time() - start
844    return _result_from_test_result(test_result, test_name, start, took, out,
845                                    err, child.worker_num, pid)
846
847
848def _run_under_debugger(host, test_case, suite,
849                        test_result):  # pragma: no cover
850    # Access to protected member pylint: disable=W0212
851    test_func = getattr(test_case, test_case._testMethodName)
852    fname = inspect.getsourcefile(test_func)
853    lineno = inspect.getsourcelines(test_func)[1] + 1
854    dbg = pdb.Pdb(stdout=host.stdout.stream)
855    dbg.set_break(fname, lineno)
856    dbg.runcall(suite.run, test_result)
857
858
859def _result_from_test_result(test_result, test_name, start, took, out, err,
860                             worker_num, pid):
861    flaky = False
862    if test_result.failures:
863        expected = [ResultType.Pass]
864        actual = ResultType.Failure
865        code = 1
866        unexpected = True
867        err = err + test_result.failures[0][1]
868    elif test_result.errors:
869        expected = [ResultType.Pass]
870        actual = ResultType.Failure
871        code = 1
872        unexpected = True
873        err = err + test_result.errors[0][1]
874    elif test_result.skipped:
875        expected = [ResultType.Skip]
876        actual = ResultType.Skip
877        err = err + test_result.skipped[0][1]
878        code = 0
879        unexpected = False
880    elif test_result.expectedFailures:
881        expected = [ResultType.Failure]
882        actual = ResultType.Failure
883        code = 1
884        err = err + test_result.expectedFailures[0][1]
885        unexpected = False
886    elif test_result.unexpectedSuccesses:
887        expected = [ResultType.Failure]
888        actual = ResultType.Pass
889        code = 0
890        unexpected = True
891    else:
892        expected = [ResultType.Pass]
893        actual = ResultType.Pass
894        code = 0
895        unexpected = False
896
897    return Result(test_name, actual, start, took, worker_num,
898                  expected, unexpected, flaky, code, out, err, pid)
899
900
901def _load_via_load_tests(child, test_name):
902    # If we couldn't import a test directly, the test may be only loadable
903    # via unittest's load_tests protocol. See if we can find a load_tests
904    # entry point that will work for this test.
905    loader = child.loader
906    comps = test_name.split('.')
907    new_suite = unittest.TestSuite()
908
909    while comps:
910        name = '.'.join(comps)
911        module = None
912        suite = None
913        if name not in child.loaded_suites:
914            try:
915                module = importlib.import_module(name)
916            except ImportError:
917                pass
918            if module:
919                suite = loader.loadTestsFromModule(module)
920            child.loaded_suites[name] = suite
921        suite = child.loaded_suites[name]
922        if suite:
923            for test_case in suite:
924                assert isinstance(test_case, unittest.TestCase)
925                if test_case.id() == test_name:
926                    new_suite.addTest(test_case)
927                    break
928        comps.pop()
929    return new_suite
930
931
932def _sort_inputs(inps):
933    return sorted(inps, key=lambda inp: inp.name)
934
935
936if __name__ == '__main__':  # pragma: no cover
937    sys.modules['__main__'].__file__ = path_to_file
938    sys.exit(main(win_multiprocessing=WinMultiprocessing.importable))
939