1import contextlib
2import functools
3import sys
4import threading
5import unittest
6from test.support import import_fresh_module
7
8OS_ENV_LOCK = threading.Lock()
9TZPATH_LOCK = threading.Lock()
10TZPATH_TEST_LOCK = threading.Lock()
11
12
13def call_once(f):
14    """Decorator that ensures a function is only ever called once."""
15    lock = threading.Lock()
16    cached = functools.lru_cache(None)(f)
17
18    @functools.wraps(f)
19    def inner():
20        with lock:
21            return cached()
22
23    return inner
24
25
26@call_once
27def get_modules():
28    """Retrieve two copies of zoneinfo: pure Python and C accelerated.
29
30    Because this function manipulates the import system in a way that might
31    be fragile or do unexpected things if it is run many times, it uses a
32    `call_once` decorator to ensure that this is only ever called exactly
33    one time — in other words, when using this function you will only ever
34    get one copy of each module rather than a fresh import each time.
35    """
36    import zoneinfo as c_module
37
38    py_module = import_fresh_module("zoneinfo", blocked=["_zoneinfo"])
39
40    return py_module, c_module
41
42
43@contextlib.contextmanager
44def set_zoneinfo_module(module):
45    """Make sure sys.modules["zoneinfo"] refers to `module`.
46
47    This is necessary because `pickle` will refuse to serialize
48    an type calling itself `zoneinfo.ZoneInfo` unless `zoneinfo.ZoneInfo`
49    refers to the same object.
50    """
51
52    NOT_PRESENT = object()
53    old_zoneinfo = sys.modules.get("zoneinfo", NOT_PRESENT)
54    sys.modules["zoneinfo"] = module
55    yield
56    if old_zoneinfo is not NOT_PRESENT:
57        sys.modules["zoneinfo"] = old_zoneinfo
58    else:  # pragma: nocover
59        sys.modules.pop("zoneinfo")
60
61
62class ZoneInfoTestBase(unittest.TestCase):
63    @classmethod
64    def setUpClass(cls):
65        cls.klass = cls.module.ZoneInfo
66        super().setUpClass()
67
68    @contextlib.contextmanager
69    def tzpath_context(self, tzpath, block_tzdata=True, lock=TZPATH_LOCK):
70        def pop_tzdata_modules():
71            tzdata_modules = {}
72            for modname in list(sys.modules):
73                if modname.split(".", 1)[0] != "tzdata":  # pragma: nocover
74                    continue
75
76                tzdata_modules[modname] = sys.modules.pop(modname)
77
78            return tzdata_modules
79
80        with lock:
81            if block_tzdata:
82                # In order to fully exclude tzdata from the path, we need to
83                # clear the sys.modules cache of all its contents — setting the
84                # root package to None is not enough to block direct access of
85                # already-imported submodules (though it will prevent new
86                # imports of submodules).
87                tzdata_modules = pop_tzdata_modules()
88                sys.modules["tzdata"] = None
89
90            old_path = self.module.TZPATH
91            try:
92                self.module.reset_tzpath(tzpath)
93                yield
94            finally:
95                if block_tzdata:
96                    sys.modules.pop("tzdata")
97                    for modname, module in tzdata_modules.items():
98                        sys.modules[modname] = module
99
100                self.module.reset_tzpath(old_path)
101