1from . import util as test_util
2
3init = test_util.import_importlib('importlib')
4
5import sys
6import threading
7import weakref
8
9from test import support
10from test import lock_tests
11
12
13class ModuleLockAsRLockTests:
14    locktype = classmethod(lambda cls: cls.LockType("some_lock"))
15
16    # _is_owned() unsupported
17    test__is_owned = None
18    # acquire(blocking=False) unsupported
19    test_try_acquire = None
20    test_try_acquire_contended = None
21    # `with` unsupported
22    test_with = None
23    # acquire(timeout=...) unsupported
24    test_timeout = None
25    # _release_save() unsupported
26    test_release_save_unacquired = None
27    # lock status in repr unsupported
28    test_repr = None
29    test_locked_repr = None
30
31LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock
32              for kind, splitinit in init.items()}
33
34(Frozen_ModuleLockAsRLockTests,
35 Source_ModuleLockAsRLockTests
36 ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests,
37                         LockType=LOCK_TYPES)
38
39
40class DeadlockAvoidanceTests:
41
42    def setUp(self):
43        try:
44            self.old_switchinterval = sys.getswitchinterval()
45            support.setswitchinterval(0.000001)
46        except AttributeError:
47            self.old_switchinterval = None
48
49    def tearDown(self):
50        if self.old_switchinterval is not None:
51            sys.setswitchinterval(self.old_switchinterval)
52
53    def run_deadlock_avoidance_test(self, create_deadlock):
54        NLOCKS = 10
55        locks = [self.LockType(str(i)) for i in range(NLOCKS)]
56        pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)]
57        if create_deadlock:
58            NTHREADS = NLOCKS
59        else:
60            NTHREADS = NLOCKS - 1
61        barrier = threading.Barrier(NTHREADS)
62        results = []
63
64        def _acquire(lock):
65            """Try to acquire the lock. Return True on success,
66            False on deadlock."""
67            try:
68                lock.acquire()
69            except self.DeadlockError:
70                return False
71            else:
72                return True
73
74        def f():
75            a, b = pairs.pop()
76            ra = _acquire(a)
77            barrier.wait()
78            rb = _acquire(b)
79            results.append((ra, rb))
80            if rb:
81                b.release()
82            if ra:
83                a.release()
84        lock_tests.Bunch(f, NTHREADS).wait_for_finished()
85        self.assertEqual(len(results), NTHREADS)
86        return results
87
88    def test_deadlock(self):
89        results = self.run_deadlock_avoidance_test(True)
90        # At least one of the threads detected a potential deadlock on its
91        # second acquire() call.  It may be several of them, because the
92        # deadlock avoidance mechanism is conservative.
93        nb_deadlocks = results.count((True, False))
94        self.assertGreaterEqual(nb_deadlocks, 1)
95        self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks)
96
97    def test_no_deadlock(self):
98        results = self.run_deadlock_avoidance_test(False)
99        self.assertEqual(results.count((True, False)), 0)
100        self.assertEqual(results.count((True, True)), len(results))
101
102
103DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError
104                   for kind, splitinit in init.items()}
105
106(Frozen_DeadlockAvoidanceTests,
107 Source_DeadlockAvoidanceTests
108 ) = test_util.test_both(DeadlockAvoidanceTests,
109                         LockType=LOCK_TYPES,
110                         DeadlockError=DEADLOCK_ERRORS)
111
112
113class LifetimeTests:
114
115    @property
116    def bootstrap(self):
117        return self.init._bootstrap
118
119    def test_lock_lifetime(self):
120        name = "xyzzy"
121        self.assertNotIn(name, self.bootstrap._module_locks)
122        lock = self.bootstrap._get_module_lock(name)
123        self.assertIn(name, self.bootstrap._module_locks)
124        wr = weakref.ref(lock)
125        del lock
126        support.gc_collect()
127        self.assertNotIn(name, self.bootstrap._module_locks)
128        self.assertIsNone(wr())
129
130    def test_all_locks(self):
131        support.gc_collect()
132        self.assertEqual(0, len(self.bootstrap._module_locks),
133                         self.bootstrap._module_locks)
134
135
136(Frozen_LifetimeTests,
137 Source_LifetimeTests
138 ) = test_util.test_both(LifetimeTests, init=init)
139
140
141@support.reap_threads
142def test_main():
143    support.run_unittest(Frozen_ModuleLockAsRLockTests,
144                         Source_ModuleLockAsRLockTests,
145                         Frozen_DeadlockAvoidanceTests,
146                         Source_DeadlockAvoidanceTests,
147                         Frozen_LifetimeTests,
148                         Source_LifetimeTests)
149
150
151if __name__ == '__main__':
152    test_main()
153