1"""Test case implementation"""
2
3import sys
4import functools
5import difflib
6import logging
7import pprint
8import re
9import warnings
10import collections
11import contextlib
12import traceback
13
14from . import result
15from .util import (strclass, safe_repr, _count_diff_all_purpose,
16                   _count_diff_hashable, _common_shorten_repr)
17
18__unittest = True
19
20_subtest_msg_sentinel = object()
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32
33class _ShouldStop(Exception):
34    """
35    The test should stop.
36    """
37
38class _UnexpectedSuccess(Exception):
39    """
40    The test was supposed to fail, but it didn't!
41    """
42
43
44class _Outcome(object):
45    def __init__(self, result=None):
46        self.expecting_failure = False
47        self.result = result
48        self.result_supports_subtests = hasattr(result, "addSubTest")
49        self.success = True
50        self.skipped = []
51        self.expectedFailure = None
52        self.errors = []
53
54    @contextlib.contextmanager
55    def testPartExecutor(self, test_case, isTest=False):
56        old_success = self.success
57        self.success = True
58        try:
59            yield
60        except KeyboardInterrupt:
61            raise
62        except SkipTest as e:
63            self.success = False
64            self.skipped.append((test_case, str(e)))
65        except _ShouldStop:
66            pass
67        except:
68            exc_info = sys.exc_info()
69            if self.expecting_failure:
70                self.expectedFailure = exc_info
71            else:
72                self.success = False
73                self.errors.append((test_case, exc_info))
74            # explicitly break a reference cycle:
75            # exc_info -> frame -> exc_info
76            exc_info = None
77        else:
78            if self.result_supports_subtests and self.success:
79                self.errors.append((test_case, None))
80        finally:
81            self.success = self.success and old_success
82
83
84def _id(obj):
85    return obj
86
87def skip(reason):
88    """
89    Unconditionally skip a test.
90    """
91    def decorator(test_item):
92        if not isinstance(test_item, type):
93            @functools.wraps(test_item)
94            def skip_wrapper(*args, **kwargs):
95                raise SkipTest(reason)
96            test_item = skip_wrapper
97
98        test_item.__unittest_skip__ = True
99        test_item.__unittest_skip_why__ = reason
100        return test_item
101    return decorator
102
103def skipIf(condition, reason):
104    """
105    Skip a test if the condition is true.
106    """
107    if condition:
108        return skip(reason)
109    return _id
110
111def skipUnless(condition, reason):
112    """
113    Skip a test unless the condition is true.
114    """
115    if not condition:
116        return skip(reason)
117    return _id
118
119def expectedFailure(test_item):
120    test_item.__unittest_expecting_failure__ = True
121    return test_item
122
123def _is_subtype(expected, basetype):
124    if isinstance(expected, tuple):
125        return all(_is_subtype(e, basetype) for e in expected)
126    return isinstance(expected, type) and issubclass(expected, basetype)
127
128class _BaseTestCaseContext:
129
130    def __init__(self, test_case):
131        self.test_case = test_case
132
133    def _raiseFailure(self, standardMsg):
134        msg = self.test_case._formatMessage(self.msg, standardMsg)
135        raise self.test_case.failureException(msg)
136
137class _AssertRaisesBaseContext(_BaseTestCaseContext):
138
139    def __init__(self, expected, test_case, expected_regex=None):
140        _BaseTestCaseContext.__init__(self, test_case)
141        self.expected = expected
142        self.test_case = test_case
143        if expected_regex is not None:
144            expected_regex = re.compile(expected_regex)
145        self.expected_regex = expected_regex
146        self.obj_name = None
147        self.msg = None
148
149    def handle(self, name, args, kwargs):
150        """
151        If args is empty, assertRaises/Warns is being used as a
152        context manager, so check for a 'msg' kwarg and return self.
153        If args is not empty, call a callable passing positional and keyword
154        arguments.
155        """
156        try:
157            if not _is_subtype(self.expected, self._base_type):
158                raise TypeError('%s() arg 1 must be %s' %
159                                (name, self._base_type_str))
160            if args and args[0] is None:
161                warnings.warn("callable is None",
162                              DeprecationWarning, 3)
163                args = ()
164            if not args:
165                self.msg = kwargs.pop('msg', None)
166                if kwargs:
167                    warnings.warn('%r is an invalid keyword argument for '
168                                  'this function' % next(iter(kwargs)),
169                                  DeprecationWarning, 3)
170                return self
171
172            callable_obj, *args = args
173            try:
174                self.obj_name = callable_obj.__name__
175            except AttributeError:
176                self.obj_name = str(callable_obj)
177            with self:
178                callable_obj(*args, **kwargs)
179        finally:
180            # bpo-23890: manually break a reference cycle
181            self = None
182
183
184class _AssertRaisesContext(_AssertRaisesBaseContext):
185    """A context manager used to implement TestCase.assertRaises* methods."""
186
187    _base_type = BaseException
188    _base_type_str = 'an exception type or tuple of exception types'
189
190    def __enter__(self):
191        return self
192
193    def __exit__(self, exc_type, exc_value, tb):
194        if exc_type is None:
195            try:
196                exc_name = self.expected.__name__
197            except AttributeError:
198                exc_name = str(self.expected)
199            if self.obj_name:
200                self._raiseFailure("{} not raised by {}".format(exc_name,
201                                                                self.obj_name))
202            else:
203                self._raiseFailure("{} not raised".format(exc_name))
204        else:
205            traceback.clear_frames(tb)
206        if not issubclass(exc_type, self.expected):
207            # let unexpected exceptions pass through
208            return False
209        # store exception, without traceback, for later retrieval
210        self.exception = exc_value.with_traceback(None)
211        if self.expected_regex is None:
212            return True
213
214        expected_regex = self.expected_regex
215        if not expected_regex.search(str(exc_value)):
216            self._raiseFailure('"{}" does not match "{}"'.format(
217                     expected_regex.pattern, str(exc_value)))
218        return True
219
220
221class _AssertWarnsContext(_AssertRaisesBaseContext):
222    """A context manager used to implement TestCase.assertWarns* methods."""
223
224    _base_type = Warning
225    _base_type_str = 'a warning type or tuple of warning types'
226
227    def __enter__(self):
228        # The __warningregistry__'s need to be in a pristine state for tests
229        # to work properly.
230        for v in sys.modules.values():
231            if getattr(v, '__warningregistry__', None):
232                v.__warningregistry__ = {}
233        self.warnings_manager = warnings.catch_warnings(record=True)
234        self.warnings = self.warnings_manager.__enter__()
235        warnings.simplefilter("always", self.expected)
236        return self
237
238    def __exit__(self, exc_type, exc_value, tb):
239        self.warnings_manager.__exit__(exc_type, exc_value, tb)
240        if exc_type is not None:
241            # let unexpected exceptions pass through
242            return
243        try:
244            exc_name = self.expected.__name__
245        except AttributeError:
246            exc_name = str(self.expected)
247        first_matching = None
248        for m in self.warnings:
249            w = m.message
250            if not isinstance(w, self.expected):
251                continue
252            if first_matching is None:
253                first_matching = w
254            if (self.expected_regex is not None and
255                not self.expected_regex.search(str(w))):
256                continue
257            # store warning for later retrieval
258            self.warning = w
259            self.filename = m.filename
260            self.lineno = m.lineno
261            return
262        # Now we simply try to choose a helpful failure message
263        if first_matching is not None:
264            self._raiseFailure('"{}" does not match "{}"'.format(
265                     self.expected_regex.pattern, str(first_matching)))
266        if self.obj_name:
267            self._raiseFailure("{} not triggered by {}".format(exc_name,
268                                                               self.obj_name))
269        else:
270            self._raiseFailure("{} not triggered".format(exc_name))
271
272
273
274_LoggingWatcher = collections.namedtuple("_LoggingWatcher",
275                                         ["records", "output"])
276
277
278class _CapturingHandler(logging.Handler):
279    """
280    A logging handler capturing all (raw and formatted) logging output.
281    """
282
283    def __init__(self):
284        logging.Handler.__init__(self)
285        self.watcher = _LoggingWatcher([], [])
286
287    def flush(self):
288        pass
289
290    def emit(self, record):
291        self.watcher.records.append(record)
292        msg = self.format(record)
293        self.watcher.output.append(msg)
294
295
296
297class _AssertLogsContext(_BaseTestCaseContext):
298    """A context manager used to implement TestCase.assertLogs()."""
299
300    LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s"
301
302    def __init__(self, test_case, logger_name, level):
303        _BaseTestCaseContext.__init__(self, test_case)
304        self.logger_name = logger_name
305        if level:
306            self.level = logging._nameToLevel.get(level, level)
307        else:
308            self.level = logging.INFO
309        self.msg = None
310
311    def __enter__(self):
312        if isinstance(self.logger_name, logging.Logger):
313            logger = self.logger = self.logger_name
314        else:
315            logger = self.logger = logging.getLogger(self.logger_name)
316        formatter = logging.Formatter(self.LOGGING_FORMAT)
317        handler = _CapturingHandler()
318        handler.setFormatter(formatter)
319        self.watcher = handler.watcher
320        self.old_handlers = logger.handlers[:]
321        self.old_level = logger.level
322        self.old_propagate = logger.propagate
323        logger.handlers = [handler]
324        logger.setLevel(self.level)
325        logger.propagate = False
326        return handler.watcher
327
328    def __exit__(self, exc_type, exc_value, tb):
329        self.logger.handlers = self.old_handlers
330        self.logger.propagate = self.old_propagate
331        self.logger.setLevel(self.old_level)
332        if exc_type is not None:
333            # let unexpected exceptions pass through
334            return False
335        if len(self.watcher.records) == 0:
336            self._raiseFailure(
337                "no logs of level {} or higher triggered on {}"
338                .format(logging.getLevelName(self.level), self.logger.name))
339
340
341class _OrderedChainMap(collections.ChainMap):
342    def __iter__(self):
343        seen = set()
344        for mapping in self.maps:
345            for k in mapping:
346                if k not in seen:
347                    seen.add(k)
348                    yield k
349
350
351class TestCase(object):
352    """A class whose instances are single test cases.
353
354    By default, the test code itself should be placed in a method named
355    'runTest'.
356
357    If the fixture may be used for many test cases, create as
358    many test methods as are needed. When instantiating such a TestCase
359    subclass, specify in the constructor arguments the name of the test method
360    that the instance is to execute.
361
362    Test authors should subclass TestCase for their own tests. Construction
363    and deconstruction of the test's environment ('fixture') can be
364    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
365
366    If it is necessary to override the __init__ method, the base class
367    __init__ method must always be called. It is important that subclasses
368    should not change the signature of their __init__ method, since instances
369    of the classes are instantiated automatically by parts of the framework
370    in order to be run.
371
372    When subclassing TestCase, you can set these attributes:
373    * failureException: determines which exception will be raised when
374        the instance's assertion methods fail; test methods raising this
375        exception will be deemed to have 'failed' rather than 'errored'.
376    * longMessage: determines whether long messages (including repr of
377        objects used in assert methods) will be printed on failure in *addition*
378        to any explicit message passed.
379    * maxDiff: sets the maximum length of a diff in failure messages
380        by assert methods using difflib. It is looked up as an instance
381        attribute so can be configured by individual tests if required.
382    """
383
384    failureException = AssertionError
385
386    longMessage = True
387
388    maxDiff = 80*8
389
390    # If a string is longer than _diffThreshold, use normal comparison instead
391    # of difflib.  See #11763.
392    _diffThreshold = 2**16
393
394    # Attribute used by TestSuite for classSetUp
395
396    _classSetupFailed = False
397
398    def __init__(self, methodName='runTest'):
399        """Create an instance of the class that will use the named test
400           method when executed. Raises a ValueError if the instance does
401           not have a method with the specified name.
402        """
403        self._testMethodName = methodName
404        self._outcome = None
405        self._testMethodDoc = 'No test'
406        try:
407            testMethod = getattr(self, methodName)
408        except AttributeError:
409            if methodName != 'runTest':
410                # we allow instantiation with no explicit method name
411                # but not an *incorrect* or missing method name
412                raise ValueError("no such test method in %s: %s" %
413                      (self.__class__, methodName))
414        else:
415            self._testMethodDoc = testMethod.__doc__
416        self._cleanups = []
417        self._subtest = None
418
419        # Map types to custom assertEqual functions that will compare
420        # instances of said type in more detail to generate a more useful
421        # error message.
422        self._type_equality_funcs = {}
423        self.addTypeEqualityFunc(dict, 'assertDictEqual')
424        self.addTypeEqualityFunc(list, 'assertListEqual')
425        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
426        self.addTypeEqualityFunc(set, 'assertSetEqual')
427        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
428        self.addTypeEqualityFunc(str, 'assertMultiLineEqual')
429
430    def addTypeEqualityFunc(self, typeobj, function):
431        """Add a type specific assertEqual style function to compare a type.
432
433        This method is for use by TestCase subclasses that need to register
434        their own type equality functions to provide nicer error messages.
435
436        Args:
437            typeobj: The data type to call this function on when both values
438                    are of the same type in assertEqual().
439            function: The callable taking two arguments and an optional
440                    msg= argument that raises self.failureException with a
441                    useful error message when the two arguments are not equal.
442        """
443        self._type_equality_funcs[typeobj] = function
444
445    def addCleanup(self, function, *args, **kwargs):
446        """Add a function, with arguments, to be called when the test is
447        completed. Functions added are called on a LIFO basis and are
448        called after tearDown on test failure or success.
449
450        Cleanup items are called even if setUp fails (unlike tearDown)."""
451        self._cleanups.append((function, args, kwargs))
452
453    def setUp(self):
454        "Hook method for setting up the test fixture before exercising it."
455        pass
456
457    def tearDown(self):
458        "Hook method for deconstructing the test fixture after testing it."
459        pass
460
461    @classmethod
462    def setUpClass(cls):
463        "Hook method for setting up class fixture before running tests in the class."
464
465    @classmethod
466    def tearDownClass(cls):
467        "Hook method for deconstructing the class fixture after running all tests in the class."
468
469    def countTestCases(self):
470        return 1
471
472    def defaultTestResult(self):
473        return result.TestResult()
474
475    def shortDescription(self):
476        """Returns a one-line description of the test, or None if no
477        description has been provided.
478
479        The default implementation of this method returns the first line of
480        the specified test method's docstring.
481        """
482        doc = self._testMethodDoc
483        return doc and doc.split("\n")[0].strip() or None
484
485
486    def id(self):
487        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
488
489    def __eq__(self, other):
490        if type(self) is not type(other):
491            return NotImplemented
492
493        return self._testMethodName == other._testMethodName
494
495    def __hash__(self):
496        return hash((type(self), self._testMethodName))
497
498    def __str__(self):
499        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
500
501    def __repr__(self):
502        return "<%s testMethod=%s>" % \
503               (strclass(self.__class__), self._testMethodName)
504
505    def _addSkip(self, result, test_case, reason):
506        addSkip = getattr(result, 'addSkip', None)
507        if addSkip is not None:
508            addSkip(test_case, reason)
509        else:
510            warnings.warn("TestResult has no addSkip method, skips not reported",
511                          RuntimeWarning, 2)
512            result.addSuccess(test_case)
513
514    @contextlib.contextmanager
515    def subTest(self, msg=_subtest_msg_sentinel, **params):
516        """Return a context manager that will return the enclosed block
517        of code in a subtest identified by the optional message and
518        keyword parameters.  A failure in the subtest marks the test
519        case as failed but resumes execution at the end of the enclosed
520        block, allowing further test code to be executed.
521        """
522        if self._outcome is None or not self._outcome.result_supports_subtests:
523            yield
524            return
525        parent = self._subtest
526        if parent is None:
527            params_map = _OrderedChainMap(params)
528        else:
529            params_map = parent.params.new_child(params)
530        self._subtest = _SubTest(self, msg, params_map)
531        try:
532            with self._outcome.testPartExecutor(self._subtest, isTest=True):
533                yield
534            if not self._outcome.success:
535                result = self._outcome.result
536                if result is not None and result.failfast:
537                    raise _ShouldStop
538            elif self._outcome.expectedFailure:
539                # If the test is expecting a failure, we really want to
540                # stop now and register the expected failure.
541                raise _ShouldStop
542        finally:
543            self._subtest = parent
544
545    def _feedErrorsToResult(self, result, errors):
546        for test, exc_info in errors:
547            if isinstance(test, _SubTest):
548                result.addSubTest(test.test_case, test, exc_info)
549            elif exc_info is not None:
550                if issubclass(exc_info[0], self.failureException):
551                    result.addFailure(test, exc_info)
552                else:
553                    result.addError(test, exc_info)
554
555    def _addExpectedFailure(self, result, exc_info):
556        try:
557            addExpectedFailure = result.addExpectedFailure
558        except AttributeError:
559            warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
560                          RuntimeWarning)
561            result.addSuccess(self)
562        else:
563            addExpectedFailure(self, exc_info)
564
565    def _addUnexpectedSuccess(self, result):
566        try:
567            addUnexpectedSuccess = result.addUnexpectedSuccess
568        except AttributeError:
569            warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure",
570                          RuntimeWarning)
571            # We need to pass an actual exception and traceback to addFailure,
572            # otherwise the legacy result can choke.
573            try:
574                raise _UnexpectedSuccess from None
575            except _UnexpectedSuccess:
576                result.addFailure(self, sys.exc_info())
577        else:
578            addUnexpectedSuccess(self)
579
580    def run(self, result=None):
581        orig_result = result
582        if result is None:
583            result = self.defaultTestResult()
584            startTestRun = getattr(result, 'startTestRun', None)
585            if startTestRun is not None:
586                startTestRun()
587
588        result.startTest(self)
589
590        testMethod = getattr(self, self._testMethodName)
591        if (getattr(self.__class__, "__unittest_skip__", False) or
592            getattr(testMethod, "__unittest_skip__", False)):
593            # If the class or method was skipped.
594            try:
595                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
596                            or getattr(testMethod, '__unittest_skip_why__', ''))
597                self._addSkip(result, self, skip_why)
598            finally:
599                result.stopTest(self)
600            return
601        expecting_failure_method = getattr(testMethod,
602                                           "__unittest_expecting_failure__", False)
603        expecting_failure_class = getattr(self,
604                                          "__unittest_expecting_failure__", False)
605        expecting_failure = expecting_failure_class or expecting_failure_method
606        outcome = _Outcome(result)
607        try:
608            self._outcome = outcome
609
610            with outcome.testPartExecutor(self):
611                self.setUp()
612            if outcome.success:
613                outcome.expecting_failure = expecting_failure
614                with outcome.testPartExecutor(self, isTest=True):
615                    testMethod()
616                outcome.expecting_failure = False
617                with outcome.testPartExecutor(self):
618                    self.tearDown()
619
620            self.doCleanups()
621            for test, reason in outcome.skipped:
622                self._addSkip(result, test, reason)
623            self._feedErrorsToResult(result, outcome.errors)
624            if outcome.success:
625                if expecting_failure:
626                    if outcome.expectedFailure:
627                        self._addExpectedFailure(result, outcome.expectedFailure)
628                    else:
629                        self._addUnexpectedSuccess(result)
630                else:
631                    result.addSuccess(self)
632            return result
633        finally:
634            result.stopTest(self)
635            if orig_result is None:
636                stopTestRun = getattr(result, 'stopTestRun', None)
637                if stopTestRun is not None:
638                    stopTestRun()
639
640            # explicitly break reference cycles:
641            # outcome.errors -> frame -> outcome -> outcome.errors
642            # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
643            outcome.errors.clear()
644            outcome.expectedFailure = None
645
646            # clear the outcome, no more needed
647            self._outcome = None
648
649    def doCleanups(self):
650        """Execute all cleanup functions. Normally called for you after
651        tearDown."""
652        outcome = self._outcome or _Outcome()
653        while self._cleanups:
654            function, args, kwargs = self._cleanups.pop()
655            with outcome.testPartExecutor(self):
656                function(*args, **kwargs)
657
658        # return this for backwards compatibility
659        # even though we no longer us it internally
660        return outcome.success
661
662    def __call__(self, *args, **kwds):
663        return self.run(*args, **kwds)
664
665    def debug(self):
666        """Run the test without collecting errors in a TestResult"""
667        self.setUp()
668        getattr(self, self._testMethodName)()
669        self.tearDown()
670        while self._cleanups:
671            function, args, kwargs = self._cleanups.pop(-1)
672            function(*args, **kwargs)
673
674    def skipTest(self, reason):
675        """Skip this test."""
676        raise SkipTest(reason)
677
678    def fail(self, msg=None):
679        """Fail immediately, with the given message."""
680        raise self.failureException(msg)
681
682    def assertFalse(self, expr, msg=None):
683        """Check that the expression is false."""
684        if expr:
685            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
686            raise self.failureException(msg)
687
688    def assertTrue(self, expr, msg=None):
689        """Check that the expression is true."""
690        if not expr:
691            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
692            raise self.failureException(msg)
693
694    def _formatMessage(self, msg, standardMsg):
695        """Honour the longMessage attribute when generating failure messages.
696        If longMessage is False this means:
697        * Use only an explicit message if it is provided
698        * Otherwise use the standard message for the assert
699
700        If longMessage is True:
701        * Use the standard message
702        * If an explicit message is provided, plus ' : ' and the explicit message
703        """
704        if not self.longMessage:
705            return msg or standardMsg
706        if msg is None:
707            return standardMsg
708        try:
709            # don't switch to '{}' formatting in Python 2.X
710            # it changes the way unicode input is handled
711            return '%s : %s' % (standardMsg, msg)
712        except UnicodeDecodeError:
713            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
714
715    def assertRaises(self, expected_exception, *args, **kwargs):
716        """Fail unless an exception of class expected_exception is raised
717           by the callable when invoked with specified positional and
718           keyword arguments. If a different type of exception is
719           raised, it will not be caught, and the test case will be
720           deemed to have suffered an error, exactly as for an
721           unexpected exception.
722
723           If called with the callable and arguments omitted, will return a
724           context object used like this::
725
726                with self.assertRaises(SomeException):
727                    do_something()
728
729           An optional keyword argument 'msg' can be provided when assertRaises
730           is used as a context object.
731
732           The context manager keeps a reference to the exception as
733           the 'exception' attribute. This allows you to inspect the
734           exception after the assertion::
735
736               with self.assertRaises(SomeException) as cm:
737                   do_something()
738               the_exception = cm.exception
739               self.assertEqual(the_exception.error_code, 3)
740        """
741        context = _AssertRaisesContext(expected_exception, self)
742        try:
743            return context.handle('assertRaises', args, kwargs)
744        finally:
745            # bpo-23890: manually break a reference cycle
746            context = None
747
748    def assertWarns(self, expected_warning, *args, **kwargs):
749        """Fail unless a warning of class warnClass is triggered
750           by the callable when invoked with specified positional and
751           keyword arguments.  If a different type of warning is
752           triggered, it will not be handled: depending on the other
753           warning filtering rules in effect, it might be silenced, printed
754           out, or raised as an exception.
755
756           If called with the callable and arguments omitted, will return a
757           context object used like this::
758
759                with self.assertWarns(SomeWarning):
760                    do_something()
761
762           An optional keyword argument 'msg' can be provided when assertWarns
763           is used as a context object.
764
765           The context manager keeps a reference to the first matching
766           warning as the 'warning' attribute; similarly, the 'filename'
767           and 'lineno' attributes give you information about the line
768           of Python code from which the warning was triggered.
769           This allows you to inspect the warning after the assertion::
770
771               with self.assertWarns(SomeWarning) as cm:
772                   do_something()
773               the_warning = cm.warning
774               self.assertEqual(the_warning.some_attribute, 147)
775        """
776        context = _AssertWarnsContext(expected_warning, self)
777        return context.handle('assertWarns', args, kwargs)
778
779    def assertLogs(self, logger=None, level=None):
780        """Fail unless a log message of level *level* or higher is emitted
781        on *logger_name* or its children.  If omitted, *level* defaults to
782        INFO and *logger* defaults to the root logger.
783
784        This method must be used as a context manager, and will yield
785        a recording object with two attributes: `output` and `records`.
786        At the end of the context manager, the `output` attribute will
787        be a list of the matching formatted log messages and the
788        `records` attribute will be a list of the corresponding LogRecord
789        objects.
790
791        Example::
792
793            with self.assertLogs('foo', level='INFO') as cm:
794                logging.getLogger('foo').info('first message')
795                logging.getLogger('foo.bar').error('second message')
796            self.assertEqual(cm.output, ['INFO:foo:first message',
797                                         'ERROR:foo.bar:second message'])
798        """
799        return _AssertLogsContext(self, logger, level)
800
801    def _getAssertEqualityFunc(self, first, second):
802        """Get a detailed comparison function for the types of the two args.
803
804        Returns: A callable accepting (first, second, msg=None) that will
805        raise a failure exception if first != second with a useful human
806        readable error message for those types.
807        """
808        #
809        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
810        # and vice versa.  I opted for the conservative approach in case
811        # subclasses are not intended to be compared in detail to their super
812        # class instances using a type equality func.  This means testing
813        # subtypes won't automagically use the detailed comparison.  Callers
814        # should use their type specific assertSpamEqual method to compare
815        # subclasses if the detailed comparison is desired and appropriate.
816        # See the discussion in http://bugs.python.org/issue2578.
817        #
818        if type(first) is type(second):
819            asserter = self._type_equality_funcs.get(type(first))
820            if asserter is not None:
821                if isinstance(asserter, str):
822                    asserter = getattr(self, asserter)
823                return asserter
824
825        return self._baseAssertEqual
826
827    def _baseAssertEqual(self, first, second, msg=None):
828        """The default assertEqual implementation, not type specific."""
829        if not first == second:
830            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
831            msg = self._formatMessage(msg, standardMsg)
832            raise self.failureException(msg)
833
834    def assertEqual(self, first, second, msg=None):
835        """Fail if the two objects are unequal as determined by the '=='
836           operator.
837        """
838        assertion_func = self._getAssertEqualityFunc(first, second)
839        assertion_func(first, second, msg=msg)
840
841    def assertNotEqual(self, first, second, msg=None):
842        """Fail if the two objects are equal as determined by the '!='
843           operator.
844        """
845        if not first != second:
846            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
847                                                          safe_repr(second)))
848            raise self.failureException(msg)
849
850    def assertAlmostEqual(self, first, second, places=None, msg=None,
851                          delta=None):
852        """Fail if the two objects are unequal as determined by their
853           difference rounded to the given number of decimal places
854           (default 7) and comparing to zero, or by comparing that the
855           difference between the two objects is more than the given
856           delta.
857
858           Note that decimal places (from zero) are usually not the same
859           as significant digits (measured from the most significant digit).
860
861           If the two objects compare equal then they will automatically
862           compare almost equal.
863        """
864        if first == second:
865            # shortcut
866            return
867        if delta is not None and places is not None:
868            raise TypeError("specify delta or places not both")
869
870        diff = abs(first - second)
871        if delta is not None:
872            if diff <= delta:
873                return
874
875            standardMsg = '%s != %s within %s delta (%s difference)' % (
876                safe_repr(first),
877                safe_repr(second),
878                safe_repr(delta),
879                safe_repr(diff))
880        else:
881            if places is None:
882                places = 7
883
884            if round(diff, places) == 0:
885                return
886
887            standardMsg = '%s != %s within %r places (%s difference)' % (
888                safe_repr(first),
889                safe_repr(second),
890                places,
891                safe_repr(diff))
892        msg = self._formatMessage(msg, standardMsg)
893        raise self.failureException(msg)
894
895    def assertNotAlmostEqual(self, first, second, places=None, msg=None,
896                             delta=None):
897        """Fail if the two objects are equal as determined by their
898           difference rounded to the given number of decimal places
899           (default 7) and comparing to zero, or by comparing that the
900           difference between the two objects is less than the given delta.
901
902           Note that decimal places (from zero) are usually not the same
903           as significant digits (measured from the most significant digit).
904
905           Objects that are equal automatically fail.
906        """
907        if delta is not None and places is not None:
908            raise TypeError("specify delta or places not both")
909        diff = abs(first - second)
910        if delta is not None:
911            if not (first == second) and diff > delta:
912                return
913            standardMsg = '%s == %s within %s delta (%s difference)' % (
914                safe_repr(first),
915                safe_repr(second),
916                safe_repr(delta),
917                safe_repr(diff))
918        else:
919            if places is None:
920                places = 7
921            if not (first == second) and round(diff, places) != 0:
922                return
923            standardMsg = '%s == %s within %r places' % (safe_repr(first),
924                                                         safe_repr(second),
925                                                         places)
926
927        msg = self._formatMessage(msg, standardMsg)
928        raise self.failureException(msg)
929
930    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
931        """An equality assertion for ordered sequences (like lists and tuples).
932
933        For the purposes of this function, a valid ordered sequence type is one
934        which can be indexed, has a length, and has an equality operator.
935
936        Args:
937            seq1: The first sequence to compare.
938            seq2: The second sequence to compare.
939            seq_type: The expected datatype of the sequences, or None if no
940                    datatype should be enforced.
941            msg: Optional message to use on failure instead of a list of
942                    differences.
943        """
944        if seq_type is not None:
945            seq_type_name = seq_type.__name__
946            if not isinstance(seq1, seq_type):
947                raise self.failureException('First sequence is not a %s: %s'
948                                        % (seq_type_name, safe_repr(seq1)))
949            if not isinstance(seq2, seq_type):
950                raise self.failureException('Second sequence is not a %s: %s'
951                                        % (seq_type_name, safe_repr(seq2)))
952        else:
953            seq_type_name = "sequence"
954
955        differing = None
956        try:
957            len1 = len(seq1)
958        except (TypeError, NotImplementedError):
959            differing = 'First %s has no length.    Non-sequence?' % (
960                    seq_type_name)
961
962        if differing is None:
963            try:
964                len2 = len(seq2)
965            except (TypeError, NotImplementedError):
966                differing = 'Second %s has no length.    Non-sequence?' % (
967                        seq_type_name)
968
969        if differing is None:
970            if seq1 == seq2:
971                return
972
973            differing = '%ss differ: %s != %s\n' % (
974                    (seq_type_name.capitalize(),) +
975                    _common_shorten_repr(seq1, seq2))
976
977            for i in range(min(len1, len2)):
978                try:
979                    item1 = seq1[i]
980                except (TypeError, IndexError, NotImplementedError):
981                    differing += ('\nUnable to index element %d of first %s\n' %
982                                 (i, seq_type_name))
983                    break
984
985                try:
986                    item2 = seq2[i]
987                except (TypeError, IndexError, NotImplementedError):
988                    differing += ('\nUnable to index element %d of second %s\n' %
989                                 (i, seq_type_name))
990                    break
991
992                if item1 != item2:
993                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
994                                 ((i,) + _common_shorten_repr(item1, item2)))
995                    break
996            else:
997                if (len1 == len2 and seq_type is None and
998                    type(seq1) != type(seq2)):
999                    # The sequences are the same, but have differing types.
1000                    return
1001
1002            if len1 > len2:
1003                differing += ('\nFirst %s contains %d additional '
1004                             'elements.\n' % (seq_type_name, len1 - len2))
1005                try:
1006                    differing += ('First extra element %d:\n%s\n' %
1007                                  (len2, safe_repr(seq1[len2])))
1008                except (TypeError, IndexError, NotImplementedError):
1009                    differing += ('Unable to index element %d '
1010                                  'of first %s\n' % (len2, seq_type_name))
1011            elif len1 < len2:
1012                differing += ('\nSecond %s contains %d additional '
1013                             'elements.\n' % (seq_type_name, len2 - len1))
1014                try:
1015                    differing += ('First extra element %d:\n%s\n' %
1016                                  (len1, safe_repr(seq2[len1])))
1017                except (TypeError, IndexError, NotImplementedError):
1018                    differing += ('Unable to index element %d '
1019                                  'of second %s\n' % (len1, seq_type_name))
1020        standardMsg = differing
1021        diffMsg = '\n' + '\n'.join(
1022            difflib.ndiff(pprint.pformat(seq1).splitlines(),
1023                          pprint.pformat(seq2).splitlines()))
1024
1025        standardMsg = self._truncateMessage(standardMsg, diffMsg)
1026        msg = self._formatMessage(msg, standardMsg)
1027        self.fail(msg)
1028
1029    def _truncateMessage(self, message, diff):
1030        max_diff = self.maxDiff
1031        if max_diff is None or len(diff) <= max_diff:
1032            return message + diff
1033        return message + (DIFF_OMITTED % len(diff))
1034
1035    def assertListEqual(self, list1, list2, msg=None):
1036        """A list-specific equality assertion.
1037
1038        Args:
1039            list1: The first list to compare.
1040            list2: The second list to compare.
1041            msg: Optional message to use on failure instead of a list of
1042                    differences.
1043
1044        """
1045        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
1046
1047    def assertTupleEqual(self, tuple1, tuple2, msg=None):
1048        """A tuple-specific equality assertion.
1049
1050        Args:
1051            tuple1: The first tuple to compare.
1052            tuple2: The second tuple to compare.
1053            msg: Optional message to use on failure instead of a list of
1054                    differences.
1055        """
1056        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
1057
1058    def assertSetEqual(self, set1, set2, msg=None):
1059        """A set-specific equality assertion.
1060
1061        Args:
1062            set1: The first set to compare.
1063            set2: The second set to compare.
1064            msg: Optional message to use on failure instead of a list of
1065                    differences.
1066
1067        assertSetEqual uses ducktyping to support different types of sets, and
1068        is optimized for sets specifically (parameters must support a
1069        difference method).
1070        """
1071        try:
1072            difference1 = set1.difference(set2)
1073        except TypeError as e:
1074            self.fail('invalid type when attempting set difference: %s' % e)
1075        except AttributeError as e:
1076            self.fail('first argument does not support set difference: %s' % e)
1077
1078        try:
1079            difference2 = set2.difference(set1)
1080        except TypeError as e:
1081            self.fail('invalid type when attempting set difference: %s' % e)
1082        except AttributeError as e:
1083            self.fail('second argument does not support set difference: %s' % e)
1084
1085        if not (difference1 or difference2):
1086            return
1087
1088        lines = []
1089        if difference1:
1090            lines.append('Items in the first set but not the second:')
1091            for item in difference1:
1092                lines.append(repr(item))
1093        if difference2:
1094            lines.append('Items in the second set but not the first:')
1095            for item in difference2:
1096                lines.append(repr(item))
1097
1098        standardMsg = '\n'.join(lines)
1099        self.fail(self._formatMessage(msg, standardMsg))
1100
1101    def assertIn(self, member, container, msg=None):
1102        """Just like self.assertTrue(a in b), but with a nicer default message."""
1103        if member not in container:
1104            standardMsg = '%s not found in %s' % (safe_repr(member),
1105                                                  safe_repr(container))
1106            self.fail(self._formatMessage(msg, standardMsg))
1107
1108    def assertNotIn(self, member, container, msg=None):
1109        """Just like self.assertTrue(a not in b), but with a nicer default message."""
1110        if member in container:
1111            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
1112                                                        safe_repr(container))
1113            self.fail(self._formatMessage(msg, standardMsg))
1114
1115    def assertIs(self, expr1, expr2, msg=None):
1116        """Just like self.assertTrue(a is b), but with a nicer default message."""
1117        if expr1 is not expr2:
1118            standardMsg = '%s is not %s' % (safe_repr(expr1),
1119                                             safe_repr(expr2))
1120            self.fail(self._formatMessage(msg, standardMsg))
1121
1122    def assertIsNot(self, expr1, expr2, msg=None):
1123        """Just like self.assertTrue(a is not b), but with a nicer default message."""
1124        if expr1 is expr2:
1125            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
1126            self.fail(self._formatMessage(msg, standardMsg))
1127
1128    def assertDictEqual(self, d1, d2, msg=None):
1129        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
1130        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
1131
1132        if d1 != d2:
1133            standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
1134            diff = ('\n' + '\n'.join(difflib.ndiff(
1135                           pprint.pformat(d1).splitlines(),
1136                           pprint.pformat(d2).splitlines())))
1137            standardMsg = self._truncateMessage(standardMsg, diff)
1138            self.fail(self._formatMessage(msg, standardMsg))
1139
1140    def assertDictContainsSubset(self, subset, dictionary, msg=None):
1141        """Checks whether dictionary is a superset of subset."""
1142        warnings.warn('assertDictContainsSubset is deprecated',
1143                      DeprecationWarning)
1144        missing = []
1145        mismatched = []
1146        for key, value in subset.items():
1147            if key not in dictionary:
1148                missing.append(key)
1149            elif value != dictionary[key]:
1150                mismatched.append('%s, expected: %s, actual: %s' %
1151                                  (safe_repr(key), safe_repr(value),
1152                                   safe_repr(dictionary[key])))
1153
1154        if not (missing or mismatched):
1155            return
1156
1157        standardMsg = ''
1158        if missing:
1159            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
1160                                                    missing)
1161        if mismatched:
1162            if standardMsg:
1163                standardMsg += '; '
1164            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
1165
1166        self.fail(self._formatMessage(msg, standardMsg))
1167
1168
1169    def assertCountEqual(self, first, second, msg=None):
1170        """An unordered sequence comparison asserting that the same elements,
1171        regardless of order.  If the same element occurs more than once,
1172        it verifies that the elements occur the same number of times.
1173
1174            self.assertEqual(Counter(list(first)),
1175                             Counter(list(second)))
1176
1177         Example:
1178            - [0, 1, 1] and [1, 0, 1] compare equal.
1179            - [0, 0, 1] and [0, 1] compare unequal.
1180
1181        """
1182        first_seq, second_seq = list(first), list(second)
1183        try:
1184            first = collections.Counter(first_seq)
1185            second = collections.Counter(second_seq)
1186        except TypeError:
1187            # Handle case with unhashable elements
1188            differences = _count_diff_all_purpose(first_seq, second_seq)
1189        else:
1190            if first == second:
1191                return
1192            differences = _count_diff_hashable(first_seq, second_seq)
1193
1194        if differences:
1195            standardMsg = 'Element counts were not equal:\n'
1196            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
1197            diffMsg = '\n'.join(lines)
1198            standardMsg = self._truncateMessage(standardMsg, diffMsg)
1199            msg = self._formatMessage(msg, standardMsg)
1200            self.fail(msg)
1201
1202    def assertMultiLineEqual(self, first, second, msg=None):
1203        """Assert that two multi-line strings are equal."""
1204        self.assertIsInstance(first, str, 'First argument is not a string')
1205        self.assertIsInstance(second, str, 'Second argument is not a string')
1206
1207        if first != second:
1208            # don't use difflib if the strings are too long
1209            if (len(first) > self._diffThreshold or
1210                len(second) > self._diffThreshold):
1211                self._baseAssertEqual(first, second, msg)
1212            firstlines = first.splitlines(keepends=True)
1213            secondlines = second.splitlines(keepends=True)
1214            if len(firstlines) == 1 and first.strip('\r\n') == first:
1215                firstlines = [first + '\n']
1216                secondlines = [second + '\n']
1217            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
1218            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
1219            standardMsg = self._truncateMessage(standardMsg, diff)
1220            self.fail(self._formatMessage(msg, standardMsg))
1221
1222    def assertLess(self, a, b, msg=None):
1223        """Just like self.assertTrue(a < b), but with a nicer default message."""
1224        if not a < b:
1225            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
1226            self.fail(self._formatMessage(msg, standardMsg))
1227
1228    def assertLessEqual(self, a, b, msg=None):
1229        """Just like self.assertTrue(a <= b), but with a nicer default message."""
1230        if not a <= b:
1231            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
1232            self.fail(self._formatMessage(msg, standardMsg))
1233
1234    def assertGreater(self, a, b, msg=None):
1235        """Just like self.assertTrue(a > b), but with a nicer default message."""
1236        if not a > b:
1237            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
1238            self.fail(self._formatMessage(msg, standardMsg))
1239
1240    def assertGreaterEqual(self, a, b, msg=None):
1241        """Just like self.assertTrue(a >= b), but with a nicer default message."""
1242        if not a >= b:
1243            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
1244            self.fail(self._formatMessage(msg, standardMsg))
1245
1246    def assertIsNone(self, obj, msg=None):
1247        """Same as self.assertTrue(obj is None), with a nicer default message."""
1248        if obj is not None:
1249            standardMsg = '%s is not None' % (safe_repr(obj),)
1250            self.fail(self._formatMessage(msg, standardMsg))
1251
1252    def assertIsNotNone(self, obj, msg=None):
1253        """Included for symmetry with assertIsNone."""
1254        if obj is None:
1255            standardMsg = 'unexpectedly None'
1256            self.fail(self._formatMessage(msg, standardMsg))
1257
1258    def assertIsInstance(self, obj, cls, msg=None):
1259        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
1260        default message."""
1261        if not isinstance(obj, cls):
1262            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
1263            self.fail(self._formatMessage(msg, standardMsg))
1264
1265    def assertNotIsInstance(self, obj, cls, msg=None):
1266        """Included for symmetry with assertIsInstance."""
1267        if isinstance(obj, cls):
1268            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
1269            self.fail(self._formatMessage(msg, standardMsg))
1270
1271    def assertRaisesRegex(self, expected_exception, expected_regex,
1272                          *args, **kwargs):
1273        """Asserts that the message in a raised exception matches a regex.
1274
1275        Args:
1276            expected_exception: Exception class expected to be raised.
1277            expected_regex: Regex (re.Pattern object or string) expected
1278                    to be found in error message.
1279            args: Function to be called and extra positional args.
1280            kwargs: Extra kwargs.
1281            msg: Optional message used in case of failure. Can only be used
1282                    when assertRaisesRegex is used as a context manager.
1283        """
1284        context = _AssertRaisesContext(expected_exception, self, expected_regex)
1285        return context.handle('assertRaisesRegex', args, kwargs)
1286
1287    def assertWarnsRegex(self, expected_warning, expected_regex,
1288                         *args, **kwargs):
1289        """Asserts that the message in a triggered warning matches a regexp.
1290        Basic functioning is similar to assertWarns() with the addition
1291        that only warnings whose messages also match the regular expression
1292        are considered successful matches.
1293
1294        Args:
1295            expected_warning: Warning class expected to be triggered.
1296            expected_regex: Regex (re.Pattern object or string) expected
1297                    to be found in error message.
1298            args: Function to be called and extra positional args.
1299            kwargs: Extra kwargs.
1300            msg: Optional message used in case of failure. Can only be used
1301                    when assertWarnsRegex is used as a context manager.
1302        """
1303        context = _AssertWarnsContext(expected_warning, self, expected_regex)
1304        return context.handle('assertWarnsRegex', args, kwargs)
1305
1306    def assertRegex(self, text, expected_regex, msg=None):
1307        """Fail the test unless the text matches the regular expression."""
1308        if isinstance(expected_regex, (str, bytes)):
1309            assert expected_regex, "expected_regex must not be empty."
1310            expected_regex = re.compile(expected_regex)
1311        if not expected_regex.search(text):
1312            standardMsg = "Regex didn't match: %r not found in %r" % (
1313                expected_regex.pattern, text)
1314            # _formatMessage ensures the longMessage option is respected
1315            msg = self._formatMessage(msg, standardMsg)
1316            raise self.failureException(msg)
1317
1318    def assertNotRegex(self, text, unexpected_regex, msg=None):
1319        """Fail the test if the text matches the regular expression."""
1320        if isinstance(unexpected_regex, (str, bytes)):
1321            unexpected_regex = re.compile(unexpected_regex)
1322        match = unexpected_regex.search(text)
1323        if match:
1324            standardMsg = 'Regex matched: %r matches %r in %r' % (
1325                text[match.start() : match.end()],
1326                unexpected_regex.pattern,
1327                text)
1328            # _formatMessage ensures the longMessage option is respected
1329            msg = self._formatMessage(msg, standardMsg)
1330            raise self.failureException(msg)
1331
1332
1333    def _deprecate(original_func):
1334        def deprecated_func(*args, **kwargs):
1335            warnings.warn(
1336                'Please use {0} instead.'.format(original_func.__name__),
1337                DeprecationWarning, 2)
1338            return original_func(*args, **kwargs)
1339        return deprecated_func
1340
1341    # see #9424
1342    failUnlessEqual = assertEquals = _deprecate(assertEqual)
1343    failIfEqual = assertNotEquals = _deprecate(assertNotEqual)
1344    failUnlessAlmostEqual = assertAlmostEquals = _deprecate(assertAlmostEqual)
1345    failIfAlmostEqual = assertNotAlmostEquals = _deprecate(assertNotAlmostEqual)
1346    failUnless = assert_ = _deprecate(assertTrue)
1347    failUnlessRaises = _deprecate(assertRaises)
1348    failIf = _deprecate(assertFalse)
1349    assertRaisesRegexp = _deprecate(assertRaisesRegex)
1350    assertRegexpMatches = _deprecate(assertRegex)
1351    assertNotRegexpMatches = _deprecate(assertNotRegex)
1352
1353
1354
1355class FunctionTestCase(TestCase):
1356    """A test case that wraps a test function.
1357
1358    This is useful for slipping pre-existing test functions into the
1359    unittest framework. Optionally, set-up and tidy-up functions can be
1360    supplied. As with TestCase, the tidy-up ('tearDown') function will
1361    always be called if the set-up ('setUp') function ran successfully.
1362    """
1363
1364    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1365        super(FunctionTestCase, self).__init__()
1366        self._setUpFunc = setUp
1367        self._tearDownFunc = tearDown
1368        self._testFunc = testFunc
1369        self._description = description
1370
1371    def setUp(self):
1372        if self._setUpFunc is not None:
1373            self._setUpFunc()
1374
1375    def tearDown(self):
1376        if self._tearDownFunc is not None:
1377            self._tearDownFunc()
1378
1379    def runTest(self):
1380        self._testFunc()
1381
1382    def id(self):
1383        return self._testFunc.__name__
1384
1385    def __eq__(self, other):
1386        if not isinstance(other, self.__class__):
1387            return NotImplemented
1388
1389        return self._setUpFunc == other._setUpFunc and \
1390               self._tearDownFunc == other._tearDownFunc and \
1391               self._testFunc == other._testFunc and \
1392               self._description == other._description
1393
1394    def __hash__(self):
1395        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1396                     self._testFunc, self._description))
1397
1398    def __str__(self):
1399        return "%s (%s)" % (strclass(self.__class__),
1400                            self._testFunc.__name__)
1401
1402    def __repr__(self):
1403        return "<%s tec=%s>" % (strclass(self.__class__),
1404                                     self._testFunc)
1405
1406    def shortDescription(self):
1407        if self._description is not None:
1408            return self._description
1409        doc = self._testFunc.__doc__
1410        return doc and doc.split("\n")[0].strip() or None
1411
1412
1413class _SubTest(TestCase):
1414
1415    def __init__(self, test_case, message, params):
1416        super().__init__()
1417        self._message = message
1418        self.test_case = test_case
1419        self.params = params
1420        self.failureException = test_case.failureException
1421
1422    def runTest(self):
1423        raise NotImplementedError("subtests cannot be run directly")
1424
1425    def _subDescription(self):
1426        parts = []
1427        if self._message is not _subtest_msg_sentinel:
1428            parts.append("[{}]".format(self._message))
1429        if self.params:
1430            params_desc = ', '.join(
1431                "{}={!r}".format(k, v)
1432                for (k, v) in self.params.items())
1433            parts.append("({})".format(params_desc))
1434        return " ".join(parts) or '(<subtest>)'
1435
1436    def id(self):
1437        return "{} {}".format(self.test_case.id(), self._subDescription())
1438
1439    def shortDescription(self):
1440        """Returns a one-line description of the subtest, or None if no
1441        description has been provided.
1442        """
1443        return self.test_case.shortDescription()
1444
1445    def __str__(self):
1446        return "{} {}".format(self.test_case, self._subDescription())
1447