1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for generating Pigweed tests that execute in C++ and Python."""
15
16import argparse
17from dataclasses import dataclass
18from datetime import datetime
19from collections import defaultdict
20import unittest
21
22from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
23                    Sequence, TextIO, TypeVar, Union)
24
25_CPP_HEADER = f"""\
26// Copyright {datetime.now().year} The Pigweed Authors
27//
28// Licensed under the Apache License, Version 2.0 (the "License"); you may not
29// use this file except in compliance with the License. You may obtain a copy of
30// the License at
31//
32//     https://www.apache.org/licenses/LICENSE-2.0
33//
34// Unless required by applicable law or agreed to in writing, software
35// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
36// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
37// License for the specific language governing permissions and limitations under
38// the License.
39
40// AUTOGENERATED - DO NOT EDIT
41//
42// Generated at {datetime.now().isoformat()}
43
44// clang-format off
45"""
46
47
48class Error(Exception):
49    """Something went wrong when generating tests."""
50
51
52T = TypeVar('T')
53
54
55@dataclass
56class Context(Generic[T]):
57    """Info passed into test generator functions for each test case."""
58    group: str
59    count: int
60    total: int
61    test_case: T
62
63    def cc_name(self) -> str:
64        name = ''.join(w.capitalize()
65                       for w in self.group.replace('-', ' ').split(' '))
66        name = ''.join(c if c.isalnum() else '_' for c in name)
67        return f'{name}_{self.count}' if self.total > 1 else name
68
69    def py_name(self) -> str:
70        name = 'test_' + ''.join(c if c.isalnum() else '_'
71                                 for c in self.group.lower())
72        return f'{name}_{self.count}' if self.total > 1 else name
73
74
75# Test cases are specified as a sequence of strings or test case instances. The
76# strings are used to separate the tests into named groups. For example:
77#
78#   STR_SPLIT_TEST_CASES = (
79#     'Empty input',
80#     MyTestCase('', '', []),
81#     MyTestCase('', 'foo', []),
82#     'Split on single character',
83#     MyTestCase('abcde', 'c', ['ab', 'de']),
84#     ...
85#   )
86#
87GroupOrTest = Union[str, T]
88
89# Python tests are generated by a function that returns a function usable as a
90# unittest.TestCase method.
91PyTest = Callable[[unittest.TestCase], None]
92PyTestGenerator = Callable[[Context[T]], PyTest]
93
94# C++ tests are generated with a function that returns or yields lines of C++
95# code for the given test case.
96CcTestGenerator = Callable[[Context[T]], Iterable[str]]
97
98
99class TestGenerator(Generic[T]):
100    """Generates tests for multiple languages from a series of test cases."""
101    def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
102        self._cases: Dict[str, List[T]] = defaultdict(list)
103        message = ''
104
105        if len(test_cases) < 2:
106            raise Error('At least one test case must be provided')
107
108        if not isinstance(test_cases[0], str):
109            raise Error(
110                'The first item in the test cases must be a group name string')
111
112        for case in test_cases:
113            if isinstance(case, str):
114                message = case
115            else:
116                self._cases[message].append(case)
117
118        if '' in self._cases:
119            raise Error('Empty test group names are not permitted')
120
121    def _test_contexts(self) -> Iterator[Context[T]]:
122        for group, test_list in self._cases.items():
123            for i, test_case in enumerate(test_list, 1):
124                yield Context(group, i, len(test_list), test_case)
125
126    def _generate_python_tests(self, define_py_test: PyTestGenerator):
127        tests: Dict[str, Callable[[Any], None]] = {}
128
129        for ctx in self._test_contexts():
130            test = define_py_test(ctx)
131            test.__name__ = ctx.py_name()
132
133            if test.__name__ in tests:
134                raise Error(
135                    f'Multiple Python tests are named {test.__name__}!')
136
137            tests[test.__name__] = test
138
139        return tests
140
141    def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
142        """Returns a Python unittest.TestCase class with tests for each case."""
143        return type(name, (unittest.TestCase, ),
144                    self._generate_python_tests(define_py_test))
145
146    def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str,
147                           footer: str) -> Iterator[str]:
148        yield _CPP_HEADER
149        yield header
150
151        for ctx in self._test_contexts():
152            yield from define_cpp_test(ctx)
153            yield ''
154
155        yield footer
156
157    def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator,
158                 header: str, footer: str):
159        """Writes C++ unit tests for each test case to the given file."""
160        for line in self._generate_cc_tests(define_cpp_test, header, footer):
161            output.write(line)
162            output.write('\n')
163
164
165def _to_chars(data: bytes) -> Iterator[str]:
166    for i, byte in enumerate(data):
167        try:
168            char = data[i:i + 1].decode()
169            yield char if char.isprintable() else fr'\x{byte:02x}'
170        except UnicodeDecodeError:
171            yield fr'\x{byte:02x}'
172
173
174def cc_string(data: Union[str, bytes]) -> str:
175    """Returns a C++ string literal version of a byte string or UTF-8 string."""
176    if isinstance(data, str):
177        data = data.encode()
178
179    return '"' + ''.join(_to_chars(data)) + '"'
180
181
182def parse_test_generation_args() -> argparse.Namespace:
183    parser = argparse.ArgumentParser(description='Generate unit test files')
184    parser.add_argument('--generate-cc-test',
185                        type=argparse.FileType('w'),
186                        help='Generate the C++ test file')
187    return parser.parse_known_args()[0]
188