1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import
16
17import collections
18import itertools
19import traceback
20import unittest
21from xml.etree import ElementTree
22
23import coverage
24from six import moves
25
26from tests import _loader
27
28
29class CaseResult(
30        collections.namedtuple('CaseResult', [
31            'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback'
32        ])):
33    """A serializable result of a single test case.
34
35  Attributes:
36    id (object): Any serializable object used to denote the identity of this
37      test case.
38    name (str or None): A human-readable name of the test case.
39    kind (CaseResult.Kind): The kind of test result.
40    stdout (object or None): Output on stdout, or None if nothing was captured.
41    stderr (object or None): Output on stderr, or None if nothing was captured.
42    skip_reason (object or None): The reason the test was skipped. Must be
43      something if self.kind is CaseResult.Kind.SKIP, else None.
44    traceback (object or None): The traceback of the test. Must be something if
45      self.kind is CaseResult.Kind.{ERROR, FAILURE, EXPECTED_FAILURE}, else
46      None.
47  """
48
49    class Kind(object):
50        UNTESTED = 'untested'
51        RUNNING = 'running'
52        ERROR = 'error'
53        FAILURE = 'failure'
54        SUCCESS = 'success'
55        SKIP = 'skip'
56        EXPECTED_FAILURE = 'expected failure'
57        UNEXPECTED_SUCCESS = 'unexpected success'
58
59    def __new__(cls,
60                id=None,
61                name=None,
62                kind=None,
63                stdout=None,
64                stderr=None,
65                skip_reason=None,
66                traceback=None):
67        """Helper keyword constructor for the namedtuple.
68
69    See this class' attributes for information on the arguments."""
70        assert id is not None
71        assert name is None or isinstance(name, str)
72        if kind is CaseResult.Kind.UNTESTED:
73            pass
74        elif kind is CaseResult.Kind.RUNNING:
75            pass
76        elif kind is CaseResult.Kind.ERROR:
77            assert traceback is not None
78        elif kind is CaseResult.Kind.FAILURE:
79            assert traceback is not None
80        elif kind is CaseResult.Kind.SUCCESS:
81            pass
82        elif kind is CaseResult.Kind.SKIP:
83            assert skip_reason is not None
84        elif kind is CaseResult.Kind.EXPECTED_FAILURE:
85            assert traceback is not None
86        elif kind is CaseResult.Kind.UNEXPECTED_SUCCESS:
87            pass
88        else:
89            assert False
90        return super(cls, CaseResult).__new__(cls, id, name, kind, stdout,
91                                              stderr, skip_reason, traceback)
92
93    def updated(self,
94                name=None,
95                kind=None,
96                stdout=None,
97                stderr=None,
98                skip_reason=None,
99                traceback=None):
100        """Get a new validated CaseResult with the fields updated.
101
102    See this class' attributes for information on the arguments."""
103        name = self.name if name is None else name
104        kind = self.kind if kind is None else kind
105        stdout = self.stdout if stdout is None else stdout
106        stderr = self.stderr if stderr is None else stderr
107        skip_reason = self.skip_reason if skip_reason is None else skip_reason
108        traceback = self.traceback if traceback is None else traceback
109        return CaseResult(
110            id=self.id,
111            name=name,
112            kind=kind,
113            stdout=stdout,
114            stderr=stderr,
115            skip_reason=skip_reason,
116            traceback=traceback)
117
118
119class AugmentedResult(unittest.TestResult):
120    """unittest.Result that keeps track of additional information.
121
122  Uses CaseResult objects to store test-case results, providing additional
123  information beyond that of the standard Python unittest library, such as
124  standard output.
125
126  Attributes:
127    id_map (callable): A unary callable mapping unittest.TestCase objects to
128      unique identifiers.
129    cases (dict): A dictionary mapping from the identifiers returned by id_map
130      to CaseResult objects corresponding to those IDs.
131  """
132
133    def __init__(self, id_map):
134        """Initialize the object with an identifier mapping.
135
136    Arguments:
137      id_map (callable): Corresponds to the attribute `id_map`."""
138        super(AugmentedResult, self).__init__()
139        self.id_map = id_map
140        self.cases = None
141
142    def startTestRun(self):
143        """See unittest.TestResult.startTestRun."""
144        super(AugmentedResult, self).startTestRun()
145        self.cases = dict()
146
147    def startTest(self, test):
148        """See unittest.TestResult.startTest."""
149        super(AugmentedResult, self).startTest(test)
150        case_id = self.id_map(test)
151        self.cases[case_id] = CaseResult(
152            id=case_id, name=test.id(), kind=CaseResult.Kind.RUNNING)
153
154    def addError(self, test, err):
155        """See unittest.TestResult.addError."""
156        super(AugmentedResult, self).addError(test, err)
157        case_id = self.id_map(test)
158        self.cases[case_id] = self.cases[case_id].updated(
159            kind=CaseResult.Kind.ERROR, traceback=err)
160
161    def addFailure(self, test, err):
162        """See unittest.TestResult.addFailure."""
163        super(AugmentedResult, self).addFailure(test, err)
164        case_id = self.id_map(test)
165        self.cases[case_id] = self.cases[case_id].updated(
166            kind=CaseResult.Kind.FAILURE, traceback=err)
167
168    def addSuccess(self, test):
169        """See unittest.TestResult.addSuccess."""
170        super(AugmentedResult, self).addSuccess(test)
171        case_id = self.id_map(test)
172        self.cases[case_id] = self.cases[case_id].updated(
173            kind=CaseResult.Kind.SUCCESS)
174
175    def addSkip(self, test, reason):
176        """See unittest.TestResult.addSkip."""
177        super(AugmentedResult, self).addSkip(test, reason)
178        case_id = self.id_map(test)
179        self.cases[case_id] = self.cases[case_id].updated(
180            kind=CaseResult.Kind.SKIP, skip_reason=reason)
181
182    def addExpectedFailure(self, test, err):
183        """See unittest.TestResult.addExpectedFailure."""
184        super(AugmentedResult, self).addExpectedFailure(test, err)
185        case_id = self.id_map(test)
186        self.cases[case_id] = self.cases[case_id].updated(
187            kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=err)
188
189    def addUnexpectedSuccess(self, test):
190        """See unittest.TestResult.addUnexpectedSuccess."""
191        super(AugmentedResult, self).addUnexpectedSuccess(test)
192        case_id = self.id_map(test)
193        self.cases[case_id] = self.cases[case_id].updated(
194            kind=CaseResult.Kind.UNEXPECTED_SUCCESS)
195
196    def set_output(self, test, stdout, stderr):
197        """Set the output attributes for the CaseResult corresponding to a test.
198
199    Args:
200      test (unittest.TestCase): The TestCase to set the outputs of.
201      stdout (str): Output from stdout to assign to self.id_map(test).
202      stderr (str): Output from stderr to assign to self.id_map(test).
203    """
204        case_id = self.id_map(test)
205        self.cases[case_id] = self.cases[case_id].updated(
206            stdout=stdout.decode(), stderr=stderr.decode())
207
208    def augmented_results(self, filter):
209        """Convenience method to retrieve filtered case results.
210
211    Args:
212      filter (callable): A unary predicate to filter over CaseResult objects.
213    """
214        return (self.cases[case_id]
215                for case_id in self.cases
216                if filter(self.cases[case_id]))
217
218
219class CoverageResult(AugmentedResult):
220    """Extension to AugmentedResult adding coverage.py support per test.\
221
222  Attributes:
223    coverage_context (coverage.Coverage): coverage.py management object.
224  """
225
226    def __init__(self, id_map):
227        """See AugmentedResult.__init__."""
228        super(CoverageResult, self).__init__(id_map=id_map)
229        self.coverage_context = None
230
231    def startTest(self, test):
232        """See unittest.TestResult.startTest.
233
234    Additionally initializes and begins code coverage tracking."""
235        super(CoverageResult, self).startTest(test)
236        self.coverage_context = coverage.Coverage(data_suffix=True)
237        self.coverage_context.start()
238
239    def stopTest(self, test):
240        """See unittest.TestResult.stopTest.
241
242    Additionally stops and deinitializes code coverage tracking."""
243        super(CoverageResult, self).stopTest(test)
244        self.coverage_context.stop()
245        self.coverage_context.save()
246        self.coverage_context = None
247
248
249class _Colors(object):
250    """Namespaced constants for terminal color magic numbers."""
251    HEADER = '\033[95m'
252    INFO = '\033[94m'
253    OK = '\033[92m'
254    WARN = '\033[93m'
255    FAIL = '\033[91m'
256    BOLD = '\033[1m'
257    UNDERLINE = '\033[4m'
258    END = '\033[0m'
259
260
261class TerminalResult(CoverageResult):
262    """Extension to CoverageResult adding basic terminal reporting."""
263
264    def __init__(self, out, id_map):
265        """Initialize the result object.
266
267    Args:
268      out (file-like): Output file to which terminal-colored live results will
269        be written.
270      id_map (callable): See AugmentedResult.__init__.
271    """
272        super(TerminalResult, self).__init__(id_map=id_map)
273        self.out = out
274
275    def startTestRun(self):
276        """See unittest.TestResult.startTestRun."""
277        super(TerminalResult, self).startTestRun()
278        self.out.write(
279            _Colors.HEADER + 'Testing gRPC Python...\n' + _Colors.END)
280
281    def stopTestRun(self):
282        """See unittest.TestResult.stopTestRun."""
283        super(TerminalResult, self).stopTestRun()
284        self.out.write(summary(self))
285        self.out.flush()
286
287    def addError(self, test, err):
288        """See unittest.TestResult.addError."""
289        super(TerminalResult, self).addError(test, err)
290        self.out.write(
291            _Colors.FAIL + 'ERROR         {}\n'.format(test.id()) + _Colors.END)
292        self.out.flush()
293
294    def addFailure(self, test, err):
295        """See unittest.TestResult.addFailure."""
296        super(TerminalResult, self).addFailure(test, err)
297        self.out.write(
298            _Colors.FAIL + 'FAILURE       {}\n'.format(test.id()) + _Colors.END)
299        self.out.flush()
300
301    def addSuccess(self, test):
302        """See unittest.TestResult.addSuccess."""
303        super(TerminalResult, self).addSuccess(test)
304        self.out.write(
305            _Colors.OK + 'SUCCESS       {}\n'.format(test.id()) + _Colors.END)
306        self.out.flush()
307
308    def addSkip(self, test, reason):
309        """See unittest.TestResult.addSkip."""
310        super(TerminalResult, self).addSkip(test, reason)
311        self.out.write(
312            _Colors.INFO + 'SKIP          {}\n'.format(test.id()) + _Colors.END)
313        self.out.flush()
314
315    def addExpectedFailure(self, test, err):
316        """See unittest.TestResult.addExpectedFailure."""
317        super(TerminalResult, self).addExpectedFailure(test, err)
318        self.out.write(
319            _Colors.INFO + 'FAILURE_OK    {}\n'.format(test.id()) + _Colors.END)
320        self.out.flush()
321
322    def addUnexpectedSuccess(self, test):
323        """See unittest.TestResult.addUnexpectedSuccess."""
324        super(TerminalResult, self).addUnexpectedSuccess(test)
325        self.out.write(
326            _Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) + _Colors.END)
327        self.out.flush()
328
329
330def _traceback_string(type, value, trace):
331    """Generate a descriptive string of a Python exception traceback.
332
333  Args:
334    type (class): The type of the exception.
335    value (Exception): The value of the exception.
336    trace (traceback): Traceback of the exception.
337
338  Returns:
339    str: Formatted exception descriptive string.
340  """
341    buffer = moves.cStringIO()
342    traceback.print_exception(type, value, trace, file=buffer)
343    return buffer.getvalue()
344
345
346def summary(result):
347    """A summary string of a result object.
348
349  Args:
350    result (AugmentedResult): The result object to get the summary of.
351
352  Returns:
353    str: The summary string.
354  """
355    assert isinstance(result, AugmentedResult)
356    untested = list(
357        result.augmented_results(
358            lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED))
359    running = list(
360        result.augmented_results(
361            lambda case_result: case_result.kind is CaseResult.Kind.RUNNING))
362    failures = list(
363        result.augmented_results(
364            lambda case_result: case_result.kind is CaseResult.Kind.FAILURE))
365    errors = list(
366        result.augmented_results(
367            lambda case_result: case_result.kind is CaseResult.Kind.ERROR))
368    successes = list(
369        result.augmented_results(
370            lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS))
371    skips = list(
372        result.augmented_results(
373            lambda case_result: case_result.kind is CaseResult.Kind.SKIP))
374    expected_failures = list(
375        result.augmented_results(
376            lambda case_result: case_result.kind is CaseResult.Kind.EXPECTED_FAILURE
377        ))
378    unexpected_successes = list(
379        result.augmented_results(
380            lambda case_result: case_result.kind is CaseResult.Kind.UNEXPECTED_SUCCESS
381        ))
382    running_names = [case.name for case in running]
383    finished_count = (len(failures) + len(errors) + len(successes) +
384                      len(expected_failures) + len(unexpected_successes))
385    statistics = ('{finished} tests finished:\n'
386                  '\t{successful} successful\n'
387                  '\t{unsuccessful} unsuccessful\n'
388                  '\t{skipped} skipped\n'
389                  '\t{expected_fail} expected failures\n'
390                  '\t{unexpected_successful} unexpected successes\n'
391                  'Interrupted Tests:\n'
392                  '\t{interrupted}\n'.format(
393                      finished=finished_count,
394                      successful=len(successes),
395                      unsuccessful=(len(failures) + len(errors)),
396                      skipped=len(skips),
397                      expected_fail=len(expected_failures),
398                      unexpected_successful=len(unexpected_successes),
399                      interrupted=str(running_names)))
400    tracebacks = '\n\n'.join(
401        [(_Colors.FAIL + '{test_name}' + _Colors.END + '\n' + _Colors.BOLD +
402          'traceback:' + _Colors.END + '\n' + '{traceback}\n' + _Colors.BOLD +
403          'stdout:' + _Colors.END + '\n' + '{stdout}\n' + _Colors.BOLD +
404          'stderr:' + _Colors.END + '\n' + '{stderr}\n').format(
405              test_name=result.name,
406              traceback=_traceback_string(*result.traceback),
407              stdout=result.stdout,
408              stderr=result.stderr)
409         for result in itertools.chain(failures, errors)])
410    notes = 'Unexpected successes: {}\n'.format(
411        [result.name for result in unexpected_successes])
412    return statistics + '\nErrors/Failures: \n' + tracebacks + '\n' + notes
413
414
415def jenkins_junit_xml(result):
416    """An XML tree object that when written is recognizable by Jenkins.
417
418  Args:
419    result (AugmentedResult): The result object to get the junit xml output of.
420
421  Returns:
422    ElementTree.ElementTree: The XML tree.
423  """
424    assert isinstance(result, AugmentedResult)
425    root = ElementTree.Element('testsuites')
426    suite = ElementTree.SubElement(root, 'testsuite', {
427        'name': 'Python gRPC tests',
428    })
429    for case in result.cases.values():
430        if case.kind is CaseResult.Kind.SUCCESS:
431            ElementTree.SubElement(suite, 'testcase', {
432                'name': case.name,
433            })
434        elif case.kind in (CaseResult.Kind.ERROR, CaseResult.Kind.FAILURE):
435            case_xml = ElementTree.SubElement(suite, 'testcase', {
436                'name': case.name,
437            })
438            error_xml = ElementTree.SubElement(case_xml, 'error', {})
439            error_xml.text = ''.format(case.stderr, case.traceback)
440    return ElementTree.ElementTree(element=root)
441