1import unittest
2from doctest import DocTestSuite
3from test import test_support as support
4import weakref
5import gc
6
7# Modules under test
8_thread = support.import_module('thread')
9threading = support.import_module('threading')
10import _threading_local
11
12
13class Weak(object):
14    pass
15
16def target(local, weaklist):
17    weak = Weak()
18    local.weak = weak
19    weaklist.append(weakref.ref(weak))
20
21class BaseLocalTest:
22
23    def test_local_refs(self):
24        self._local_refs(20)
25        self._local_refs(50)
26        self._local_refs(100)
27
28    def _local_refs(self, n):
29        local = self._local()
30        weaklist = []
31        for i in range(n):
32            t = threading.Thread(target=target, args=(local, weaklist))
33            t.start()
34            t.join()
35        del t
36
37        gc.collect()
38        self.assertEqual(len(weaklist), n)
39
40        # XXX _threading_local keeps the local of the last stopped thread alive.
41        deadlist = [weak for weak in weaklist if weak() is None]
42        self.assertIn(len(deadlist), (n-1, n))
43
44        # Assignment to the same thread local frees it sometimes (!)
45        local.someothervar = None
46        gc.collect()
47        deadlist = [weak for weak in weaklist if weak() is None]
48        self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
49
50    def test_derived(self):
51        # Issue 3088: if there is a threads switch inside the __init__
52        # of a threading.local derived class, the per-thread dictionary
53        # is created but not correctly set on the object.
54        # The first member set may be bogus.
55        import time
56        class Local(self._local):
57            def __init__(self):
58                time.sleep(0.01)
59        local = Local()
60
61        def f(i):
62            local.x = i
63            # Simply check that the variable is correctly set
64            self.assertEqual(local.x, i)
65
66        with support.start_threads(threading.Thread(target=f, args=(i,))
67                                   for i in range(10)):
68            pass
69
70    def test_derived_cycle_dealloc(self):
71        # http://bugs.python.org/issue6990
72        class Local(self._local):
73            pass
74        locals = None
75        passed = [False]
76        e1 = threading.Event()
77        e2 = threading.Event()
78
79        def f():
80            # 1) Involve Local in a cycle
81            cycle = [Local()]
82            cycle.append(cycle)
83            cycle[0].foo = 'bar'
84
85            # 2) GC the cycle (triggers threadmodule.c::local_clear
86            # before local_dealloc)
87            del cycle
88            gc.collect()
89            e1.set()
90            e2.wait()
91
92            # 4) New Locals should be empty
93            passed[0] = all(not hasattr(local, 'foo') for local in locals)
94
95        t = threading.Thread(target=f)
96        t.start()
97        e1.wait()
98
99        # 3) New Locals should recycle the original's address. Creating
100        # them in the thread overwrites the thread state and avoids the
101        # bug
102        locals = [Local() for i in range(10)]
103        e2.set()
104        t.join()
105
106        self.assertTrue(passed[0])
107
108    def test_arguments(self):
109        # Issue 1522237
110        from thread import _local as local
111        from _threading_local import local as py_local
112
113        for cls in (local, py_local):
114            class MyLocal(cls):
115                def __init__(self, *args, **kwargs):
116                    pass
117
118            MyLocal(a=1)
119            MyLocal(1)
120            self.assertRaises(TypeError, cls, a=1)
121            self.assertRaises(TypeError, cls, 1)
122
123    def _test_one_class(self, c):
124        self._failed = "No error message set or cleared."
125        obj = c()
126        e1 = threading.Event()
127        e2 = threading.Event()
128
129        def f1():
130            obj.x = 'foo'
131            obj.y = 'bar'
132            del obj.y
133            e1.set()
134            e2.wait()
135
136        def f2():
137            try:
138                foo = obj.x
139            except AttributeError:
140                # This is expected -- we haven't set obj.x in this thread yet!
141                self._failed = ""  # passed
142            else:
143                self._failed = ('Incorrectly got value %r from class %r\n' %
144                                (foo, c))
145                sys.stderr.write(self._failed)
146
147        t1 = threading.Thread(target=f1)
148        t1.start()
149        e1.wait()
150        t2 = threading.Thread(target=f2)
151        t2.start()
152        t2.join()
153        # The test is done; just let t1 know it can exit, and wait for it.
154        e2.set()
155        t1.join()
156
157        self.assertFalse(self._failed, self._failed)
158
159    def test_threading_local(self):
160        self._test_one_class(self._local)
161
162    def test_threading_local_subclass(self):
163        class LocalSubclass(self._local):
164            """To test that subclasses behave properly."""
165        self._test_one_class(LocalSubclass)
166
167    def _test_dict_attribute(self, cls):
168        obj = cls()
169        obj.x = 5
170        self.assertEqual(obj.__dict__, {'x': 5})
171        with self.assertRaises(AttributeError):
172            obj.__dict__ = {}
173        with self.assertRaises(AttributeError):
174            del obj.__dict__
175
176    def test_dict_attribute(self):
177        self._test_dict_attribute(self._local)
178
179    def test_dict_attribute_subclass(self):
180        class LocalSubclass(self._local):
181            """To test that subclasses behave properly."""
182        self._test_dict_attribute(LocalSubclass)
183
184
185class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
186    _local = _thread._local
187
188    # Fails for the pure Python implementation
189    def test_cycle_collection(self):
190        class X:
191            pass
192
193        x = X()
194        x.local = self._local()
195        x.local.x = x
196        wr = weakref.ref(x)
197        del x
198        gc.collect()
199        self.assertIsNone(wr())
200
201class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
202    _local = _threading_local.local
203
204
205def test_main():
206    suite = unittest.TestSuite()
207    suite.addTest(DocTestSuite('_threading_local'))
208    suite.addTest(unittest.makeSuite(ThreadLocalTest))
209    suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
210
211    try:
212        from thread import _local
213    except ImportError:
214        pass
215    else:
216        import _threading_local
217        local_orig = _threading_local.local
218        def setUp(test):
219            _threading_local.local = _local
220        def tearDown(test):
221            _threading_local.local = local_orig
222        suite.addTest(DocTestSuite('_threading_local',
223                                   setUp=setUp, tearDown=tearDown)
224                      )
225
226    support.run_unittest(suite)
227
228if __name__ == '__main__':
229    test_main()
230