1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import
16
17import importlib
18import pkgutil
19import re
20import unittest
21
22import coverage
23
24TEST_MODULE_REGEX = r'^.*_test$'
25
26
27class Loader(object):
28    """Test loader for setuptools test suite support.
29
30  Attributes:
31    suite (unittest.TestSuite): All tests collected by the loader.
32    loader (unittest.TestLoader): Standard Python unittest loader to be ran per
33      module discovered.
34    module_matcher (re.RegexObject): A regular expression object to match
35      against module names and determine whether or not the discovered module
36      contributes to the test suite.
37  """
38
39    def __init__(self):
40        self.suite = unittest.TestSuite()
41        self.loader = unittest.TestLoader()
42        self.module_matcher = re.compile(TEST_MODULE_REGEX)
43
44    def loadTestsFromNames(self, names, module=None):
45        """Function mirroring TestLoader::loadTestsFromNames, as expected by
46    setuptools.setup argument `test_loader`."""
47        # ensure that we capture decorators and definitions (else our coverage
48        # measure unnecessarily suffers)
49        coverage_context = coverage.Coverage(data_suffix=True)
50        coverage_context.start()
51        imported_modules = tuple(
52            importlib.import_module(name) for name in names)
53        for imported_module in imported_modules:
54            self.visit_module(imported_module)
55        for imported_module in imported_modules:
56            try:
57                package_paths = imported_module.__path__
58            except AttributeError:
59                continue
60            self.walk_packages(package_paths)
61        coverage_context.stop()
62        coverage_context.save()
63        return self.suite
64
65    def walk_packages(self, package_paths):
66        """Walks over the packages, dispatching `visit_module` calls.
67
68    Args:
69      package_paths (list): A list of paths over which to walk through modules
70        along.
71    """
72        for importer, module_name, is_package in (
73                pkgutil.walk_packages(package_paths)):
74            module = importer.find_module(module_name).load_module(module_name)
75            self.visit_module(module)
76
77    def visit_module(self, module):
78        """Visits the module, adding discovered tests to the test suite.
79
80    Args:
81      module (module): Module to match against self.module_matcher; if matched
82        it has its tests loaded via self.loader into self.suite.
83    """
84        if self.module_matcher.match(module.__name__):
85            module_suite = self.loader.loadTestsFromModule(module)
86            self.suite.addTest(module_suite)
87
88
89def iterate_suite_cases(suite):
90    """Generator over all unittest.TestCases in a unittest.TestSuite.
91
92  Args:
93    suite (unittest.TestSuite): Suite to iterate over in the generator.
94
95  Returns:
96    generator: A generator over all unittest.TestCases in `suite`.
97  """
98    for item in suite:
99        if isinstance(item, unittest.TestSuite):
100            for child_item in iterate_suite_cases(item):
101                yield child_item
102        elif isinstance(item, unittest.TestCase):
103            yield item
104        else:
105            raise ValueError('unexpected suite item of type {}'.format(
106                type(item)))
107