1import unittest
2from doctest import DocTestSuite
3from test import support
4import weakref
5import gc
6
7# Modules under test
8_thread = support.import_module('_thread')
9threading = support.import_module('threading')
10import _threading_local
11
12
13class Weak(object):
14    pass
15
16def target(local, weaklist):
17    weak = Weak()
18    local.weak = weak
19    weaklist.append(weakref.ref(weak))
20
21
22class BaseLocalTest:
23
24    def test_local_refs(self):
25        self._local_refs(20)
26        self._local_refs(50)
27        self._local_refs(100)
28
29    def _local_refs(self, n):
30        local = self._local()
31        weaklist = []
32        for i in range(n):
33            t = threading.Thread(target=target, args=(local, weaklist))
34            t.start()
35            t.join()
36        del t
37
38        gc.collect()
39        self.assertEqual(len(weaklist), n)
40
41        # XXX _threading_local keeps the local of the last stopped thread alive.
42        deadlist = [weak for weak in weaklist if weak() is None]
43        self.assertIn(len(deadlist), (n-1, n))
44
45        # Assignment to the same thread local frees it sometimes (!)
46        local.someothervar = None
47        gc.collect()
48        deadlist = [weak for weak in weaklist if weak() is None]
49        self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
50
51    def test_derived(self):
52        # Issue 3088: if there is a threads switch inside the __init__
53        # of a threading.local derived class, the per-thread dictionary
54        # is created but not correctly set on the object.
55        # The first member set may be bogus.
56        import time
57        class Local(self._local):
58            def __init__(self):
59                time.sleep(0.01)
60        local = Local()
61
62        def f(i):
63            local.x = i
64            # Simply check that the variable is correctly set
65            self.assertEqual(local.x, i)
66
67        with support.start_threads(threading.Thread(target=f, args=(i,))
68                                   for i in range(10)):
69            pass
70
71    def test_derived_cycle_dealloc(self):
72        # http://bugs.python.org/issue6990
73        class Local(self._local):
74            pass
75        locals = None
76        passed = False
77        e1 = threading.Event()
78        e2 = threading.Event()
79
80        def f():
81            nonlocal passed
82            # 1) Involve Local in a cycle
83            cycle = [Local()]
84            cycle.append(cycle)
85            cycle[0].foo = 'bar'
86
87            # 2) GC the cycle (triggers threadmodule.c::local_clear
88            # before local_dealloc)
89            del cycle
90            gc.collect()
91            e1.set()
92            e2.wait()
93
94            # 4) New Locals should be empty
95            passed = all(not hasattr(local, 'foo') for local in locals)
96
97        t = threading.Thread(target=f)
98        t.start()
99        e1.wait()
100
101        # 3) New Locals should recycle the original's address. Creating
102        # them in the thread overwrites the thread state and avoids the
103        # bug
104        locals = [Local() for i in range(10)]
105        e2.set()
106        t.join()
107
108        self.assertTrue(passed)
109
110    def test_arguments(self):
111        # Issue 1522237
112        class MyLocal(self._local):
113            def __init__(self, *args, **kwargs):
114                pass
115
116        MyLocal(a=1)
117        MyLocal(1)
118        self.assertRaises(TypeError, self._local, a=1)
119        self.assertRaises(TypeError, self._local, 1)
120
121    def _test_one_class(self, c):
122        self._failed = "No error message set or cleared."
123        obj = c()
124        e1 = threading.Event()
125        e2 = threading.Event()
126
127        def f1():
128            obj.x = 'foo'
129            obj.y = 'bar'
130            del obj.y
131            e1.set()
132            e2.wait()
133
134        def f2():
135            try:
136                foo = obj.x
137            except AttributeError:
138                # This is expected -- we haven't set obj.x in this thread yet!
139                self._failed = ""  # passed
140            else:
141                self._failed = ('Incorrectly got value %r from class %r\n' %
142                                (foo, c))
143                sys.stderr.write(self._failed)
144
145        t1 = threading.Thread(target=f1)
146        t1.start()
147        e1.wait()
148        t2 = threading.Thread(target=f2)
149        t2.start()
150        t2.join()
151        # The test is done; just let t1 know it can exit, and wait for it.
152        e2.set()
153        t1.join()
154
155        self.assertFalse(self._failed, self._failed)
156
157    def test_threading_local(self):
158        self._test_one_class(self._local)
159
160    def test_threading_local_subclass(self):
161        class LocalSubclass(self._local):
162            """To test that subclasses behave properly."""
163        self._test_one_class(LocalSubclass)
164
165    def _test_dict_attribute(self, cls):
166        obj = cls()
167        obj.x = 5
168        self.assertEqual(obj.__dict__, {'x': 5})
169        with self.assertRaises(AttributeError):
170            obj.__dict__ = {}
171        with self.assertRaises(AttributeError):
172            del obj.__dict__
173
174    def test_dict_attribute(self):
175        self._test_dict_attribute(self._local)
176
177    def test_dict_attribute_subclass(self):
178        class LocalSubclass(self._local):
179            """To test that subclasses behave properly."""
180        self._test_dict_attribute(LocalSubclass)
181
182    def test_cycle_collection(self):
183        class X:
184            pass
185
186        x = X()
187        x.local = self._local()
188        x.local.x = x
189        wr = weakref.ref(x)
190        del x
191        gc.collect()
192        self.assertIsNone(wr())
193
194
195class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
196    _local = _thread._local
197
198class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
199    _local = _threading_local.local
200
201
202def test_main():
203    suite = unittest.TestSuite()
204    suite.addTest(DocTestSuite('_threading_local'))
205    suite.addTest(unittest.makeSuite(ThreadLocalTest))
206    suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
207
208    local_orig = _threading_local.local
209    def setUp(test):
210        _threading_local.local = _thread._local
211    def tearDown(test):
212        _threading_local.local = local_orig
213    suite.addTest(DocTestSuite('_threading_local',
214                               setUp=setUp, tearDown=tearDown)
215                  )
216
217    support.run_unittest(suite)
218
219if __name__ == '__main__':
220    test_main()
221