1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for unit-testing Keras."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections.abc as collections_abc
22import functools
23import itertools
24import unittest
25
26from absl.testing import parameterized
27
28from tensorflow.python import keras
29from tensorflow.python import tf2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import ops
32from tensorflow.python.keras import testing_utils
33from tensorflow.python.platform import test
34from tensorflow.python.util import nest
35
36try:
37  import h5py  # pylint:disable=g-import-not-at-top
38except ImportError:
39  h5py = None
40
41
42class TestCase(test.TestCase, parameterized.TestCase):
43
44  def tearDown(self):
45    keras.backend.clear_session()
46    super(TestCase, self).tearDown()
47
48
49def run_with_all_saved_model_formats(
50    test_or_class=None,
51    exclude_formats=None):
52  """Execute the decorated test with all Keras saved model formats).
53
54  This decorator is intended to be applied either to individual test methods in
55  a `keras_parameterized.TestCase` class, or directly to a test class that
56  extends it. Doing so will cause the contents of the individual test
57  method (or all test methods in the class) to be executed multiple times - once
58  for each Keras saved model format.
59
60  The Keras saved model formats include:
61  1. HDF5: 'h5'
62  2. SavedModel: 'tf'
63
64  Note: if stacking this decorator with absl.testing's parameterized decorators,
65  those should be at the bottom of the stack.
66
67  Various methods in `testing_utils` to get file path for saved models will
68  auto-generate a string of the two saved model formats. This allows unittests
69  to confirm the equivalence between the two Keras saved model formats.
70
71  For example, consider the following unittest:
72
73  ```python
74  class MyTests(testing_utils.KerasTestCase):
75
76    @testing_utils.run_with_all_saved_model_formats
77    def test_foo(self):
78      save_format = testing_utils.get_save_format()
79      saved_model_dir = '/tmp/saved_model/'
80      model = keras.models.Sequential()
81      model.add(keras.layers.Dense(2, input_shape=(3,)))
82      model.add(keras.layers.Dense(3))
83      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
84
85      keras.models.save_model(model, saved_model_dir, save_format=save_format)
86      model = keras.models.load_model(saved_model_dir)
87
88  if __name__ == "__main__":
89    tf.test.main()
90  ```
91
92  This test tries to save the model into the formats of 'hdf5', 'h5', 'keras',
93  'tensorflow', and 'tf'.
94
95  We can also annotate the whole class if we want this to apply to all tests in
96  the class:
97  ```python
98  @testing_utils.run_with_all_saved_model_formats
99  class MyTests(testing_utils.KerasTestCase):
100
101    def test_foo(self):
102      save_format = testing_utils.get_save_format()
103      saved_model_dir = '/tmp/saved_model/'
104      model = keras.models.Sequential()
105      model.add(keras.layers.Dense(2, input_shape=(3,)))
106      model.add(keras.layers.Dense(3))
107      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
108
109      keras.models.save_model(model, saved_model_dir, save_format=save_format)
110      model = tf.keras.models.load_model(saved_model_dir)
111
112  if __name__ == "__main__":
113    tf.test.main()
114  ```
115
116  Args:
117    test_or_class: test method or class to be annotated. If None,
118      this method returns a decorator that can be applied to a test method or
119      test class. If it is not None this returns the decorator applied to the
120      test or class.
121    exclude_formats: A collection of Keras saved model formats to not run.
122      (May also be a single format not wrapped in a collection).
123      Defaults to None.
124
125  Returns:
126    Returns a decorator that will run the decorated test method multiple times:
127    once for each desired Keras saved model format.
128
129  Raises:
130    ImportError: If abseil parameterized is not installed or not included as
131      a target dependency.
132  """
133  # Exclude h5 save format if H5py isn't available.
134  if h5py is None:
135    exclude_formats.append(['h5'])
136  saved_model_formats = ['h5', 'tf', 'tf_no_traces']
137  params = [('_%s' % saved_format, saved_format)
138            for saved_format in saved_model_formats
139            if saved_format not in nest.flatten(exclude_formats)]
140
141  def single_method_decorator(f):
142    """Decorator that constructs the test cases."""
143    # Use named_parameters so it can be individually run from the command line
144    @parameterized.named_parameters(*params)
145    @functools.wraps(f)
146    def decorated(self, saved_format, *args, **kwargs):
147      """A run of a single test case w/ the specified model type."""
148      if saved_format == 'h5':
149        _test_h5_saved_model_format(f, self, *args, **kwargs)
150      elif saved_format == 'tf':
151        _test_tf_saved_model_format(f, self, *args, **kwargs)
152      elif saved_format == 'tf_no_traces':
153        _test_tf_saved_model_format_no_traces(f, self, *args, **kwargs)
154      else:
155        raise ValueError('Unknown model type: %s' % (saved_format,))
156    return decorated
157
158  return _test_or_class_decorator(test_or_class, single_method_decorator)
159
160
161def _test_h5_saved_model_format(f, test_or_class, *args, **kwargs):
162  with testing_utils.saved_model_format_scope('h5'):
163    f(test_or_class, *args, **kwargs)
164
165
166def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs):
167  with testing_utils.saved_model_format_scope('tf'):
168    f(test_or_class, *args, **kwargs)
169
170
171def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs):
172  with testing_utils.saved_model_format_scope('tf', save_traces=False):
173    f(test_or_class, *args, **kwargs)
174
175
176def run_with_all_weight_formats(test_or_class=None, exclude_formats=None):
177  """Runs all tests with the supported formats for saving weights."""
178  exclude_formats = exclude_formats or []
179  exclude_formats.append('tf_no_traces')  # Only applies to saving models
180  return run_with_all_saved_model_formats(test_or_class, exclude_formats)
181
182
183# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass
184# it. Or perhaps make 'subclass' always use a custom build method.
185def run_with_all_model_types(
186    test_or_class=None,
187    exclude_models=None):
188  """Execute the decorated test with all Keras model types.
189
190  This decorator is intended to be applied either to individual test methods in
191  a `keras_parameterized.TestCase` class, or directly to a test class that
192  extends it. Doing so will cause the contents of the individual test
193  method (or all test methods in the class) to be executed multiple times - once
194  for each Keras model type.
195
196  The Keras model types are: ['functional', 'subclass', 'sequential']
197
198  Note: if stacking this decorator with absl.testing's parameterized decorators,
199  those should be at the bottom of the stack.
200
201  Various methods in `testing_utils` to get models will auto-generate a model
202  of the currently active Keras model type. This allows unittests to confirm
203  the equivalence between different Keras models.
204
205  For example, consider the following unittest:
206
207  ```python
208  class MyTests(testing_utils.KerasTestCase):
209
210    @testing_utils.run_with_all_model_types(
211      exclude_models = ['sequential'])
212    def test_foo(self):
213      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
214      optimizer = RMSPropOptimizer(learning_rate=0.001)
215      loss = 'mse'
216      metrics = ['mae']
217      model.compile(optimizer, loss, metrics=metrics)
218
219      inputs = np.zeros((10, 3))
220      targets = np.zeros((10, 4))
221      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
222      dataset = dataset.repeat(100)
223      dataset = dataset.batch(10)
224
225      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
226
227  if __name__ == "__main__":
228    tf.test.main()
229  ```
230
231  This test tries building a small mlp as both a functional model and as a
232  subclass model.
233
234  We can also annotate the whole class if we want this to apply to all tests in
235  the class:
236  ```python
237  @testing_utils.run_with_all_model_types(exclude_models = ['sequential'])
238  class MyTests(testing_utils.KerasTestCase):
239
240    def test_foo(self):
241      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
242      optimizer = RMSPropOptimizer(learning_rate=0.001)
243      loss = 'mse'
244      metrics = ['mae']
245      model.compile(optimizer, loss, metrics=metrics)
246
247      inputs = np.zeros((10, 3))
248      targets = np.zeros((10, 4))
249      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
250      dataset = dataset.repeat(100)
251      dataset = dataset.batch(10)
252
253      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
254
255  if __name__ == "__main__":
256    tf.test.main()
257  ```
258
259
260  Args:
261    test_or_class: test method or class to be annotated. If None,
262      this method returns a decorator that can be applied to a test method or
263      test class. If it is not None this returns the decorator applied to the
264      test or class.
265    exclude_models: A collection of Keras model types to not run.
266      (May also be a single model type not wrapped in a collection).
267      Defaults to None.
268
269  Returns:
270    Returns a decorator that will run the decorated test method multiple times:
271    once for each desired Keras model type.
272
273  Raises:
274    ImportError: If abseil parameterized is not installed or not included as
275      a target dependency.
276  """
277  model_types = ['functional', 'subclass', 'sequential']
278  params = [('_%s' % model, model) for model in model_types
279            if model not in nest.flatten(exclude_models)]
280
281  def single_method_decorator(f):
282    """Decorator that constructs the test cases."""
283    # Use named_parameters so it can be individually run from the command line
284    @parameterized.named_parameters(*params)
285    @functools.wraps(f)
286    def decorated(self, model_type, *args, **kwargs):
287      """A run of a single test case w/ the specified model type."""
288      if model_type == 'functional':
289        _test_functional_model_type(f, self, *args, **kwargs)
290      elif model_type == 'subclass':
291        _test_subclass_model_type(f, self, *args, **kwargs)
292      elif model_type == 'sequential':
293        _test_sequential_model_type(f, self, *args, **kwargs)
294      else:
295        raise ValueError('Unknown model type: %s' % (model_type,))
296    return decorated
297
298  return _test_or_class_decorator(test_or_class, single_method_decorator)
299
300
301def _test_functional_model_type(f, test_or_class, *args, **kwargs):
302  with testing_utils.model_type_scope('functional'):
303    f(test_or_class, *args, **kwargs)
304
305
306def _test_subclass_model_type(f, test_or_class, *args, **kwargs):
307  with testing_utils.model_type_scope('subclass'):
308    f(test_or_class, *args, **kwargs)
309
310
311def _test_sequential_model_type(f, test_or_class, *args, **kwargs):
312  with testing_utils.model_type_scope('sequential'):
313    f(test_or_class, *args, **kwargs)
314
315
316def run_all_keras_modes(test_or_class=None,
317                        config=None,
318                        always_skip_v1=False,
319                        always_skip_eager=False,
320                        **kwargs):
321  """Execute the decorated test with all keras execution modes.
322
323  This decorator is intended to be applied either to individual test methods in
324  a `keras_parameterized.TestCase` class, or directly to a test class that
325  extends it. Doing so will cause the contents of the individual test
326  method (or all test methods in the class) to be executed multiple times -
327  once executing in legacy graph mode, once running eagerly and with
328  `should_run_eagerly` returning True, and once running eagerly with
329  `should_run_eagerly` returning False.
330
331  If Tensorflow v2 behavior is enabled, legacy graph mode will be skipped, and
332  the test will only run twice.
333
334  Note: if stacking this decorator with absl.testing's parameterized decorators,
335  those should be at the bottom of the stack.
336
337  For example, consider the following unittest:
338
339  ```python
340  class MyTests(testing_utils.KerasTestCase):
341
342    @testing_utils.run_all_keras_modes
343    def test_foo(self):
344      model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
345      optimizer = RMSPropOptimizer(learning_rate=0.001)
346      loss = 'mse'
347      metrics = ['mae']
348      model.compile(
349          optimizer, loss, metrics=metrics,
350          run_eagerly=testing_utils.should_run_eagerly())
351
352      inputs = np.zeros((10, 3))
353      targets = np.zeros((10, 4))
354      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
355      dataset = dataset.repeat(100)
356      dataset = dataset.batch(10)
357
358      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
359
360  if __name__ == "__main__":
361    tf.test.main()
362  ```
363
364  This test will try compiling & fitting the small functional mlp using all
365  three Keras execution modes.
366
367  Args:
368    test_or_class: test method or class to be annotated. If None,
369      this method returns a decorator that can be applied to a test method or
370      test class. If it is not None this returns the decorator applied to the
371      test or class.
372    config: An optional config_pb2.ConfigProto to use to configure the
373      session when executing graphs.
374    always_skip_v1: If True, does not try running the legacy graph mode even
375      when Tensorflow v2 behavior is not enabled.
376    always_skip_eager: If True, does not execute the decorated test
377      with eager execution modes.
378    **kwargs: Additional kwargs for configuring tests for
379     in-progress Keras behaviors/ refactorings that we haven't fully
380     rolled out yet
381
382  Returns:
383    Returns a decorator that will run the decorated test method multiple times.
384
385  Raises:
386    ImportError: If abseil parameterized is not installed or not included as
387      a target dependency.
388  """
389  if kwargs:
390    raise ValueError('Unrecognized keyword args: {}'.format(kwargs))
391
392  params = [('_v2_function', 'v2_function')]
393  if not always_skip_eager:
394    params.append(('_v2_eager', 'v2_eager'))
395  if not (always_skip_v1 or tf2.enabled()):
396    params.append(('_v1_session', 'v1_session'))
397
398  def single_method_decorator(f):
399    """Decorator that constructs the test cases."""
400
401    # Use named_parameters so it can be individually run from the command line
402    @parameterized.named_parameters(*params)
403    @functools.wraps(f)
404    def decorated(self, run_mode, *args, **kwargs):
405      """A run of a single test case w/ specified run mode."""
406      if run_mode == 'v1_session':
407        _v1_session_test(f, self, config, *args, **kwargs)
408      elif run_mode == 'v2_eager':
409        _v2_eager_test(f, self, *args, **kwargs)
410      elif run_mode == 'v2_function':
411        _v2_function_test(f, self, *args, **kwargs)
412      else:
413        return ValueError('Unknown run mode %s' % run_mode)
414
415    return decorated
416
417  return _test_or_class_decorator(test_or_class, single_method_decorator)
418
419
420def _v1_session_test(f, test_or_class, config, *args, **kwargs):
421  with ops.get_default_graph().as_default():
422    with testing_utils.run_eagerly_scope(False):
423      with test_or_class.test_session(config=config):
424        f(test_or_class, *args, **kwargs)
425
426
427def _v2_eager_test(f, test_or_class, *args, **kwargs):
428  with context.eager_mode():
429    with testing_utils.run_eagerly_scope(True):
430      f(test_or_class, *args, **kwargs)
431
432
433def _v2_function_test(f, test_or_class, *args, **kwargs):
434  with context.eager_mode():
435    with testing_utils.run_eagerly_scope(False):
436      f(test_or_class, *args, **kwargs)
437
438
439def _test_or_class_decorator(test_or_class, single_method_decorator):
440  """Decorate a test or class with a decorator intended for one method.
441
442  If the test_or_class is a class:
443    This will apply the decorator to all test methods in the class.
444
445  If the test_or_class is an iterable of already-parameterized test cases:
446    This will apply the decorator to all the cases, and then flatten the
447    resulting cross-product of test cases. This allows stacking the Keras
448    parameterized decorators w/ each other, and to apply them to test methods
449    that have already been marked with an absl parameterized decorator.
450
451  Otherwise, treat the obj as a single method and apply the decorator directly.
452
453  Args:
454    test_or_class: A test method (that may have already been decorated with a
455      parameterized decorator, or a test class that extends
456      keras_parameterized.TestCase
457    single_method_decorator:
458      A parameterized decorator intended for a single test method.
459  Returns:
460    The decorated result.
461  """
462  def _decorate_test_or_class(obj):
463    if isinstance(obj, collections_abc.Iterable):
464      return itertools.chain.from_iterable(
465          single_method_decorator(method) for method in obj)
466    if isinstance(obj, type):
467      cls = obj
468      for name, value in cls.__dict__.copy().items():
469        if callable(value) and name.startswith(
470            unittest.TestLoader.testMethodPrefix):
471          setattr(cls, name, single_method_decorator(value))
472
473      cls = type(cls).__new__(type(cls), cls.__name__, cls.__bases__,
474                              cls.__dict__.copy())
475      return cls
476
477    return single_method_decorator(obj)
478
479  if test_or_class is not None:
480    return _decorate_test_or_class(test_or_class)
481
482  return _decorate_test_or_class
483