1import contextlib
2import imp
3import importlib
4import sys
5import unittest
6
7
8@contextlib.contextmanager
9def uncache(*names):
10    """Uncache a module from sys.modules.
11
12    A basic sanity check is performed to prevent uncaching modules that either
13    cannot/shouldn't be uncached.
14
15    """
16    for name in names:
17        if name in ('sys', 'marshal', 'imp'):
18            raise ValueError(
19                "cannot uncache {0} as it will break _importlib".format(name))
20        try:
21            del sys.modules[name]
22        except KeyError:
23            pass
24    try:
25        yield
26    finally:
27        for name in names:
28            try:
29                del sys.modules[name]
30            except KeyError:
31                pass
32
33
34@contextlib.contextmanager
35def import_state(**kwargs):
36    """Context manager to manage the various importers and stored state in the
37    sys module.
38
39    The 'modules' attribute is not supported as the interpreter state stores a
40    pointer to the dict that the interpreter uses internally;
41    reassigning to sys.modules does not have the desired effect.
42
43    """
44    originals = {}
45    try:
46        for attr, default in (('meta_path', []), ('path', []),
47                              ('path_hooks', []),
48                              ('path_importer_cache', {})):
49            originals[attr] = getattr(sys, attr)
50            if attr in kwargs:
51                new_value = kwargs[attr]
52                del kwargs[attr]
53            else:
54                new_value = default
55            setattr(sys, attr, new_value)
56        if len(kwargs):
57            raise ValueError(
58                    'unrecognized arguments: {0}'.format(kwargs.keys()))
59        yield
60    finally:
61        for attr, value in originals.items():
62            setattr(sys, attr, value)
63
64
65class mock_modules(object):
66
67    """A mock importer/loader."""
68
69    def __init__(self, *names):
70        self.modules = {}
71        for name in names:
72            if not name.endswith('.__init__'):
73                import_name = name
74            else:
75                import_name = name[:-len('.__init__')]
76            if '.' not in name:
77                package = None
78            elif import_name == name:
79                package = name.rsplit('.', 1)[0]
80            else:
81                package = import_name
82            module = imp.new_module(import_name)
83            module.__loader__ = self
84            module.__file__ = '<mock __file__>'
85            module.__package__ = package
86            module.attr = name
87            if import_name != name:
88                module.__path__ = ['<mock __path__>']
89            self.modules[import_name] = module
90
91    def __getitem__(self, name):
92        return self.modules[name]
93
94    def find_module(self, fullname, path=None):
95        if fullname not in self.modules:
96            return None
97        else:
98            return self
99
100    def load_module(self, fullname):
101        if fullname not in self.modules:
102            raise ImportError
103        else:
104            sys.modules[fullname] = self.modules[fullname]
105            return self.modules[fullname]
106
107    def __enter__(self):
108        self._uncache = uncache(*self.modules.keys())
109        self._uncache.__enter__()
110        return self
111
112    def __exit__(self, *exc_info):
113        self._uncache.__exit__(None, None, None)
114
115
116
117class ImportModuleTests(unittest.TestCase):
118
119    """Test importlib.import_module."""
120
121    def test_module_import(self):
122        # Test importing a top-level module.
123        with mock_modules('top_level') as mock:
124            with import_state(meta_path=[mock]):
125                module = importlib.import_module('top_level')
126                self.assertEqual(module.__name__, 'top_level')
127
128    def test_absolute_package_import(self):
129        # Test importing a module from a package with an absolute name.
130        pkg_name = 'pkg'
131        pkg_long_name = '{0}.__init__'.format(pkg_name)
132        name = '{0}.mod'.format(pkg_name)
133        with mock_modules(pkg_long_name, name) as mock:
134            with import_state(meta_path=[mock]):
135                module = importlib.import_module(name)
136                self.assertEqual(module.__name__, name)
137
138    def test_shallow_relative_package_import(self):
139        modules = ['a.__init__', 'a.b.__init__', 'a.b.c.__init__', 'a.b.c.d']
140        with mock_modules(*modules) as mock:
141            with import_state(meta_path=[mock]):
142                module = importlib.import_module('.d', 'a.b.c')
143                self.assertEqual(module.__name__, 'a.b.c.d')
144
145    def test_deep_relative_package_import(self):
146        # Test importing a module from a package through a relatve import.
147        modules = ['a.__init__', 'a.b.__init__', 'a.c']
148        with mock_modules(*modules) as mock:
149            with import_state(meta_path=[mock]):
150                module = importlib.import_module('..c', 'a.b')
151                self.assertEqual(module.__name__, 'a.c')
152
153    def test_absolute_import_with_package(self):
154        # Test importing a module from a package with an absolute name with
155        # the 'package' argument given.
156        pkg_name = 'pkg'
157        pkg_long_name = '{0}.__init__'.format(pkg_name)
158        name = '{0}.mod'.format(pkg_name)
159        with mock_modules(pkg_long_name, name) as mock:
160            with import_state(meta_path=[mock]):
161                module = importlib.import_module(name, pkg_name)
162                self.assertEqual(module.__name__, name)
163
164    def test_relative_import_wo_package(self):
165        # Relative imports cannot happen without the 'package' argument being
166        # set.
167        self.assertRaises(TypeError, importlib.import_module, '.support')
168
169
170def test_main():
171    from test.test_support import run_unittest
172    run_unittest(ImportModuleTests)
173
174
175if __name__ == '__main__':
176    test_main()
177