1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8
9from functools import cmp_to_key as _CmpToKey
10from fnmatch import fnmatch
11
12from . import case, suite
13
14__unittest = True
15
16# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
23    message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
24    return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25                             suiteClass)
26
27def _make_failed_load_tests(name, exception, suiteClass):
28    return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
29
30def _make_failed_test(classname, methodname, exception, suiteClass):
31    def testFailure(self):
32        raise exception
33    attrs = {methodname: testFailure}
34    TestClass = type(classname, (case.TestCase,), attrs)
35    return suiteClass((TestClass(methodname),))
36
37
38class TestLoader(object):
39    """
40    This class is responsible for loading tests according to various criteria
41    and returning them wrapped in a TestSuite
42    """
43    testMethodPrefix = 'test'
44    sortTestMethodsUsing = cmp
45    suiteClass = suite.TestSuite
46    _top_level_dir = None
47
48    def loadTestsFromTestCase(self, testCaseClass):
49        """Return a suite of all tests cases contained in testCaseClass"""
50        if issubclass(testCaseClass, suite.TestSuite):
51            raise TypeError("Test cases should not be derived from TestSuite." \
52                                " Maybe you meant to derive from TestCase?")
53        testCaseNames = self.getTestCaseNames(testCaseClass)
54        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
55            testCaseNames = ['runTest']
56        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
57        return loaded_suite
58
59    def loadTestsFromModule(self, module, use_load_tests=True):
60        """Return a suite of all tests cases contained in the given module"""
61        tests = []
62        for name in dir(module):
63            obj = getattr(module, name)
64            if isinstance(obj, type) and issubclass(obj, case.TestCase):
65                tests.append(self.loadTestsFromTestCase(obj))
66
67        load_tests = getattr(module, 'load_tests', None)
68        tests = self.suiteClass(tests)
69        if use_load_tests and load_tests is not None:
70            try:
71                return load_tests(self, tests, None)
72            except Exception, e:
73                return _make_failed_load_tests(module.__name__, e,
74                                               self.suiteClass)
75        return tests
76
77    def loadTestsFromName(self, name, module=None):
78        """Return a suite of all tests cases given a string specifier.
79
80        The name may resolve either to a module, a test case class, a
81        test method within a test case class, or a callable object which
82        returns a TestCase or TestSuite instance.
83
84        The method optionally resolves the names relative to a given module.
85        """
86        parts = name.split('.')
87        if module is None:
88            parts_copy = parts[:]
89            while parts_copy:
90                try:
91                    module = __import__('.'.join(parts_copy))
92                    break
93                except ImportError:
94                    del parts_copy[-1]
95                    if not parts_copy:
96                        raise
97            parts = parts[1:]
98        obj = module
99        for part in parts:
100            parent, obj = obj, getattr(obj, part)
101
102        if isinstance(obj, types.ModuleType):
103            return self.loadTestsFromModule(obj)
104        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
105            return self.loadTestsFromTestCase(obj)
106        elif (isinstance(obj, types.UnboundMethodType) and
107              isinstance(parent, type) and
108              issubclass(parent, case.TestCase)):
109            return self.suiteClass([parent(obj.__name__)])
110        elif isinstance(obj, suite.TestSuite):
111            return obj
112        elif hasattr(obj, '__call__'):
113            test = obj()
114            if isinstance(test, suite.TestSuite):
115                return test
116            elif isinstance(test, case.TestCase):
117                return self.suiteClass([test])
118            else:
119                raise TypeError("calling %s returned %s, not a test" %
120                                (obj, test))
121        else:
122            raise TypeError("don't know how to make test from: %s" % obj)
123
124    def loadTestsFromNames(self, names, module=None):
125        """Return a suite of all tests cases found using the given sequence
126        of string specifiers. See 'loadTestsFromName()'.
127        """
128        suites = [self.loadTestsFromName(name, module) for name in names]
129        return self.suiteClass(suites)
130
131    def getTestCaseNames(self, testCaseClass):
132        """Return a sorted sequence of method names found within testCaseClass
133        """
134        def isTestMethod(attrname, testCaseClass=testCaseClass,
135                         prefix=self.testMethodPrefix):
136            return attrname.startswith(prefix) and \
137                hasattr(getattr(testCaseClass, attrname), '__call__')
138        testFnNames = filter(isTestMethod, dir(testCaseClass))
139        if self.sortTestMethodsUsing:
140            testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
141        return testFnNames
142
143    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
144        """Find and return all test modules from the specified start
145        directory, recursing into subdirectories to find them. Only test files
146        that match the pattern will be loaded. (Using shell style pattern
147        matching.)
148
149        All test modules must be importable from the top level of the project.
150        If the start directory is not the top level directory then the top
151        level directory must be specified separately.
152
153        If a test package name (directory with '__init__.py') matches the
154        pattern then the package will be checked for a 'load_tests' function. If
155        this exists then it will be called with loader, tests, pattern.
156
157        If load_tests exists then discovery does  *not* recurse into the package,
158        load_tests is responsible for loading all tests in the package.
159
160        The pattern is deliberately not stored as a loader attribute so that
161        packages can continue discovery themselves. top_level_dir is stored so
162        load_tests does not need to pass this argument in to loader.discover().
163        """
164        set_implicit_top = False
165        if top_level_dir is None and self._top_level_dir is not None:
166            # make top_level_dir optional if called from load_tests in a package
167            top_level_dir = self._top_level_dir
168        elif top_level_dir is None:
169            set_implicit_top = True
170            top_level_dir = start_dir
171
172        top_level_dir = os.path.abspath(top_level_dir)
173
174        if not top_level_dir in sys.path:
175            # all test modules must be importable from the top level directory
176            # should we *unconditionally* put the start directory in first
177            # in sys.path to minimise likelihood of conflicts between installed
178            # modules and development versions?
179            sys.path.insert(0, top_level_dir)
180        self._top_level_dir = top_level_dir
181
182        is_not_importable = False
183        if os.path.isdir(os.path.abspath(start_dir)):
184            start_dir = os.path.abspath(start_dir)
185            if start_dir != top_level_dir:
186                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
187        else:
188            # support for discovery from dotted module names
189            try:
190                __import__(start_dir)
191            except ImportError:
192                is_not_importable = True
193            else:
194                the_module = sys.modules[start_dir]
195                top_part = start_dir.split('.')[0]
196                start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
197                if set_implicit_top:
198                    self._top_level_dir = self._get_directory_containing_module(top_part)
199                    sys.path.remove(top_level_dir)
200
201        if is_not_importable:
202            raise ImportError('Start directory is not importable: %r' % start_dir)
203
204        tests = list(self._find_tests(start_dir, pattern))
205        return self.suiteClass(tests)
206
207    def _get_directory_containing_module(self, module_name):
208        module = sys.modules[module_name]
209        full_path = os.path.abspath(module.__file__)
210
211        if os.path.basename(full_path).lower().startswith('__init__.py'):
212            return os.path.dirname(os.path.dirname(full_path))
213        else:
214            # here we have been given a module rather than a package - so
215            # all we can do is search the *same* directory the module is in
216            # should an exception be raised instead
217            return os.path.dirname(full_path)
218
219    def _get_name_from_path(self, path):
220        path = os.path.splitext(os.path.normpath(path))[0]
221
222        _relpath = os.path.relpath(path, self._top_level_dir)
223        assert not os.path.isabs(_relpath), "Path must be within the project"
224        assert not _relpath.startswith('..'), "Path must be within the project"
225
226        name = _relpath.replace(os.path.sep, '.')
227        return name
228
229    def _get_module_from_name(self, name):
230        __import__(name)
231        return sys.modules[name]
232
233    def _match_path(self, path, full_path, pattern):
234        # override this method to use alternative matching strategy
235        return fnmatch(path, pattern)
236
237    def _find_tests(self, start_dir, pattern):
238        """Used by discovery. Yields test suites it loads."""
239        paths = os.listdir(start_dir)
240
241        for path in paths:
242            full_path = os.path.join(start_dir, path)
243            if os.path.isfile(full_path):
244                if not VALID_MODULE_NAME.match(path):
245                    # valid Python identifiers only
246                    continue
247                if not self._match_path(path, full_path, pattern):
248                    continue
249                # if the test file matches, load it
250                name = self._get_name_from_path(full_path)
251                try:
252                    module = self._get_module_from_name(name)
253                except:
254                    yield _make_failed_import_test(name, self.suiteClass)
255                else:
256                    mod_file = os.path.abspath(getattr(module, '__file__', full_path))
257                    realpath = os.path.splitext(mod_file)[0]
258                    fullpath_noext = os.path.splitext(full_path)[0]
259                    if realpath.lower() != fullpath_noext.lower():
260                        module_dir = os.path.dirname(realpath)
261                        mod_name = os.path.splitext(os.path.basename(full_path))[0]
262                        expected_dir = os.path.dirname(full_path)
263                        msg = ("%r module incorrectly imported from %r. Expected %r. "
264                               "Is this module globally installed?")
265                        raise ImportError(msg % (mod_name, module_dir, expected_dir))
266                    yield self.loadTestsFromModule(module)
267            elif os.path.isdir(full_path):
268                if not os.path.isfile(os.path.join(full_path, '__init__.py')):
269                    continue
270
271                load_tests = None
272                tests = None
273                if fnmatch(path, pattern):
274                    # only check load_tests if the package directory itself matches the filter
275                    name = self._get_name_from_path(full_path)
276                    package = self._get_module_from_name(name)
277                    load_tests = getattr(package, 'load_tests', None)
278                    tests = self.loadTestsFromModule(package, use_load_tests=False)
279
280                if load_tests is None:
281                    if tests is not None:
282                        # tests loaded from package file
283                        yield tests
284                    # recurse into the package
285                    for test in self._find_tests(full_path, pattern):
286                        yield test
287                else:
288                    try:
289                        yield load_tests(self, tests, pattern)
290                    except Exception, e:
291                        yield _make_failed_load_tests(package.__name__, e,
292                                                      self.suiteClass)
293
294defaultTestLoader = TestLoader()
295
296
297def _makeLoader(prefix, sortUsing, suiteClass=None):
298    loader = TestLoader()
299    loader.sortTestMethodsUsing = sortUsing
300    loader.testMethodPrefix = prefix
301    if suiteClass:
302        loader.suiteClass = suiteClass
303    return loader
304
305def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
306    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
307
308def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
309              suiteClass=suite.TestSuite):
310    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
311
312def findTestCases(module, prefix='test', sortUsing=cmp,
313                  suiteClass=suite.TestSuite):
314    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
315