1import gc
2import io
3import os
4import sys
5import signal
6import weakref
7
8import unittest
9
10
11@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill")
12@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows")
13class TestBreak(unittest.TestCase):
14    int_handler = None
15
16    def setUp(self):
17        self._default_handler = signal.getsignal(signal.SIGINT)
18        if self.int_handler is not None:
19            signal.signal(signal.SIGINT, self.int_handler)
20
21    def tearDown(self):
22        signal.signal(signal.SIGINT, self._default_handler)
23        unittest.signals._results = weakref.WeakKeyDictionary()
24        unittest.signals._interrupt_handler = None
25
26
27    def testInstallHandler(self):
28        default_handler = signal.getsignal(signal.SIGINT)
29        unittest.installHandler()
30        self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler)
31
32        try:
33            pid = os.getpid()
34            os.kill(pid, signal.SIGINT)
35        except KeyboardInterrupt:
36            self.fail("KeyboardInterrupt not handled")
37
38        self.assertTrue(unittest.signals._interrupt_handler.called)
39
40    def testRegisterResult(self):
41        result = unittest.TestResult()
42        self.assertNotIn(result, unittest.signals._results)
43
44        unittest.registerResult(result)
45        try:
46            self.assertIn(result, unittest.signals._results)
47        finally:
48            unittest.removeResult(result)
49
50    def testInterruptCaught(self):
51        default_handler = signal.getsignal(signal.SIGINT)
52
53        result = unittest.TestResult()
54        unittest.installHandler()
55        unittest.registerResult(result)
56
57        self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler)
58
59        def test(result):
60            pid = os.getpid()
61            os.kill(pid, signal.SIGINT)
62            result.breakCaught = True
63            self.assertTrue(result.shouldStop)
64
65        try:
66            test(result)
67        except KeyboardInterrupt:
68            self.fail("KeyboardInterrupt not handled")
69        self.assertTrue(result.breakCaught)
70
71
72    def testSecondInterrupt(self):
73        # Can't use skipIf decorator because the signal handler may have
74        # been changed after defining this method.
75        if signal.getsignal(signal.SIGINT) == signal.SIG_IGN:
76            self.skipTest("test requires SIGINT to not be ignored")
77        result = unittest.TestResult()
78        unittest.installHandler()
79        unittest.registerResult(result)
80
81        def test(result):
82            pid = os.getpid()
83            os.kill(pid, signal.SIGINT)
84            result.breakCaught = True
85            self.assertTrue(result.shouldStop)
86            os.kill(pid, signal.SIGINT)
87            self.fail("Second KeyboardInterrupt not raised")
88
89        try:
90            test(result)
91        except KeyboardInterrupt:
92            pass
93        else:
94            self.fail("Second KeyboardInterrupt not raised")
95        self.assertTrue(result.breakCaught)
96
97
98    def testTwoResults(self):
99        unittest.installHandler()
100
101        result = unittest.TestResult()
102        unittest.registerResult(result)
103        new_handler = signal.getsignal(signal.SIGINT)
104
105        result2 = unittest.TestResult()
106        unittest.registerResult(result2)
107        self.assertEqual(signal.getsignal(signal.SIGINT), new_handler)
108
109        result3 = unittest.TestResult()
110
111        def test(result):
112            pid = os.getpid()
113            os.kill(pid, signal.SIGINT)
114
115        try:
116            test(result)
117        except KeyboardInterrupt:
118            self.fail("KeyboardInterrupt not handled")
119
120        self.assertTrue(result.shouldStop)
121        self.assertTrue(result2.shouldStop)
122        self.assertFalse(result3.shouldStop)
123
124
125    def testHandlerReplacedButCalled(self):
126        # Can't use skipIf decorator because the signal handler may have
127        # been changed after defining this method.
128        if signal.getsignal(signal.SIGINT) == signal.SIG_IGN:
129            self.skipTest("test requires SIGINT to not be ignored")
130        # If our handler has been replaced (is no longer installed) but is
131        # called by the *new* handler, then it isn't safe to delay the
132        # SIGINT and we should immediately delegate to the default handler
133        unittest.installHandler()
134
135        handler = signal.getsignal(signal.SIGINT)
136        def new_handler(frame, signum):
137            handler(frame, signum)
138        signal.signal(signal.SIGINT, new_handler)
139
140        try:
141            pid = os.getpid()
142            os.kill(pid, signal.SIGINT)
143        except KeyboardInterrupt:
144            pass
145        else:
146            self.fail("replaced but delegated handler doesn't raise interrupt")
147
148    def testRunner(self):
149        # Creating a TextTestRunner with the appropriate argument should
150        # register the TextTestResult it creates
151        runner = unittest.TextTestRunner(stream=io.StringIO())
152
153        result = runner.run(unittest.TestSuite())
154        self.assertIn(result, unittest.signals._results)
155
156    def testWeakReferences(self):
157        # Calling registerResult on a result should not keep it alive
158        result = unittest.TestResult()
159        unittest.registerResult(result)
160
161        ref = weakref.ref(result)
162        del result
163
164        # For non-reference counting implementations
165        gc.collect();gc.collect()
166        self.assertIsNone(ref())
167
168
169    def testRemoveResult(self):
170        result = unittest.TestResult()
171        unittest.registerResult(result)
172
173        unittest.installHandler()
174        self.assertTrue(unittest.removeResult(result))
175
176        # Should this raise an error instead?
177        self.assertFalse(unittest.removeResult(unittest.TestResult()))
178
179        try:
180            pid = os.getpid()
181            os.kill(pid, signal.SIGINT)
182        except KeyboardInterrupt:
183            pass
184
185        self.assertFalse(result.shouldStop)
186
187    def testMainInstallsHandler(self):
188        failfast = object()
189        test = object()
190        verbosity = object()
191        result = object()
192        default_handler = signal.getsignal(signal.SIGINT)
193
194        class FakeRunner(object):
195            initArgs = []
196            runArgs = []
197            def __init__(self, *args, **kwargs):
198                self.initArgs.append((args, kwargs))
199            def run(self, test):
200                self.runArgs.append(test)
201                return result
202
203        class Program(unittest.TestProgram):
204            def __init__(self, catchbreak):
205                self.exit = False
206                self.verbosity = verbosity
207                self.failfast = failfast
208                self.catchbreak = catchbreak
209                self.tb_locals = False
210                self.testRunner = FakeRunner
211                self.test = test
212                self.result = None
213
214        p = Program(False)
215        p.runTests()
216
217        self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None,
218                                                     'verbosity': verbosity,
219                                                     'failfast': failfast,
220                                                     'tb_locals': False,
221                                                     'warnings': None})])
222        self.assertEqual(FakeRunner.runArgs, [test])
223        self.assertEqual(p.result, result)
224
225        self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
226
227        FakeRunner.initArgs = []
228        FakeRunner.runArgs = []
229        p = Program(True)
230        p.runTests()
231
232        self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None,
233                                                     'verbosity': verbosity,
234                                                     'failfast': failfast,
235                                                     'tb_locals': False,
236                                                     'warnings': None})])
237        self.assertEqual(FakeRunner.runArgs, [test])
238        self.assertEqual(p.result, result)
239
240        self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler)
241
242    def testRemoveHandler(self):
243        default_handler = signal.getsignal(signal.SIGINT)
244        unittest.installHandler()
245        unittest.removeHandler()
246        self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
247
248        # check that calling removeHandler multiple times has no ill-effect
249        unittest.removeHandler()
250        self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
251
252    def testRemoveHandlerAsDecorator(self):
253        default_handler = signal.getsignal(signal.SIGINT)
254        unittest.installHandler()
255
256        @unittest.removeHandler
257        def test():
258            self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
259
260        test()
261        self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler)
262
263@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill")
264@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows")
265class TestBreakDefaultIntHandler(TestBreak):
266    int_handler = signal.default_int_handler
267
268@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill")
269@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows")
270class TestBreakSignalIgnored(TestBreak):
271    int_handler = signal.SIG_IGN
272
273@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill")
274@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows")
275class TestBreakSignalDefault(TestBreak):
276    int_handler = signal.SIG_DFL
277
278
279if __name__ == "__main__":
280    unittest.main()
281