1# Written by Will Bond <will@wbond.net>
2#
3# The author or authors of this code dedicate any and all copyright interest in
4# this code to the public domain. We make this dedication for the benefit of the
5# public at large and to the detriment of our heirs and successors. We intend
6# this dedication to be an overt act of relinquishment in perpetuity of all
7# present and future rights to this code under copyright law.
8
9
10def data(provider_method, first_param_name_suffix=False):
11    """
12    A method decorator for unittest.TestCase classes that configured a
13    static method to be used to provide multiple sets of test data to a single
14    test
15
16    :param provider_method:
17        The name of the staticmethod of the class to use as the data provider
18
19    :param first_param_name_suffix:
20        If the first parameter for each set should be appended to the method
21        name to generate the name of the test. Otherwise integers are used.
22
23    :return:
24        The decorated function
25    """
26
27    def test_func_decorator(test_func):
28        test_func._provider_method = provider_method
29        test_func._provider_name_suffix = first_param_name_suffix
30        return test_func
31    return test_func_decorator
32
33
34def data_decorator(cls):
35    """
36    A class decorator that works with the @provider decorator to generate test
37    method from a data provider
38    """
39
40    def generate_test_func(name, original_function, num, params):
41        if original_function._provider_name_suffix:
42            data_name = params[0]
43            params = params[1:]
44        else:
45            data_name = num
46        expanded_name = 'test_%s_%s' % (name, data_name)
47
48        # We used expanded variable names here since this line is present in
49        # backtraces that are generated from test failures.
50        def generated_test_function(self):
51            original_function(self, *params)
52
53        setattr(cls, expanded_name, generated_test_function)
54
55    for name in dir(cls):
56        func = getattr(cls, name)
57        if hasattr(func, '_provider_method'):
58            num = 1
59            for params in getattr(cls, func._provider_method)():
60                generate_test_func(name, func, num, params)
61                num += 1
62
63    return cls
64