1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for unit-testing Keras."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import itertools
24import unittest
25
26from absl.testing import parameterized
27
28from tensorflow.python import keras
29from tensorflow.python import tf2
30from tensorflow.python.eager import context
31from tensorflow.python.keras import testing_utils
32from tensorflow.python.platform import test
33from tensorflow.python.util import nest
34
35
36class TestCase(test.TestCase, parameterized.TestCase):
37
38  def tearDown(self):
39    keras.backend.clear_session()
40    super(TestCase, self).tearDown()
41
42
43# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass
44# it. Or perhaps make 'subclass' always use a custom build method.
45def run_with_all_model_types(
46    test_or_class=None,
47    exclude_models=None):
48  """Execute the decorated test with all Keras model types.
49
50  This decorator is intended to be applied either to individual test methods in
51  a `keras_parameterized.TestCase` class, or directly to a test class that
52  extends it. Doing so will cause the contents of the individual test
53  method (or all test methods in the class) to be executed multiple times - once
54  for each Keras model type.
55
56  The Keras model types are: ['functional', 'subclass', 'sequential']
57
58  Note: if stacking this decorator with absl.testing's parameterized decorators,
59  those should be at the bottom of the stack.
60
61  Various methods in `testing_utils` to get models will auto-generate a model
62  of the currently active Keras model type. This allows unittests to confirm
63  the equivalence between different Keras models.
64
65  For example, consider the following unittest:
66
67  ```python
68  class MyTests(testing_utils.KerasTestCase):
69
70    @testing_utils.run_with_all_model_types(
71      exclude_models = ['sequential'])
72    def test_foo(self):
73      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
74      optimizer = RMSPropOptimizer(learning_rate=0.001)
75      loss = 'mse'
76      metrics = ['mae']
77      model.compile(optimizer, loss, metrics=metrics)
78
79      inputs = np.zeros((10, 3))
80      targets = np.zeros((10, 4))
81      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
82      dataset = dataset.repeat(100)
83      dataset = dataset.batch(10)
84
85      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
86
87  if __name__ == "__main__":
88    tf.test.main()
89  ```
90
91  This test tries building a small mlp as both a functional model and as a
92  subclass model.
93
94  We can also annotate the whole class if we want this to apply to all tests in
95  the class:
96  ```python
97  @testing_utils.run_with_all_model_types(exclude_models = ['sequential'])
98  class MyTests(testing_utils.KerasTestCase):
99
100    def test_foo(self):
101      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
102      optimizer = RMSPropOptimizer(learning_rate=0.001)
103      loss = 'mse'
104      metrics = ['mae']
105      model.compile(optimizer, loss, metrics=metrics)
106
107      inputs = np.zeros((10, 3))
108      targets = np.zeros((10, 4))
109      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
110      dataset = dataset.repeat(100)
111      dataset = dataset.batch(10)
112
113      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
114
115  if __name__ == "__main__":
116    tf.test.main()
117  ```
118
119
120  Args:
121    test_or_class: test method or class to be annotated. If None,
122      this method returns a decorator that can be applied to a test method or
123      test class. If it is not None this returns the decorator applied to the
124      test or class.
125    exclude_models: A collection of Keras model types to not run.
126      (May also be a single model type not wrapped in a collection).
127      Defaults to None.
128
129  Returns:
130    Returns a decorator that will run the decorated test method multiple times:
131    once for each desired Keras model type.
132
133  Raises:
134    ImportError: If abseil parameterized is not installed or not included as
135      a target dependency.
136  """
137  model_types = ['functional', 'subclass', 'sequential']
138  params = [('_%s' % model, model) for model in model_types
139            if model not in nest.flatten(exclude_models)]
140
141  def single_method_decorator(f):
142    """Decorator that constructs the test cases."""
143    # Use named_parameters so it can be individually run from the command line
144    @parameterized.named_parameters(*params)
145    @functools.wraps(f)
146    def decorated(self, model_type, *args, **kwargs):
147      """A run of a single test case w/ the specified model type."""
148      if model_type == 'functional':
149        _test_functional_model_type(f, self, *args, **kwargs)
150      elif model_type == 'subclass':
151        _test_subclass_model_type(f, self, *args, **kwargs)
152      elif model_type == 'sequential':
153        _test_sequential_model_type(f, self, *args, **kwargs)
154      else:
155        raise ValueError('Unknown model type: %s' % (model_type,))
156    return decorated
157
158  return _test_or_class_decorator(test_or_class, single_method_decorator)
159
160
161def _test_functional_model_type(f, test_or_class, *args, **kwargs):
162  with testing_utils.model_type_scope('functional'):
163    f(test_or_class, *args, **kwargs)
164
165
166def _test_subclass_model_type(f, test_or_class, *args, **kwargs):
167  with testing_utils.model_type_scope('subclass'):
168    f(test_or_class, *args, **kwargs)
169
170
171def _test_sequential_model_type(f, test_or_class, *args, **kwargs):
172  with testing_utils.model_type_scope('sequential'):
173    f(test_or_class, *args, **kwargs)
174
175
176def run_all_keras_modes(
177    test_or_class=None,
178    config=None,
179    always_skip_v1=False):
180  """Execute the decorated test with all keras execution modes.
181
182  This decorator is intended to be applied either to individual test methods in
183  a `keras_parameterized.TestCase` class, or directly to a test class that
184  extends it. Doing so will cause the contents of the individual test
185  method (or all test methods in the class) to be executed multiple times -
186  once executing in legacy graph mode, once running eagerly and with
187  `should_run_eagerly` returning True, and once running eagerly with
188  `should_run_eagerly` returning False.
189
190  If Tensorflow v2 behavior is enabled, legacy graph mode will be skipped, and
191  the test will only run twice.
192
193  Note: if stacking this decorator with absl.testing's parameterized decorators,
194  those should be at the bottom of the stack.
195
196  For example, consider the following unittest:
197
198  ```python
199  class MyTests(testing_utils.KerasTestCase):
200
201    @testing_utils.run_all_keras_modes
202    def test_foo(self):
203      model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
204      optimizer = RMSPropOptimizer(learning_rate=0.001)
205      loss = 'mse'
206      metrics = ['mae']
207      model.compile(optimizer, loss, metrics=metrics,
208                    run_eagerly=testing_utils.should_run_eagerly())
209
210      inputs = np.zeros((10, 3))
211      targets = np.zeros((10, 4))
212      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
213      dataset = dataset.repeat(100)
214      dataset = dataset.batch(10)
215
216      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
217
218  if __name__ == "__main__":
219    tf.test.main()
220  ```
221
222  This test will try compiling & fitting the small functional mlp using all
223  three Keras execution modes.
224
225  Args:
226    test_or_class: test method or class to be annotated. If None,
227      this method returns a decorator that can be applied to a test method or
228      test class. If it is not None this returns the decorator applied to the
229      test or class.
230    config: An optional config_pb2.ConfigProto to use to configure the
231      session when executing graphs.
232    always_skip_v1: If True, does not try running the legacy graph mode even
233      when Tensorflow v2 behavior is not enabled.
234
235  Returns:
236    Returns a decorator that will run the decorated test method multiple times.
237
238  Raises:
239    ImportError: If abseil parameterized is not installed or not included as
240      a target dependency.
241  """
242  params = [('_v2_eager', 'v2_eager'),
243            ('_v2_function', 'v2_function')]
244  if not (always_skip_v1 or tf2.enabled()):
245    params.append(('_v1_graph', 'v1_graph'))
246
247  def single_method_decorator(f):
248    """Decorator that constructs the test cases."""
249
250    # Use named_parameters so it can be individually run from the command line
251    @parameterized.named_parameters(*params)
252    @functools.wraps(f)
253    def decorated(self, run_mode, *args, **kwargs):
254      """A run of a single test case w/ specified run mode."""
255      if run_mode == 'v1_graph':
256        _v1_graph_test(f, self, config, *args, **kwargs)
257      elif run_mode == 'v2_function':
258        _v2_graph_functions_test(f, self, *args, **kwargs)
259      elif run_mode == 'v2_eager':
260        _v2_eager_test(f, self, *args, **kwargs)
261      else:
262        return ValueError('Unknown run mode %s' % run_mode)
263
264    return decorated
265
266  return _test_or_class_decorator(test_or_class, single_method_decorator)
267
268
269def _v1_graph_test(f, test_or_class, config, *args, **kwargs):
270  with context.graph_mode(), testing_utils.run_eagerly_scope(False):
271    with test_or_class.test_session(use_gpu=True, config=config):
272      f(test_or_class, *args, **kwargs)
273
274
275def _v2_graph_functions_test(f, test_or_class, *args, **kwargs):
276  with context.eager_mode():
277    with testing_utils.run_eagerly_scope(False):
278      f(test_or_class, *args, **kwargs)
279
280
281def _v2_eager_test(f, test_or_class, *args, **kwargs):
282  with context.eager_mode():
283    with testing_utils.run_eagerly_scope(True):
284      f(test_or_class, *args, **kwargs)
285
286
287def _test_or_class_decorator(test_or_class, single_method_decorator):
288  """Decorate a test or class with a decorator intended for one method.
289
290  If the test_or_class is a class:
291    This will apply the decorator to all test methods in the class.
292
293  If the test_or_class is an iterable of already-parameterized test cases:
294    This will apply the decorator to all the cases, and then flatten the
295    resulting cross-product of test cases. This allows stacking the Keras
296    parameterized decorators w/ each other, and to apply them to test methods
297    that have already been marked with an absl parameterized decorator.
298
299  Otherwise, treat the obj as a single method and apply the decorator directly.
300
301  Args:
302    test_or_class: A test method (that may have already been decorated with a
303      parameterized decorator, or a test class that extends
304      keras_parameterized.TestCase
305    single_method_decorator:
306      A parameterized decorator intended for a single test method.
307  Returns:
308    The decorated result.
309  """
310  def _decorate_test_or_class(obj):
311    if isinstance(obj, collections.Iterable):
312      return itertools.chain.from_iterable(
313          single_method_decorator(method) for method in obj)
314    if isinstance(obj, type):
315      cls = obj
316      for name, value in cls.__dict__.copy().items():
317        if callable(value) and name.startswith(
318            unittest.TestLoader.testMethodPrefix):
319          setattr(cls, name, single_method_decorator(value))
320
321      cls = type(cls).__new__(type(cls), cls.__name__, cls.__bases__,
322                              cls.__dict__.copy())
323      return cls
324
325    return single_method_decorator(obj)
326
327  if test_or_class is not None:
328    return _decorate_test_or_class(test_or_class)
329
330  return _decorate_test_or_class
331