1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Facilities for creating multiple test combinations.
16
17Here is a simple example for testing various optimizers in Eager and Graph:
18
19class AdditionExample(test.TestCase, parameterized.TestCase):
20  @combinations.generate(
21     combinations.combine(mode=["graph", "eager"],
22                          optimizer=[AdamOptimizer(),
23                                     GradientDescentOptimizer()]))
24  def testOptimizer(self, optimizer):
25    ... f(optimizer)...
26
27This will run `testOptimizer` 4 times with the specified optimizers: 2 in
28Eager and 2 in Graph mode.
29The test is going to accept the same parameters as the ones used in `combine()`.
30The parameters need to match by name between the `combine()` call and the test
31signature.  It is necessary to accept all parameters. See `OptionalParameter`
32for a way to implement optional parameters.
33
34`combine()` function is available for creating a cross product of various
35options.  `times()` function exists for creating a product of N `combine()`-ed
36results.
37
38The execution of generated tests can be customized in a number of ways:
39-  The test can be skipped if it is not running in the correct environment.
40-  The arguments that are passed to the test can be additionally transformed.
41-  The test can be run with specific Python context managers.
42These behaviors can be customized by providing instances of `TestCombination` to
43`generate()`.
44"""
45
46from __future__ import absolute_import
47from __future__ import division
48from __future__ import print_function
49
50from collections import OrderedDict
51import contextlib
52import re
53import types
54import unittest
55
56from absl.testing import parameterized
57import six
58
59from tensorflow.python.util import tf_inspect
60from tensorflow.python.util.tf_export import tf_export
61
62
63@tf_export("__internal__.test.combinations.TestCombination", v1=[])
64class TestCombination(object):
65  """Customize the behavior of `generate()` and the tests that it executes.
66
67  Here is sequence of steps for executing a test combination:
68    1. The test combination is evaluated for whether it should be executed in
69       the given environment by calling `should_execute_combination`.
70    2. If the test combination is going to be executed, then the arguments for
71       all combined parameters are validated.  Some arguments can be handled in
72       a special way.  This is achieved by implementing that logic in
73       `ParameterModifier` instances that returned from `parameter_modifiers`.
74    3. Before executing the test, `context_managers` are installed
75       around it.
76  """
77
78  def should_execute_combination(self, kwargs):
79    """Indicates whether the combination of test arguments should be executed.
80
81    If the environment doesn't satisfy the dependencies of the test
82    combination, then it can be skipped.
83
84    Args:
85      kwargs:  Arguments that are passed to the test combination.
86
87    Returns:
88      A tuple boolean and an optional string.  The boolean False indicates
89    that the test should be skipped.  The string would indicate a textual
90    description of the reason.  If the test is going to be executed, then
91    this method returns `None` instead of the string.
92    """
93    del kwargs
94    return (True, None)
95
96  def parameter_modifiers(self):
97    """Returns `ParameterModifier` instances that customize the arguments."""
98    return []
99
100  def context_managers(self, kwargs):
101    """Return context managers for running the test combination.
102
103    The test combination will run under all context managers that all
104    `TestCombination` instances return.
105
106    Args:
107      kwargs:  Arguments and their values that are passed to the test
108        combination.
109
110    Returns:
111      A list of instantiated context managers.
112    """
113    del kwargs
114    return []
115
116
117@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
118class ParameterModifier(object):
119  """Customizes the behavior of a particular parameter.
120
121  Users should override `modified_arguments()` to modify the parameter they
122  want, eg: change the value of certain parameter or filter it from the params
123  passed to the test case.
124
125  See the sample usage below, it will change any negative parameters to zero
126  before it gets passed to test case.
127  ```
128  class NonNegativeParameterModifier(ParameterModifier):
129
130    def modified_arguments(self, kwargs, requested_parameters):
131      updates = {}
132      for name, value in kwargs.items():
133        if value < 0:
134          updates[name] = 0
135      return updates
136  ```
137  """
138
139  DO_NOT_PASS_TO_THE_TEST = object()
140
141  def __init__(self, parameter_name=None):
142    """Construct a parameter modifier that may be specific to a parameter.
143
144    Args:
145      parameter_name:  A `ParameterModifier` instance may operate on a class of
146        parameters or on a parameter with a particular name.  Only
147        `ParameterModifier` instances that are of a unique type or were
148        initialized with a unique `parameter_name` will be executed.
149        See `__eq__` and `__hash__`.
150    """
151    object.__init__(self)
152    self._parameter_name = parameter_name
153
154  def modified_arguments(self, kwargs, requested_parameters):
155    """Replace user-provided arguments before they are passed to a test.
156
157    This makes it possible to adjust user-provided arguments before passing
158    them to the test method.
159
160    Args:
161      kwargs:  The combined arguments for the test.
162      requested_parameters: The set of parameters that are defined in the
163        signature of the test method.
164
165    Returns:
166      A dictionary with updates to `kwargs`.  Keys with values set to
167      `ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and
168      not passed to the test.
169    """
170    del kwargs, requested_parameters
171    return {}
172
173  def __eq__(self, other):
174    """Compare `ParameterModifier` by type and `parameter_name`."""
175    if self is other:
176      return True
177    elif type(self) is type(other):
178      return self._parameter_name == other._parameter_name
179    else:
180      return False
181
182  def __ne__(self, other):
183    return not self.__eq__(other)
184
185  def __hash__(self):
186    """Compare `ParameterModifier` by type or `parameter_name`."""
187    if self._parameter_name:
188      return hash(self._parameter_name)
189    else:
190      return id(self.__class__)
191
192
193@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
194class OptionalParameter(ParameterModifier):
195  """A parameter that is optional in `combine()` and in the test signature.
196
197  `OptionalParameter` is usually used with `TestCombination` in the
198  `parameter_modifiers()`. It allows `TestCombination` to skip certain
199  parameters when passing them to `combine()`, since the `TestCombination` might
200  consume the param and create some context based on the value it gets.
201
202  See the sample usage below:
203
204  ```
205  class EagerGraphCombination(TestCombination):
206
207    def context_managers(self, kwargs):
208      mode = kwargs.pop("mode", None)
209      if mode is None:
210        return []
211      elif mode == "eager":
212        return [context.eager_mode()]
213      elif mode == "graph":
214        return [ops.Graph().as_default(), context.graph_mode()]
215      else:
216        raise ValueError(
217            "'mode' has to be either 'eager' or 'graph', got {}".format(mode))
218
219    def parameter_modifiers(self):
220      return [test_combinations.OptionalParameter("mode")]
221  ```
222
223  When the test case is generated, the param "mode" will not be passed to the
224  test method, since it is consumed by the `EagerGraphCombination`.
225  """
226
227  def modified_arguments(self, kwargs, requested_parameters):
228    if self._parameter_name in requested_parameters:
229      return {}
230    else:
231      return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST}
232
233
234def generate(combinations, test_combinations=()):
235  """A decorator for generating combinations of a test method or a test class.
236
237  Parameters of the test method must match by name to get the corresponding
238  value of the combination.  Tests must accept all parameters that are passed
239  other than the ones that are `OptionalParameter`.
240
241  Args:
242    combinations: a list of dictionaries created using combine() and times().
243    test_combinations: a tuple of `TestCombination` instances that customize
244      the execution of generated tests.
245
246  Returns:
247    a decorator that will cause the test method or the test class to be run
248    under the specified conditions.
249
250  Raises:
251    ValueError: if any parameters were not accepted by the test method
252  """
253  def decorator(test_method_or_class):
254    """The decorator to be returned."""
255
256    # Generate good test names that can be used with --test_filter.
257    named_combinations = []
258    for combination in combinations:
259      # We use OrderedDicts in `combine()` and `times()` to ensure stable
260      # order of keys in each dictionary.
261      assert isinstance(combination, OrderedDict)
262      name = "".join([
263          "_{}_{}".format("".join(filter(str.isalnum, key)),
264                          "".join(filter(str.isalnum, _get_name(value, i))))
265          for i, (key, value) in enumerate(combination.items())
266      ])
267      named_combinations.append(
268          OrderedDict(
269              list(combination.items()) +
270              [("testcase_name", "_test{}".format(name))]))
271
272    if isinstance(test_method_or_class, type):
273      class_object = test_method_or_class
274      class_object._test_method_ids = test_method_ids = {}
275      for name, test_method in six.iteritems(class_object.__dict__.copy()):
276        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
277            isinstance(test_method, types.FunctionType)):
278          delattr(class_object, name)
279          methods = {}
280          parameterized._update_class_dict_for_param_test_case(
281              class_object.__name__, methods, test_method_ids, name,
282              parameterized._ParameterizedTestIter(
283                  _augment_with_special_arguments(
284                      test_method, test_combinations=test_combinations),
285                  named_combinations, parameterized._NAMED, name))
286          for method_name, method in six.iteritems(methods):
287            setattr(class_object, method_name, method)
288
289      return class_object
290    else:
291      test_method = _augment_with_special_arguments(
292          test_method_or_class, test_combinations=test_combinations)
293      return parameterized.named_parameters(*named_combinations)(test_method)
294
295  return decorator
296
297
298def _augment_with_special_arguments(test_method, test_combinations):
299  def decorated(self, **kwargs):
300    """A wrapped test method that can treat some arguments in a special way."""
301    original_kwargs = kwargs.copy()
302
303    # Skip combinations that are going to be executed in a different testing
304    # environment.
305    reasons_to_skip = []
306    for combination in test_combinations:
307      should_execute, reason = combination.should_execute_combination(
308          original_kwargs.copy())
309      if not should_execute:
310        reasons_to_skip.append(" - " + reason)
311
312    if reasons_to_skip:
313      self.skipTest("\n".join(reasons_to_skip))
314
315    customized_parameters = []
316    for combination in test_combinations:
317      customized_parameters.extend(combination.parameter_modifiers())
318    customized_parameters = set(customized_parameters)
319
320    # The function for running the test under the total set of
321    # `context_managers`:
322    def execute_test_method():
323      requested_parameters = tf_inspect.getfullargspec(test_method).args
324      for customized_parameter in customized_parameters:
325        for argument, value in customized_parameter.modified_arguments(
326            original_kwargs.copy(), requested_parameters).items():
327          if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
328            kwargs.pop(argument, None)
329          else:
330            kwargs[argument] = value
331
332      omitted_arguments = set(requested_parameters).difference(
333          set(list(kwargs.keys()) + ["self"]))
334      if omitted_arguments:
335        raise ValueError("The test requires parameters whose arguments "
336                         "were not passed: {} .".format(omitted_arguments))
337      missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
338          set(requested_parameters))
339      if missing_arguments:
340        raise ValueError("The test does not take parameters that were passed "
341                         ": {} .".format(missing_arguments))
342
343      kwargs_to_pass = {}
344      for parameter in requested_parameters:
345        if parameter == "self":
346          kwargs_to_pass[parameter] = self
347        else:
348          kwargs_to_pass[parameter] = kwargs[parameter]
349      test_method(**kwargs_to_pass)
350
351    # Install `context_managers` before running the test:
352    context_managers = []
353    for combination in test_combinations:
354      for manager in combination.context_managers(
355          original_kwargs.copy()):
356        context_managers.append(manager)
357
358    if hasattr(contextlib, "nested"):  # Python 2
359      # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
360      with contextlib.nested(*context_managers):
361        execute_test_method()
362    else:  # Python 3
363      with contextlib.ExitStack() as context_stack:
364        for manager in context_managers:
365          context_stack.enter_context(manager)
366        execute_test_method()
367
368  return decorated
369
370
371@tf_export("__internal__.test.combinations.combine", v1=[])
372def combine(**kwargs):
373  """Generate combinations based on its keyword arguments.
374
375  Two sets of returned combinations can be concatenated using +.  Their product
376  can be computed using `times()`.
377
378  Args:
379    **kwargs: keyword arguments of form `option=[possibilities, ...]`
380         or `option=the_only_possibility`.
381
382  Returns:
383    a list of dictionaries for each combination. Keys in the dictionaries are
384    the keyword argument names.  Each key has one value - one of the
385    corresponding keyword argument values.
386  """
387  if not kwargs:
388    return [OrderedDict()]
389
390  sort_by_key = lambda k: k[0]
391  kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
392  first = list(kwargs.items())[0]
393
394  rest = dict(list(kwargs.items())[1:])
395  rest_combined = combine(**rest)
396
397  key = first[0]
398  values = first[1]
399  if not isinstance(values, list):
400    values = [values]
401
402  return [
403      OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
404      for v in values
405      for combined in rest_combined
406  ]
407
408
409@tf_export("__internal__.test.combinations.times", v1=[])
410def times(*combined):
411  """Generate a product of N sets of combinations.
412
413  times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4])
414
415  Args:
416    *combined: N lists of dictionaries that specify combinations.
417
418  Returns:
419    a list of dictionaries for each combination.
420
421  Raises:
422    ValueError: if some of the inputs have overlapping keys.
423  """
424  assert combined
425
426  if len(combined) == 1:
427    return combined[0]
428
429  first = combined[0]
430  rest_combined = times(*combined[1:])
431
432  combined_results = []
433  for a in first:
434    for b in rest_combined:
435      if set(a.keys()).intersection(set(b.keys())):
436        raise ValueError("Keys need to not overlap: {} vs {}".format(
437            a.keys(), b.keys()))
438
439      combined_results.append(OrderedDict(list(a.items()) + list(b.items())))
440  return combined_results
441
442
443@tf_export("__internal__.test.combinations.NamedObject", v1=[])
444class NamedObject(object):
445  """A class that translates an object into a good test name."""
446
447  def __init__(self, name, obj):
448    object.__init__(self)
449    self._name = name
450    self._obj = obj
451
452  def __getattr__(self, name):
453    return getattr(self._obj, name)
454
455  def __call__(self, *args, **kwargs):
456    return self._obj(*args, **kwargs)
457
458  def __iter__(self):
459    return self._obj.__iter__()
460
461  def __repr__(self):
462    return self._name
463
464
465def _get_name(value, index):
466  return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value))
467