1#!/usr/bin/env python3.4
2#
3# Copyright 2016 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import re
18import unittest
19
20from acts import signals
21
22
23# Have an instance of unittest.TestCase so we could reuse some logic from
24# python's own unittest.
25# _ProxyTest is required because py2 does not allow instantiating
26# unittest.TestCase directly.
27class _ProxyTest(unittest.TestCase):
28    def runTest(self):
29        pass
30
31
32_pyunit_proxy = _ProxyTest()
33
34
35def assert_equal(first, second, msg=None, extras=None):
36    """Assert an expression evaluates to true, otherwise fail the test.
37
38    Error message is "first != second" by default. Additional explanation can
39    be supplied in the message.
40
41    Args:
42        expr: The expression that is evaluated.
43        msg: A string that adds additional info about the failure.
44        extras: An optional field for extra information to be included in
45                test result.
46    """
47    my_msg = None
48    try:
49        _pyunit_proxy.assertEqual(first, second)
50    except Exception as e:
51        # We have to catch all here for py2/py3 compatibility.
52        # In py2, assertEqual throws exceptions.AssertionError, which does not
53        # exist in py3. In py3, it throws unittest.case.failureException, which
54        # does not exist in py2. To accommodate using explicit catch complicates
55        # the code like hell, so I opted to catch all instead.
56        my_msg = str(e)
57        if msg:
58            my_msg = "%s %s" % (my_msg, msg)
59    # This is a hack to remove the stacktrace produced by the above exception.
60    if my_msg is not None:
61        fail(my_msg, extras=extras)
62
63
64def assert_almost_equal(first,
65                        second,
66                        places=7,
67                        msg=None,
68                        delta=None,
69                        extras=None):
70    """
71    Assert FIRST to be within +/- DELTA to SECOND, otherwise fail the
72    test.
73    :param first: The first argument, LHS
74    :param second: The second argument, RHS
75    :param places: For floating points, how many decimal places to look into
76    :param msg: Message to display on failure
77    :param delta: The +/- first and second could be apart from each other
78    :param extras: Extra object passed to test failure handler
79    :return:
80    """
81    my_msg = None
82    try:
83        if delta:
84            _pyunit_proxy.assertAlmostEqual(
85                first, second, msg=msg, delta=delta)
86        else:
87            _pyunit_proxy.assertAlmostEqual(
88                first, second, places=places, msg=msg)
89    except Exception as e:
90        my_msg = str(e)
91        if msg:
92            my_msg = "%s %s" % (my_msg, msg)
93    # This is a hack to remove the stacktrace produced by the above exception.
94    if my_msg is not None:
95        fail(my_msg, extras=extras)
96
97
98def assert_raises(expected_exception, extras=None, *args, **kwargs):
99    """Assert that an exception is raised when a function is called.
100
101    If no exception is raised, test fail. If an exception is raised but not
102    of the expected type, the exception is let through.
103
104    This should only be used as a context manager:
105        with assert_raises(Exception):
106            func()
107
108    Args:
109        expected_exception: An exception class that is expected to be
110                            raised.
111        extras: An optional field for extra information to be included in
112                test result.
113    """
114    context = _AssertRaisesContext(expected_exception, extras=extras)
115    return context
116
117
118def assert_raises_regex(expected_exception,
119                        expected_regex,
120                        extras=None,
121                        *args,
122                        **kwargs):
123    """Assert that an exception is raised when a function is called.
124
125    If no exception is raised, test fail. If an exception is raised but not
126    of the expected type, the exception is let through. If an exception of the
127    expected type is raised but the error message does not match the
128    expected_regex, test fail.
129
130    This should only be used as a context manager:
131        with assert_raises(Exception):
132            func()
133
134    Args:
135        expected_exception: An exception class that is expected to be
136                            raised.
137        extras: An optional field for extra information to be included in
138                test result.
139    """
140    context = _AssertRaisesContext(
141        expected_exception, expected_regex, extras=extras)
142    return context
143
144
145def assert_true(expr, msg, extras=None):
146    """Assert an expression evaluates to true, otherwise fail the test.
147
148    Args:
149        expr: The expression that is evaluated.
150        msg: A string explaining the details in case of failure.
151        extras: An optional field for extra information to be included in
152                test result.
153    """
154    if not expr:
155        fail(msg, extras)
156
157
158def assert_false(expr, msg, extras=None):
159    """Assert an expression evaluates to false, otherwise fail the test.
160
161    Args:
162        expr: The expression that is evaluated.
163        msg: A string explaining the details in case of failure.
164        extras: An optional field for extra information to be included in
165                test result.
166    """
167    if expr:
168        fail(msg, extras)
169
170
171def skip(reason, extras=None):
172    """Skip a test case.
173
174    Args:
175        reason: The reason this test is skipped.
176        extras: An optional field for extra information to be included in
177                test result.
178
179    Raises:
180        signals.TestSkip is raised to mark a test case as skipped.
181    """
182    raise signals.TestSkip(reason, extras)
183
184
185def skip_if(expr, reason, extras=None):
186    """Skip a test case if expression evaluates to True.
187
188    Args:
189        expr: The expression that is evaluated.
190        reason: The reason this test is skipped.
191        extras: An optional field for extra information to be included in
192                test result.
193    """
194    if expr:
195        skip(reason, extras)
196
197
198def abort_class(reason, extras=None):
199    """Abort all subsequent test cases within the same test class in one
200    iteration.
201
202    If one test class is requested multiple times in a test run, this can
203    only abort one of the requested executions, NOT all.
204
205    Args:
206        reason: The reason to abort.
207        extras: An optional field for extra information to be included in
208                test result.
209
210    Raises:
211        signals.TestAbortClass is raised to abort all subsequent tests in a
212        test class.
213    """
214    raise signals.TestAbortClass(reason, extras)
215
216
217def abort_class_if(expr, reason, extras=None):
218    """Abort all subsequent test cases within the same test class in one
219    iteration, if expression evaluates to True.
220
221    If one test class is requested multiple times in a test run, this can
222    only abort one of the requested executions, NOT all.
223
224    Args:
225        expr: The expression that is evaluated.
226        reason: The reason to abort.
227        extras: An optional field for extra information to be included in
228                test result.
229
230    Raises:
231        signals.TestAbortClass is raised to abort all subsequent tests in a
232        test class.
233    """
234    if expr:
235        abort_class(reason, extras)
236
237
238def abort_all(reason, extras=None):
239    """Abort all subsequent test cases, including the ones not in this test
240    class or iteration.
241
242    Args:
243        reason: The reason to abort.
244        extras: An optional field for extra information to be included in
245                test result.
246
247    Raises:
248        signals.TestAbortAll is raised to abort all subsequent tests.
249    """
250    raise signals.TestAbortAll(reason, extras)
251
252
253def abort_all_if(expr, reason, extras=None):
254    """Abort all subsequent test cases, if the expression evaluates to
255    True.
256
257    Args:
258        expr: The expression that is evaluated.
259        reason: The reason to abort.
260        extras: An optional field for extra information to be included in
261                test result.
262
263    Raises:
264        signals.TestAbortAll is raised to abort all subsequent tests.
265    """
266    if expr:
267        abort_all(reason, extras)
268
269
270def fail(msg, extras=None):
271    """Explicitly fail a test case.
272
273    Args:
274        msg: A string explaining the details of the failure.
275        extras: An optional field for extra information to be included in
276                test result.
277
278    Raises:
279        signals.TestFailure is raised to mark a test case as failed.
280    """
281    raise signals.TestFailure(msg, extras)
282
283
284def explicit_pass(msg, extras=None):
285    """Explicitly pass a test case.
286
287    A test with not uncaught exception will pass implicitly so the usage of
288    this is optional. It is intended for reporting extra information when a
289    test passes.
290
291    Args:
292        msg: A string explaining the details of the passed test.
293        extras: An optional field for extra information to be included in
294                test result.
295
296    Raises:
297        signals.TestPass is raised to mark a test case as passed.
298    """
299    raise signals.TestPass(msg, extras)
300
301
302class _AssertRaisesContext(object):
303    """A context manager used to implement TestCase.assertRaises* methods."""
304
305    def __init__(self, expected, expected_regexp=None, extras=None):
306        self.expected = expected
307        self.failureException = signals.TestFailure
308        self.expected_regexp = expected_regexp
309        self.extras = extras
310
311    def __enter__(self):
312        return self
313
314    def __exit__(self, exc_type, exc_value, tb):
315        if exc_type is None:
316            try:
317                exc_name = self.expected.__name__
318            except AttributeError:
319                exc_name = str(self.expected)
320            raise signals.TestFailure(
321                "{} not raised".format(exc_name), extras=self.extras)
322        if not issubclass(exc_type, self.expected):
323            # let unexpected exceptions pass through
324            return False
325        self.exception = exc_value  # store for later retrieval
326        if self.expected_regexp is None:
327            return True
328
329        expected_regexp = self.expected_regexp
330        if isinstance(expected_regexp, str):
331            expected_regexp = re.compile(expected_regexp)
332        if not expected_regexp.search(str(exc_value)):
333            raise signals.TestFailure(
334                '"%s" does not match "%s"' %
335                (expected_regexp.pattern, str(exc_value)),
336                extras=self.extras)
337        return True
338