1import os, sys, string, random, tempfile, unittest
2
3from test.test_support import run_unittest
4
5class TestImport(unittest.TestCase):
6
7    def __init__(self, *args, **kw):
8        self.package_name = 'PACKAGE_'
9        while self.package_name in sys.modules:
10            self.package_name += random.choose(string.letters)
11        self.module_name = self.package_name + '.foo'
12        unittest.TestCase.__init__(self, *args, **kw)
13
14    def remove_modules(self):
15        for module_name in (self.package_name, self.module_name):
16            if module_name in sys.modules:
17                del sys.modules[module_name]
18
19    def setUp(self):
20        self.test_dir = tempfile.mkdtemp()
21        sys.path.append(self.test_dir)
22        self.package_dir = os.path.join(self.test_dir,
23                                        self.package_name)
24        os.mkdir(self.package_dir)
25        open(os.path.join(
26                self.package_dir, '__init__'+os.extsep+'py'), 'w').close()
27        self.module_path = os.path.join(self.package_dir, 'foo'+os.extsep+'py')
28
29    def tearDown(self):
30        for file in os.listdir(self.package_dir):
31            os.remove(os.path.join(self.package_dir, file))
32        os.rmdir(self.package_dir)
33        os.rmdir(self.test_dir)
34        self.assertNotEqual(sys.path.count(self.test_dir), 0)
35        sys.path.remove(self.test_dir)
36        self.remove_modules()
37
38    def rewrite_file(self, contents):
39        for extension in "co":
40            compiled_path = self.module_path + extension
41            if os.path.exists(compiled_path):
42                os.remove(compiled_path)
43        f = open(self.module_path, 'w')
44        f.write(contents)
45        f.close()
46
47    def test_package_import__semantics(self):
48
49        # Generate a couple of broken modules to try importing.
50
51        # ...try loading the module when there's a SyntaxError
52        self.rewrite_file('for')
53        try: __import__(self.module_name)
54        except SyntaxError: pass
55        else: raise RuntimeError, 'Failed to induce SyntaxError'
56        self.assertNotIn(self.module_name, sys.modules)
57        self.assertFalse(hasattr(sys.modules[self.package_name], 'foo'))
58
59        # ...make up a variable name that isn't bound in __builtins__
60        var = 'a'
61        while var in dir(__builtins__):
62            var += random.choose(string.letters)
63
64        # ...make a module that just contains that
65        self.rewrite_file(var)
66
67        try: __import__(self.module_name)
68        except NameError: pass
69        else: raise RuntimeError, 'Failed to induce NameError.'
70
71        # ...now  change  the module  so  that  the NameError  doesn't
72        # happen
73        self.rewrite_file('%s = 1' % var)
74        module = __import__(self.module_name).foo
75        self.assertEqual(getattr(module, var), 1)
76
77
78def test_main():
79    run_unittest(TestImport)
80
81
82if __name__ == "__main__":
83    test_main()
84