1import functools
2import importlib.util
3import os
4import py_compile
5import shutil
6import stat
7import sys
8import tempfile
9import unittest
10
11from test import support
12
13
14def without_source_date_epoch(fxn):
15    """Runs function with SOURCE_DATE_EPOCH unset."""
16    @functools.wraps(fxn)
17    def wrapper(*args, **kwargs):
18        with support.EnvironmentVarGuard() as env:
19            env.unset('SOURCE_DATE_EPOCH')
20            return fxn(*args, **kwargs)
21    return wrapper
22
23
24def with_source_date_epoch(fxn):
25    """Runs function with SOURCE_DATE_EPOCH set."""
26    @functools.wraps(fxn)
27    def wrapper(*args, **kwargs):
28        with support.EnvironmentVarGuard() as env:
29            env['SOURCE_DATE_EPOCH'] = '123456789'
30            return fxn(*args, **kwargs)
31    return wrapper
32
33
34# Run tests with SOURCE_DATE_EPOCH set or unset explicitly.
35class SourceDateEpochTestMeta(type(unittest.TestCase)):
36    def __new__(mcls, name, bases, dct, *, source_date_epoch):
37        cls = super().__new__(mcls, name, bases, dct)
38
39        for attr in dir(cls):
40            if attr.startswith('test_'):
41                meth = getattr(cls, attr)
42                if source_date_epoch:
43                    wrapper = with_source_date_epoch(meth)
44                else:
45                    wrapper = without_source_date_epoch(meth)
46                setattr(cls, attr, wrapper)
47
48        return cls
49
50
51class PyCompileTestsBase:
52
53    def setUp(self):
54        self.directory = tempfile.mkdtemp()
55        self.source_path = os.path.join(self.directory, '_test.py')
56        self.pyc_path = self.source_path + 'c'
57        self.cache_path = importlib.util.cache_from_source(self.source_path)
58        self.cwd_drive = os.path.splitdrive(os.getcwd())[0]
59        # In these tests we compute relative paths.  When using Windows, the
60        # current working directory path and the 'self.source_path' might be
61        # on different drives.  Therefore we need to switch to the drive where
62        # the temporary source file lives.
63        drive = os.path.splitdrive(self.source_path)[0]
64        if drive:
65            os.chdir(drive)
66        with open(self.source_path, 'w') as file:
67            file.write('x = 123\n')
68
69    def tearDown(self):
70        shutil.rmtree(self.directory)
71        if self.cwd_drive:
72            os.chdir(self.cwd_drive)
73
74    def test_absolute_path(self):
75        py_compile.compile(self.source_path, self.pyc_path)
76        self.assertTrue(os.path.exists(self.pyc_path))
77        self.assertFalse(os.path.exists(self.cache_path))
78
79    def test_do_not_overwrite_symlinks(self):
80        # In the face of a cfile argument being a symlink, bail out.
81        # Issue #17222
82        try:
83            os.symlink(self.pyc_path + '.actual', self.pyc_path)
84        except (NotImplementedError, OSError):
85            self.skipTest('need to be able to create a symlink for a file')
86        else:
87            assert os.path.islink(self.pyc_path)
88            with self.assertRaises(FileExistsError):
89                py_compile.compile(self.source_path, self.pyc_path)
90
91    @unittest.skipIf(not os.path.exists(os.devnull) or os.path.isfile(os.devnull),
92                     'requires os.devnull and for it to be a non-regular file')
93    def test_do_not_overwrite_nonregular_files(self):
94        # In the face of a cfile argument being a non-regular file, bail out.
95        # Issue #17222
96        with self.assertRaises(FileExistsError):
97            py_compile.compile(self.source_path, os.devnull)
98
99    def test_cache_path(self):
100        py_compile.compile(self.source_path)
101        self.assertTrue(os.path.exists(self.cache_path))
102
103    def test_cwd(self):
104        with support.change_cwd(self.directory):
105            py_compile.compile(os.path.basename(self.source_path),
106                               os.path.basename(self.pyc_path))
107        self.assertTrue(os.path.exists(self.pyc_path))
108        self.assertFalse(os.path.exists(self.cache_path))
109
110    def test_relative_path(self):
111        py_compile.compile(os.path.relpath(self.source_path),
112                           os.path.relpath(self.pyc_path))
113        self.assertTrue(os.path.exists(self.pyc_path))
114        self.assertFalse(os.path.exists(self.cache_path))
115
116    @unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
117                     'non-root user required')
118    @unittest.skipIf(os.name == 'nt',
119                     'cannot control directory permissions on Windows')
120    def test_exceptions_propagate(self):
121        # Make sure that exceptions raised thanks to issues with writing
122        # bytecode.
123        # http://bugs.python.org/issue17244
124        mode = os.stat(self.directory)
125        os.chmod(self.directory, stat.S_IREAD)
126        try:
127            with self.assertRaises(IOError):
128                py_compile.compile(self.source_path, self.pyc_path)
129        finally:
130            os.chmod(self.directory, mode.st_mode)
131
132    def test_bad_coding(self):
133        bad_coding = os.path.join(os.path.dirname(__file__), 'bad_coding2.py')
134        with support.captured_stderr():
135            self.assertIsNone(py_compile.compile(bad_coding, doraise=False))
136        self.assertFalse(os.path.exists(
137            importlib.util.cache_from_source(bad_coding)))
138
139    def test_source_date_epoch(self):
140        py_compile.compile(self.source_path, self.pyc_path)
141        self.assertTrue(os.path.exists(self.pyc_path))
142        self.assertFalse(os.path.exists(self.cache_path))
143        with open(self.pyc_path, 'rb') as fp:
144            flags = importlib._bootstrap_external._classify_pyc(
145                fp.read(), 'test', {})
146        if os.environ.get('SOURCE_DATE_EPOCH'):
147            expected_flags = 0b11
148        else:
149            expected_flags = 0b00
150
151        self.assertEqual(flags, expected_flags)
152
153    @unittest.skipIf(sys.flags.optimize > 0, 'test does not work with -O')
154    def test_double_dot_no_clobber(self):
155        # http://bugs.python.org/issue22966
156        # py_compile foo.bar.py -> __pycache__/foo.cpython-34.pyc
157        weird_path = os.path.join(self.directory, 'foo.bar.py')
158        cache_path = importlib.util.cache_from_source(weird_path)
159        pyc_path = weird_path + 'c'
160        head, tail = os.path.split(cache_path)
161        penultimate_tail = os.path.basename(head)
162        self.assertEqual(
163            os.path.join(penultimate_tail, tail),
164            os.path.join(
165                '__pycache__',
166                'foo.bar.{}.pyc'.format(sys.implementation.cache_tag)))
167        with open(weird_path, 'w') as file:
168            file.write('x = 123\n')
169        py_compile.compile(weird_path)
170        self.assertTrue(os.path.exists(cache_path))
171        self.assertFalse(os.path.exists(pyc_path))
172
173    def test_optimization_path(self):
174        # Specifying optimized bytecode should lead to a path reflecting that.
175        self.assertIn('opt-2', py_compile.compile(self.source_path, optimize=2))
176
177    def test_invalidation_mode(self):
178        py_compile.compile(
179            self.source_path,
180            invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
181        )
182        with open(self.cache_path, 'rb') as fp:
183            flags = importlib._bootstrap_external._classify_pyc(
184                fp.read(), 'test', {})
185        self.assertEqual(flags, 0b11)
186        py_compile.compile(
187            self.source_path,
188            invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH,
189        )
190        with open(self.cache_path, 'rb') as fp:
191            flags = importlib._bootstrap_external._classify_pyc(
192                fp.read(), 'test', {})
193        self.assertEqual(flags, 0b1)
194
195
196class PyCompileTestsWithSourceEpoch(PyCompileTestsBase,
197                                    unittest.TestCase,
198                                    metaclass=SourceDateEpochTestMeta,
199                                    source_date_epoch=True):
200    pass
201
202
203class PyCompileTestsWithoutSourceEpoch(PyCompileTestsBase,
204                                       unittest.TestCase,
205                                       metaclass=SourceDateEpochTestMeta,
206                                       source_date_epoch=False):
207    pass
208
209
210if __name__ == "__main__":
211    unittest.main()
212