1import sys
2import textwrap
3from StringIO import StringIO
4from test import test_support
5
6import traceback
7import unittest
8
9
10class Test_TestResult(unittest.TestCase):
11    # Note: there are not separate tests for TestResult.wasSuccessful(),
12    # TestResult.errors, TestResult.failures, TestResult.testsRun or
13    # TestResult.shouldStop because these only have meaning in terms of
14    # other TestResult methods.
15    #
16    # Accordingly, tests for the aforenamed attributes are incorporated
17    # in with the tests for the defining methods.
18    ################################################################
19
20    def test_init(self):
21        result = unittest.TestResult()
22
23        self.assertTrue(result.wasSuccessful())
24        self.assertEqual(len(result.errors), 0)
25        self.assertEqual(len(result.failures), 0)
26        self.assertEqual(result.testsRun, 0)
27        self.assertEqual(result.shouldStop, False)
28        self.assertIsNone(result._stdout_buffer)
29        self.assertIsNone(result._stderr_buffer)
30
31
32    # "This method can be called to signal that the set of tests being
33    # run should be aborted by setting the TestResult's shouldStop
34    # attribute to True."
35    def test_stop(self):
36        result = unittest.TestResult()
37
38        result.stop()
39
40        self.assertEqual(result.shouldStop, True)
41
42    # "Called when the test case test is about to be run. The default
43    # implementation simply increments the instance's testsRun counter."
44    def test_startTest(self):
45        class Foo(unittest.TestCase):
46            def test_1(self):
47                pass
48
49        test = Foo('test_1')
50
51        result = unittest.TestResult()
52
53        result.startTest(test)
54
55        self.assertTrue(result.wasSuccessful())
56        self.assertEqual(len(result.errors), 0)
57        self.assertEqual(len(result.failures), 0)
58        self.assertEqual(result.testsRun, 1)
59        self.assertEqual(result.shouldStop, False)
60
61        result.stopTest(test)
62
63    # "Called after the test case test has been executed, regardless of
64    # the outcome. The default implementation does nothing."
65    def test_stopTest(self):
66        class Foo(unittest.TestCase):
67            def test_1(self):
68                pass
69
70        test = Foo('test_1')
71
72        result = unittest.TestResult()
73
74        result.startTest(test)
75
76        self.assertTrue(result.wasSuccessful())
77        self.assertEqual(len(result.errors), 0)
78        self.assertEqual(len(result.failures), 0)
79        self.assertEqual(result.testsRun, 1)
80        self.assertEqual(result.shouldStop, False)
81
82        result.stopTest(test)
83
84        # Same tests as above; make sure nothing has changed
85        self.assertTrue(result.wasSuccessful())
86        self.assertEqual(len(result.errors), 0)
87        self.assertEqual(len(result.failures), 0)
88        self.assertEqual(result.testsRun, 1)
89        self.assertEqual(result.shouldStop, False)
90
91    # "Called before and after tests are run. The default implementation does nothing."
92    def test_startTestRun_stopTestRun(self):
93        result = unittest.TestResult()
94        result.startTestRun()
95        result.stopTestRun()
96
97    # "addSuccess(test)"
98    # ...
99    # "Called when the test case test succeeds"
100    # ...
101    # "wasSuccessful() - Returns True if all tests run so far have passed,
102    # otherwise returns False"
103    # ...
104    # "testsRun - The total number of tests run so far."
105    # ...
106    # "errors - A list containing 2-tuples of TestCase instances and
107    # formatted tracebacks. Each tuple represents a test which raised an
108    # unexpected exception. Contains formatted
109    # tracebacks instead of sys.exc_info() results."
110    # ...
111    # "failures - A list containing 2-tuples of TestCase instances and
112    # formatted tracebacks. Each tuple represents a test where a failure was
113    # explicitly signalled using the TestCase.fail*() or TestCase.assert*()
114    # methods. Contains formatted tracebacks instead
115    # of sys.exc_info() results."
116    def test_addSuccess(self):
117        class Foo(unittest.TestCase):
118            def test_1(self):
119                pass
120
121        test = Foo('test_1')
122
123        result = unittest.TestResult()
124
125        result.startTest(test)
126        result.addSuccess(test)
127        result.stopTest(test)
128
129        self.assertTrue(result.wasSuccessful())
130        self.assertEqual(len(result.errors), 0)
131        self.assertEqual(len(result.failures), 0)
132        self.assertEqual(result.testsRun, 1)
133        self.assertEqual(result.shouldStop, False)
134
135    # "addFailure(test, err)"
136    # ...
137    # "Called when the test case test signals a failure. err is a tuple of
138    # the form returned by sys.exc_info(): (type, value, traceback)"
139    # ...
140    # "wasSuccessful() - Returns True if all tests run so far have passed,
141    # otherwise returns False"
142    # ...
143    # "testsRun - The total number of tests run so far."
144    # ...
145    # "errors - A list containing 2-tuples of TestCase instances and
146    # formatted tracebacks. Each tuple represents a test which raised an
147    # unexpected exception. Contains formatted
148    # tracebacks instead of sys.exc_info() results."
149    # ...
150    # "failures - A list containing 2-tuples of TestCase instances and
151    # formatted tracebacks. Each tuple represents a test where a failure was
152    # explicitly signalled using the TestCase.fail*() or TestCase.assert*()
153    # methods. Contains formatted tracebacks instead
154    # of sys.exc_info() results."
155    def test_addFailure(self):
156        class Foo(unittest.TestCase):
157            def test_1(self):
158                pass
159
160        test = Foo('test_1')
161        try:
162            test.fail("foo")
163        except:
164            exc_info_tuple = sys.exc_info()
165
166        result = unittest.TestResult()
167
168        result.startTest(test)
169        result.addFailure(test, exc_info_tuple)
170        result.stopTest(test)
171
172        self.assertFalse(result.wasSuccessful())
173        self.assertEqual(len(result.errors), 0)
174        self.assertEqual(len(result.failures), 1)
175        self.assertEqual(result.testsRun, 1)
176        self.assertEqual(result.shouldStop, False)
177
178        test_case, formatted_exc = result.failures[0]
179        self.assertIs(test_case, test)
180        self.assertIsInstance(formatted_exc, str)
181
182    # "addError(test, err)"
183    # ...
184    # "Called when the test case test raises an unexpected exception err
185    # is a tuple of the form returned by sys.exc_info():
186    # (type, value, traceback)"
187    # ...
188    # "wasSuccessful() - Returns True if all tests run so far have passed,
189    # otherwise returns False"
190    # ...
191    # "testsRun - The total number of tests run so far."
192    # ...
193    # "errors - A list containing 2-tuples of TestCase instances and
194    # formatted tracebacks. Each tuple represents a test which raised an
195    # unexpected exception. Contains formatted
196    # tracebacks instead of sys.exc_info() results."
197    # ...
198    # "failures - A list containing 2-tuples of TestCase instances and
199    # formatted tracebacks. Each tuple represents a test where a failure was
200    # explicitly signalled using the TestCase.fail*() or TestCase.assert*()
201    # methods. Contains formatted tracebacks instead
202    # of sys.exc_info() results."
203    def test_addError(self):
204        class Foo(unittest.TestCase):
205            def test_1(self):
206                pass
207
208        test = Foo('test_1')
209        try:
210            raise TypeError()
211        except:
212            exc_info_tuple = sys.exc_info()
213
214        result = unittest.TestResult()
215
216        result.startTest(test)
217        result.addError(test, exc_info_tuple)
218        result.stopTest(test)
219
220        self.assertFalse(result.wasSuccessful())
221        self.assertEqual(len(result.errors), 1)
222        self.assertEqual(len(result.failures), 0)
223        self.assertEqual(result.testsRun, 1)
224        self.assertEqual(result.shouldStop, False)
225
226        test_case, formatted_exc = result.errors[0]
227        self.assertIs(test_case, test)
228        self.assertIsInstance(formatted_exc, str)
229
230    def testGetDescriptionWithoutDocstring(self):
231        result = unittest.TextTestResult(None, True, 1)
232        self.assertEqual(
233                result.getDescription(self),
234                'testGetDescriptionWithoutDocstring (' + __name__ +
235                '.Test_TestResult)')
236
237    @unittest.skipIf(sys.flags.optimize >= 2,
238                     "Docstrings are omitted with -O2 and above")
239    def testGetDescriptionWithOneLineDocstring(self):
240        """Tests getDescription() for a method with a docstring."""
241        result = unittest.TextTestResult(None, True, 1)
242        self.assertEqual(
243                result.getDescription(self),
244               ('testGetDescriptionWithOneLineDocstring '
245                '(' + __name__ + '.Test_TestResult)\n'
246                'Tests getDescription() for a method with a docstring.'))
247
248    @unittest.skipIf(sys.flags.optimize >= 2,
249                     "Docstrings are omitted with -O2 and above")
250    def testGetDescriptionWithMultiLineDocstring(self):
251        """Tests getDescription() for a method with a longer docstring.
252        The second line of the docstring.
253        """
254        result = unittest.TextTestResult(None, True, 1)
255        self.assertEqual(
256                result.getDescription(self),
257               ('testGetDescriptionWithMultiLineDocstring '
258                '(' + __name__ + '.Test_TestResult)\n'
259                'Tests getDescription() for a method with a longer '
260                'docstring.'))
261
262    def testStackFrameTrimming(self):
263        class Frame(object):
264            class tb_frame(object):
265                f_globals = {}
266        result = unittest.TestResult()
267        self.assertFalse(result._is_relevant_tb_level(Frame))
268
269        Frame.tb_frame.f_globals['__unittest'] = True
270        self.assertTrue(result._is_relevant_tb_level(Frame))
271
272    def testFailFast(self):
273        result = unittest.TestResult()
274        result._exc_info_to_string = lambda *_: ''
275        result.failfast = True
276        result.addError(None, None)
277        self.assertTrue(result.shouldStop)
278
279        result = unittest.TestResult()
280        result._exc_info_to_string = lambda *_: ''
281        result.failfast = True
282        result.addFailure(None, None)
283        self.assertTrue(result.shouldStop)
284
285        result = unittest.TestResult()
286        result._exc_info_to_string = lambda *_: ''
287        result.failfast = True
288        result.addUnexpectedSuccess(None)
289        self.assertTrue(result.shouldStop)
290
291    def testFailFastSetByRunner(self):
292        runner = unittest.TextTestRunner(stream=StringIO(), failfast=True)
293        def test(result):
294            self.assertTrue(result.failfast)
295        runner.run(test)
296
297
298classDict = dict(unittest.TestResult.__dict__)
299for m in ('addSkip', 'addExpectedFailure', 'addUnexpectedSuccess',
300           '__init__'):
301    del classDict[m]
302
303def __init__(self, stream=None, descriptions=None, verbosity=None):
304    self.failures = []
305    self.errors = []
306    self.testsRun = 0
307    self.shouldStop = False
308    self.buffer = False
309
310classDict['__init__'] = __init__
311OldResult = type('OldResult', (object,), classDict)
312
313class Test_OldTestResult(unittest.TestCase):
314
315    def assertOldResultWarning(self, test, failures):
316        with test_support.check_warnings(("TestResult has no add.+ method,",
317                                          RuntimeWarning)):
318            result = OldResult()
319            test.run(result)
320            self.assertEqual(len(result.failures), failures)
321
322    def testOldTestResult(self):
323        class Test(unittest.TestCase):
324            def testSkip(self):
325                self.skipTest('foobar')
326            @unittest.expectedFailure
327            def testExpectedFail(self):
328                raise TypeError
329            @unittest.expectedFailure
330            def testUnexpectedSuccess(self):
331                pass
332
333        for test_name, should_pass in (('testSkip', True),
334                                       ('testExpectedFail', True),
335                                       ('testUnexpectedSuccess', False)):
336            test = Test(test_name)
337            self.assertOldResultWarning(test, int(not should_pass))
338
339    def testOldTestTesultSetup(self):
340        class Test(unittest.TestCase):
341            def setUp(self):
342                self.skipTest('no reason')
343            def testFoo(self):
344                pass
345        self.assertOldResultWarning(Test('testFoo'), 0)
346
347    def testOldTestResultClass(self):
348        @unittest.skip('no reason')
349        class Test(unittest.TestCase):
350            def testFoo(self):
351                pass
352        self.assertOldResultWarning(Test('testFoo'), 0)
353
354    def testOldResultWithRunner(self):
355        class Test(unittest.TestCase):
356            def testFoo(self):
357                pass
358        runner = unittest.TextTestRunner(resultclass=OldResult,
359                                          stream=StringIO())
360        # This will raise an exception if TextTestRunner can't handle old
361        # test result objects
362        runner.run(Test('testFoo'))
363
364
365class MockTraceback(object):
366    @staticmethod
367    def format_exception(*_):
368        return ['A traceback']
369
370def restore_traceback():
371    unittest.result.traceback = traceback
372
373
374class TestOutputBuffering(unittest.TestCase):
375
376    def setUp(self):
377        self._real_out = sys.stdout
378        self._real_err = sys.stderr
379
380    def tearDown(self):
381        sys.stdout = self._real_out
382        sys.stderr = self._real_err
383
384    def testBufferOutputOff(self):
385        real_out = self._real_out
386        real_err = self._real_err
387
388        result = unittest.TestResult()
389        self.assertFalse(result.buffer)
390
391        self.assertIs(real_out, sys.stdout)
392        self.assertIs(real_err, sys.stderr)
393
394        result.startTest(self)
395
396        self.assertIs(real_out, sys.stdout)
397        self.assertIs(real_err, sys.stderr)
398
399    def testBufferOutputStartTestAddSuccess(self):
400        real_out = self._real_out
401        real_err = self._real_err
402
403        result = unittest.TestResult()
404        self.assertFalse(result.buffer)
405
406        result.buffer = True
407
408        self.assertIs(real_out, sys.stdout)
409        self.assertIs(real_err, sys.stderr)
410
411        result.startTest(self)
412
413        self.assertIsNot(real_out, sys.stdout)
414        self.assertIsNot(real_err, sys.stderr)
415        self.assertIsInstance(sys.stdout, StringIO)
416        self.assertIsInstance(sys.stderr, StringIO)
417        self.assertIsNot(sys.stdout, sys.stderr)
418
419        out_stream = sys.stdout
420        err_stream = sys.stderr
421
422        result._original_stdout = StringIO()
423        result._original_stderr = StringIO()
424
425        print 'foo'
426        print >> sys.stderr, 'bar'
427
428        self.assertEqual(out_stream.getvalue(), 'foo\n')
429        self.assertEqual(err_stream.getvalue(), 'bar\n')
430
431        self.assertEqual(result._original_stdout.getvalue(), '')
432        self.assertEqual(result._original_stderr.getvalue(), '')
433
434        result.addSuccess(self)
435        result.stopTest(self)
436
437        self.assertIs(sys.stdout, result._original_stdout)
438        self.assertIs(sys.stderr, result._original_stderr)
439
440        self.assertEqual(result._original_stdout.getvalue(), '')
441        self.assertEqual(result._original_stderr.getvalue(), '')
442
443        self.assertEqual(out_stream.getvalue(), '')
444        self.assertEqual(err_stream.getvalue(), '')
445
446
447    def getStartedResult(self):
448        result = unittest.TestResult()
449        result.buffer = True
450        result.startTest(self)
451        return result
452
453    def testBufferOutputAddErrorOrFailure(self):
454        unittest.result.traceback = MockTraceback
455        self.addCleanup(restore_traceback)
456
457        for message_attr, add_attr, include_error in [
458            ('errors', 'addError', True),
459            ('failures', 'addFailure', False),
460            ('errors', 'addError', True),
461            ('failures', 'addFailure', False)
462        ]:
463            result = self.getStartedResult()
464            buffered_out = sys.stdout
465            buffered_err = sys.stderr
466            result._original_stdout = StringIO()
467            result._original_stderr = StringIO()
468
469            print >> sys.stdout, 'foo'
470            if include_error:
471                print >> sys.stderr, 'bar'
472
473
474            addFunction = getattr(result, add_attr)
475            addFunction(self, (None, None, None))
476            result.stopTest(self)
477
478            result_list = getattr(result, message_attr)
479            self.assertEqual(len(result_list), 1)
480
481            test, message = result_list[0]
482            expectedOutMessage = textwrap.dedent("""
483                Stdout:
484                foo
485            """)
486            expectedErrMessage = ''
487            if include_error:
488                expectedErrMessage = textwrap.dedent("""
489                Stderr:
490                bar
491            """)
492            expectedFullMessage = 'A traceback%s%s' % (expectedOutMessage, expectedErrMessage)
493
494            self.assertIs(test, self)
495            self.assertEqual(result._original_stdout.getvalue(), expectedOutMessage)
496            self.assertEqual(result._original_stderr.getvalue(), expectedErrMessage)
497            self.assertMultiLineEqual(message, expectedFullMessage)
498
499    def testBufferSetupClass(self):
500        result = unittest.TestResult()
501        result.buffer = True
502
503        class Foo(unittest.TestCase):
504            @classmethod
505            def setUpClass(cls):
506                1//0
507            def test_foo(self):
508                pass
509        suite = unittest.TestSuite([Foo('test_foo')])
510        suite(result)
511        self.assertEqual(len(result.errors), 1)
512
513    def testBufferTearDownClass(self):
514        result = unittest.TestResult()
515        result.buffer = True
516
517        class Foo(unittest.TestCase):
518            @classmethod
519            def tearDownClass(cls):
520                1//0
521            def test_foo(self):
522                pass
523        suite = unittest.TestSuite([Foo('test_foo')])
524        suite(result)
525        self.assertEqual(len(result.errors), 1)
526
527    def testBufferSetUpModule(self):
528        result = unittest.TestResult()
529        result.buffer = True
530
531        class Foo(unittest.TestCase):
532            def test_foo(self):
533                pass
534        class Module(object):
535            @staticmethod
536            def setUpModule():
537                1//0
538
539        Foo.__module__ = 'Module'
540        sys.modules['Module'] = Module
541        self.addCleanup(sys.modules.pop, 'Module')
542        suite = unittest.TestSuite([Foo('test_foo')])
543        suite(result)
544        self.assertEqual(len(result.errors), 1)
545
546    def testBufferTearDownModule(self):
547        result = unittest.TestResult()
548        result.buffer = True
549
550        class Foo(unittest.TestCase):
551            def test_foo(self):
552                pass
553        class Module(object):
554            @staticmethod
555            def tearDownModule():
556                1//0
557
558        Foo.__module__ = 'Module'
559        sys.modules['Module'] = Module
560        self.addCleanup(sys.modules.pop, 'Module')
561        suite = unittest.TestSuite([Foo('test_foo')])
562        suite(result)
563        self.assertEqual(len(result.errors), 1)
564
565
566if __name__ == '__main__':
567    unittest.main()
568