1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8import functools
9import warnings
10
11from fnmatch import fnmatch
12
13from . import case, suite, util
14
15__unittest = True
16
17# what about .pyc (etc)
18# we would need to avoid loading the same tests multiple times
19# from '.py', *and* '.pyc'
20VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
21
22
23class _FailedTest(case.TestCase):
24    _testMethodName = None
25
26    def __init__(self, method_name, exception):
27        self._exception = exception
28        super(_FailedTest, self).__init__(method_name)
29
30    def __getattr__(self, name):
31        if name != self._testMethodName:
32            return super(_FailedTest, self).__getattr__(name)
33        def testFailure():
34            raise self._exception
35        return testFailure
36
37
38def _make_failed_import_test(name, suiteClass):
39    message = 'Failed to import test module: %s\n%s' % (
40        name, traceback.format_exc())
41    return _make_failed_test(name, ImportError(message), suiteClass, message)
42
43def _make_failed_load_tests(name, exception, suiteClass):
44    message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
45    return _make_failed_test(
46        name, exception, suiteClass, message)
47
48def _make_failed_test(methodname, exception, suiteClass, message):
49    test = _FailedTest(methodname, exception)
50    return suiteClass((test,)), message
51
52def _make_skipped_test(methodname, exception, suiteClass):
53    @case.skip(str(exception))
54    def testSkipped(self):
55        pass
56    attrs = {methodname: testSkipped}
57    TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
58    return suiteClass((TestClass(methodname),))
59
60def _jython_aware_splitext(path):
61    if path.lower().endswith('$py.class'):
62        return path[:-9]
63    return os.path.splitext(path)[0]
64
65
66class TestLoader(object):
67    """
68    This class is responsible for loading tests according to various criteria
69    and returning them wrapped in a TestSuite
70    """
71    testMethodPrefix = 'test'
72    sortTestMethodsUsing = staticmethod(util.three_way_cmp)
73    suiteClass = suite.TestSuite
74    _top_level_dir = None
75
76    def __init__(self):
77        super(TestLoader, self).__init__()
78        self.errors = []
79        # Tracks packages which we have called into via load_tests, to
80        # avoid infinite re-entrancy.
81        self._loading_packages = set()
82
83    def loadTestsFromTestCase(self, testCaseClass):
84        """Return a suite of all test cases contained in testCaseClass"""
85        if issubclass(testCaseClass, suite.TestSuite):
86            raise TypeError("Test cases should not be derived from "
87                            "TestSuite. Maybe you meant to derive from "
88                            "TestCase?")
89        testCaseNames = self.getTestCaseNames(testCaseClass)
90        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
91            testCaseNames = ['runTest']
92        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
93        return loaded_suite
94
95    # XXX After Python 3.5, remove backward compatibility hacks for
96    # use_load_tests deprecation via *args and **kws.  See issue 16662.
97    def loadTestsFromModule(self, module, *args, pattern=None, **kws):
98        """Return a suite of all test cases contained in the given module"""
99        # This method used to take an undocumented and unofficial
100        # use_load_tests argument.  For backward compatibility, we still
101        # accept the argument (which can also be the first position) but we
102        # ignore it and issue a deprecation warning if it's present.
103        if len(args) > 0 or 'use_load_tests' in kws:
104            warnings.warn('use_load_tests is deprecated and ignored',
105                          DeprecationWarning)
106            kws.pop('use_load_tests', None)
107        if len(args) > 1:
108            # Complain about the number of arguments, but don't forget the
109            # required `module` argument.
110            complaint = len(args) + 1
111            raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
112        if len(kws) != 0:
113            # Since the keyword arguments are unsorted (see PEP 468), just
114            # pick the alphabetically sorted first argument to complain about,
115            # if multiple were given.  At least the error message will be
116            # predictable.
117            complaint = sorted(kws)[0]
118            raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
119        tests = []
120        for name in dir(module):
121            obj = getattr(module, name)
122            if isinstance(obj, type) and issubclass(obj, case.TestCase):
123                tests.append(self.loadTestsFromTestCase(obj))
124
125        load_tests = getattr(module, 'load_tests', None)
126        tests = self.suiteClass(tests)
127        if load_tests is not None:
128            try:
129                return load_tests(self, tests, pattern)
130            except Exception as e:
131                error_case, error_message = _make_failed_load_tests(
132                    module.__name__, e, self.suiteClass)
133                self.errors.append(error_message)
134                return error_case
135        return tests
136
137    def loadTestsFromName(self, name, module=None):
138        """Return a suite of all test cases given a string specifier.
139
140        The name may resolve either to a module, a test case class, a
141        test method within a test case class, or a callable object which
142        returns a TestCase or TestSuite instance.
143
144        The method optionally resolves the names relative to a given module.
145        """
146        parts = name.split('.')
147        error_case, error_message = None, None
148        if module is None:
149            parts_copy = parts[:]
150            while parts_copy:
151                try:
152                    module_name = '.'.join(parts_copy)
153                    module = __import__(module_name)
154                    break
155                except ImportError:
156                    next_attribute = parts_copy.pop()
157                    # Last error so we can give it to the user if needed.
158                    error_case, error_message = _make_failed_import_test(
159                        next_attribute, self.suiteClass)
160                    if not parts_copy:
161                        # Even the top level import failed: report that error.
162                        self.errors.append(error_message)
163                        return error_case
164            parts = parts[1:]
165        obj = module
166        for part in parts:
167            try:
168                parent, obj = obj, getattr(obj, part)
169            except AttributeError as e:
170                # We can't traverse some part of the name.
171                if (getattr(obj, '__path__', None) is not None
172                    and error_case is not None):
173                    # This is a package (no __path__ per importlib docs), and we
174                    # encountered an error importing something. We cannot tell
175                    # the difference between package.WrongNameTestClass and
176                    # package.wrong_module_name so we just report the
177                    # ImportError - it is more informative.
178                    self.errors.append(error_message)
179                    return error_case
180                else:
181                    # Otherwise, we signal that an AttributeError has occurred.
182                    error_case, error_message = _make_failed_test(
183                        part, e, self.suiteClass,
184                        'Failed to access attribute:\n%s' % (
185                            traceback.format_exc(),))
186                    self.errors.append(error_message)
187                    return error_case
188
189        if isinstance(obj, types.ModuleType):
190            return self.loadTestsFromModule(obj)
191        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
192            return self.loadTestsFromTestCase(obj)
193        elif (isinstance(obj, types.FunctionType) and
194              isinstance(parent, type) and
195              issubclass(parent, case.TestCase)):
196            name = parts[-1]
197            inst = parent(name)
198            # static methods follow a different path
199            if not isinstance(getattr(inst, name), types.FunctionType):
200                return self.suiteClass([inst])
201        elif isinstance(obj, suite.TestSuite):
202            return obj
203        if callable(obj):
204            test = obj()
205            if isinstance(test, suite.TestSuite):
206                return test
207            elif isinstance(test, case.TestCase):
208                return self.suiteClass([test])
209            else:
210                raise TypeError("calling %s returned %s, not a test" %
211                                (obj, test))
212        else:
213            raise TypeError("don't know how to make test from: %s" % obj)
214
215    def loadTestsFromNames(self, names, module=None):
216        """Return a suite of all test cases found using the given sequence
217        of string specifiers. See 'loadTestsFromName()'.
218        """
219        suites = [self.loadTestsFromName(name, module) for name in names]
220        return self.suiteClass(suites)
221
222    def getTestCaseNames(self, testCaseClass):
223        """Return a sorted sequence of method names found within testCaseClass
224        """
225        def isTestMethod(attrname, testCaseClass=testCaseClass,
226                         prefix=self.testMethodPrefix):
227            return attrname.startswith(prefix) and \
228                callable(getattr(testCaseClass, attrname))
229        testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
230        if self.sortTestMethodsUsing:
231            testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
232        return testFnNames
233
234    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
235        """Find and return all test modules from the specified start
236        directory, recursing into subdirectories to find them and return all
237        tests found within them. Only test files that match the pattern will
238        be loaded. (Using shell style pattern matching.)
239
240        All test modules must be importable from the top level of the project.
241        If the start directory is not the top level directory then the top
242        level directory must be specified separately.
243
244        If a test package name (directory with '__init__.py') matches the
245        pattern then the package will be checked for a 'load_tests' function. If
246        this exists then it will be called with (loader, tests, pattern) unless
247        the package has already had load_tests called from the same discovery
248        invocation, in which case the package module object is not scanned for
249        tests - this ensures that when a package uses discover to further
250        discover child tests that infinite recursion does not happen.
251
252        If load_tests exists then discovery does *not* recurse into the package,
253        load_tests is responsible for loading all tests in the package.
254
255        The pattern is deliberately not stored as a loader attribute so that
256        packages can continue discovery themselves. top_level_dir is stored so
257        load_tests does not need to pass this argument in to loader.discover().
258
259        Paths are sorted before being imported to ensure reproducible execution
260        order even on filesystems with non-alphabetical ordering like ext3/4.
261        """
262        set_implicit_top = False
263        if top_level_dir is None and self._top_level_dir is not None:
264            # make top_level_dir optional if called from load_tests in a package
265            top_level_dir = self._top_level_dir
266        elif top_level_dir is None:
267            set_implicit_top = True
268            top_level_dir = start_dir
269
270        top_level_dir = os.path.abspath(top_level_dir)
271
272        if not top_level_dir in sys.path:
273            # all test modules must be importable from the top level directory
274            # should we *unconditionally* put the start directory in first
275            # in sys.path to minimise likelihood of conflicts between installed
276            # modules and development versions?
277            sys.path.insert(0, top_level_dir)
278        self._top_level_dir = top_level_dir
279
280        is_not_importable = False
281        is_namespace = False
282        tests = []
283        if os.path.isdir(os.path.abspath(start_dir)):
284            start_dir = os.path.abspath(start_dir)
285            if start_dir != top_level_dir:
286                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
287        else:
288            # support for discovery from dotted module names
289            try:
290                __import__(start_dir)
291            except ImportError:
292                is_not_importable = True
293            else:
294                the_module = sys.modules[start_dir]
295                top_part = start_dir.split('.')[0]
296                try:
297                    start_dir = os.path.abspath(
298                       os.path.dirname((the_module.__file__)))
299                except AttributeError:
300                    # look for namespace packages
301                    try:
302                        spec = the_module.__spec__
303                    except AttributeError:
304                        spec = None
305
306                    if spec and spec.loader is None:
307                        if spec.submodule_search_locations is not None:
308                            is_namespace = True
309
310                            for path in the_module.__path__:
311                                if (not set_implicit_top and
312                                    not path.startswith(top_level_dir)):
313                                    continue
314                                self._top_level_dir = \
315                                    (path.split(the_module.__name__
316                                         .replace(".", os.path.sep))[0])
317                                tests.extend(self._find_tests(path,
318                                                              pattern,
319                                                              namespace=True))
320                    elif the_module.__name__ in sys.builtin_module_names:
321                        # builtin module
322                        raise TypeError('Can not use builtin modules '
323                                        'as dotted module names') from None
324                    else:
325                        raise TypeError(
326                            'don\'t know how to discover from {!r}'
327                            .format(the_module)) from None
328
329                if set_implicit_top:
330                    if not is_namespace:
331                        self._top_level_dir = \
332                           self._get_directory_containing_module(top_part)
333                        sys.path.remove(top_level_dir)
334                    else:
335                        sys.path.remove(top_level_dir)
336
337        if is_not_importable:
338            raise ImportError('Start directory is not importable: %r' % start_dir)
339
340        if not is_namespace:
341            tests = list(self._find_tests(start_dir, pattern))
342        return self.suiteClass(tests)
343
344    def _get_directory_containing_module(self, module_name):
345        module = sys.modules[module_name]
346        full_path = os.path.abspath(module.__file__)
347
348        if os.path.basename(full_path).lower().startswith('__init__.py'):
349            return os.path.dirname(os.path.dirname(full_path))
350        else:
351            # here we have been given a module rather than a package - so
352            # all we can do is search the *same* directory the module is in
353            # should an exception be raised instead
354            return os.path.dirname(full_path)
355
356    def _get_name_from_path(self, path):
357        if path == self._top_level_dir:
358            return '.'
359        path = _jython_aware_splitext(os.path.normpath(path))
360
361        _relpath = os.path.relpath(path, self._top_level_dir)
362        assert not os.path.isabs(_relpath), "Path must be within the project"
363        assert not _relpath.startswith('..'), "Path must be within the project"
364
365        name = _relpath.replace(os.path.sep, '.')
366        return name
367
368    def _get_module_from_name(self, name):
369        __import__(name)
370        return sys.modules[name]
371
372    def _match_path(self, path, full_path, pattern):
373        # override this method to use alternative matching strategy
374        return fnmatch(path, pattern)
375
376    def _find_tests(self, start_dir, pattern, namespace=False):
377        """Used by discovery. Yields test suites it loads."""
378        # Handle the __init__ in this package
379        name = self._get_name_from_path(start_dir)
380        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
381        # definition not a package).
382        if name != '.' and name not in self._loading_packages:
383            # name is in self._loading_packages while we have called into
384            # loadTestsFromModule with name.
385            tests, should_recurse = self._find_test_path(
386                start_dir, pattern, namespace)
387            if tests is not None:
388                yield tests
389            if not should_recurse:
390                # Either an error occurred, or load_tests was used by the
391                # package.
392                return
393        # Handle the contents.
394        paths = sorted(os.listdir(start_dir))
395        for path in paths:
396            full_path = os.path.join(start_dir, path)
397            tests, should_recurse = self._find_test_path(
398                full_path, pattern, namespace)
399            if tests is not None:
400                yield tests
401            if should_recurse:
402                # we found a package that didn't use load_tests.
403                name = self._get_name_from_path(full_path)
404                self._loading_packages.add(name)
405                try:
406                    yield from self._find_tests(full_path, pattern, namespace)
407                finally:
408                    self._loading_packages.discard(name)
409
410    def _find_test_path(self, full_path, pattern, namespace=False):
411        """Used by discovery.
412
413        Loads tests from a single file, or a directories' __init__.py when
414        passed the directory.
415
416        Returns a tuple (None_or_tests_from_file, should_recurse).
417        """
418        basename = os.path.basename(full_path)
419        if os.path.isfile(full_path):
420            if not VALID_MODULE_NAME.match(basename):
421                # valid Python identifiers only
422                return None, False
423            if not self._match_path(basename, full_path, pattern):
424                return None, False
425            # if the test file matches, load it
426            name = self._get_name_from_path(full_path)
427            try:
428                module = self._get_module_from_name(name)
429            except case.SkipTest as e:
430                return _make_skipped_test(name, e, self.suiteClass), False
431            except:
432                error_case, error_message = \
433                    _make_failed_import_test(name, self.suiteClass)
434                self.errors.append(error_message)
435                return error_case, False
436            else:
437                mod_file = os.path.abspath(
438                    getattr(module, '__file__', full_path))
439                realpath = _jython_aware_splitext(
440                    os.path.realpath(mod_file))
441                fullpath_noext = _jython_aware_splitext(
442                    os.path.realpath(full_path))
443                if realpath.lower() != fullpath_noext.lower():
444                    module_dir = os.path.dirname(realpath)
445                    mod_name = _jython_aware_splitext(
446                        os.path.basename(full_path))
447                    expected_dir = os.path.dirname(full_path)
448                    msg = ("%r module incorrectly imported from %r. Expected "
449                           "%r. Is this module globally installed?")
450                    raise ImportError(
451                        msg % (mod_name, module_dir, expected_dir))
452                return self.loadTestsFromModule(module, pattern=pattern), False
453        elif os.path.isdir(full_path):
454            if (not namespace and
455                not os.path.isfile(os.path.join(full_path, '__init__.py'))):
456                return None, False
457
458            load_tests = None
459            tests = None
460            name = self._get_name_from_path(full_path)
461            try:
462                package = self._get_module_from_name(name)
463            except case.SkipTest as e:
464                return _make_skipped_test(name, e, self.suiteClass), False
465            except:
466                error_case, error_message = \
467                    _make_failed_import_test(name, self.suiteClass)
468                self.errors.append(error_message)
469                return error_case, False
470            else:
471                load_tests = getattr(package, 'load_tests', None)
472                # Mark this package as being in load_tests (possibly ;))
473                self._loading_packages.add(name)
474                try:
475                    tests = self.loadTestsFromModule(package, pattern=pattern)
476                    if load_tests is not None:
477                        # loadTestsFromModule(package) has loaded tests for us.
478                        return tests, False
479                    return tests, True
480                finally:
481                    self._loading_packages.discard(name)
482        else:
483            return None, False
484
485
486defaultTestLoader = TestLoader()
487
488
489def _makeLoader(prefix, sortUsing, suiteClass=None):
490    loader = TestLoader()
491    loader.sortTestMethodsUsing = sortUsing
492    loader.testMethodPrefix = prefix
493    if suiteClass:
494        loader.suiteClass = suiteClass
495    return loader
496
497def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
498    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
499
500def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
501              suiteClass=suite.TestSuite):
502    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
503        testCaseClass)
504
505def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
506                  suiteClass=suite.TestSuite):
507    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
508        module)
509