1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for unit-testing Keras."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import functools
24import itertools
25import threading
26
27import numpy as np
28
29from tensorflow.python import tf2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import config
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import test_util
37from tensorflow.python.keras import backend
38from tensorflow.python.keras import layers
39from tensorflow.python.keras import models
40from tensorflow.python.keras.engine import base_layer_utils
41from tensorflow.python.keras.engine import keras_tensor
42from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
43from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
44from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
45from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
46from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
47from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
48from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
49from tensorflow.python.keras.utils import tf_contextlib
50from tensorflow.python.keras.utils import tf_inspect
51from tensorflow.python.util import tf_decorator
52
53
54def string_test(actual, expected):
55  np.testing.assert_array_equal(actual, expected)
56
57
58def numeric_test(actual, expected):
59  np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6)
60
61
62def get_test_data(train_samples,
63                  test_samples,
64                  input_shape,
65                  num_classes,
66                  random_seed=None):
67  """Generates test data to train a model on.
68
69  Args:
70    train_samples: Integer, how many training samples to generate.
71    test_samples: Integer, how many test samples to generate.
72    input_shape: Tuple of integers, shape of the inputs.
73    num_classes: Integer, number of classes for the data and targets.
74    random_seed: Integer, random seed used by numpy to generate data.
75
76  Returns:
77    A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
78  """
79  if random_seed is not None:
80    np.random.seed(random_seed)
81  num_sample = train_samples + test_samples
82  templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
83  y = np.random.randint(0, num_classes, size=(num_sample,))
84  x = np.zeros((num_sample,) + input_shape, dtype=np.float32)
85  for i in range(num_sample):
86    x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
87  return ((x[:train_samples], y[:train_samples]),
88          (x[train_samples:], y[train_samples:]))
89
90
91@test_util.disable_cudnn_autotune
92def layer_test(layer_cls,
93               kwargs=None,
94               input_shape=None,
95               input_dtype=None,
96               input_data=None,
97               expected_output=None,
98               expected_output_dtype=None,
99               expected_output_shape=None,
100               validate_training=True,
101               adapt_data=None,
102               custom_objects=None,
103               test_harness=None,
104               supports_masking=None):
105  """Test routine for a layer with a single input and single output.
106
107  Args:
108    layer_cls: Layer class object.
109    kwargs: Optional dictionary of keyword arguments for instantiating the
110      layer.
111    input_shape: Input shape tuple.
112    input_dtype: Data type of the input data.
113    input_data: Numpy array of input data.
114    expected_output: Numpy array of the expected output.
115    expected_output_dtype: Data type expected for the output.
116    expected_output_shape: Shape tuple for the expected shape of the output.
117    validate_training: Whether to attempt to validate training on this layer.
118      This might be set to False for non-differentiable layers that output
119      string or integer values.
120    adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
121      be tested for this layer. This is only relevant for PreprocessingLayers.
122    custom_objects: Optional dictionary mapping name strings to custom objects
123      in the layer class. This is helpful for testing custom layers.
124    test_harness: The Tensorflow test, if any, that this function is being
125      called in.
126    supports_masking: Optional boolean to check the `supports_masking` property
127      of the layer. If None, the check will not be performed.
128
129  Returns:
130    The output data (Numpy array) returned by the layer, for additional
131    checks to be done by the calling code.
132
133  Raises:
134    ValueError: if `input_shape is None`.
135  """
136  if input_data is None:
137    if input_shape is None:
138      raise ValueError('input_shape is None')
139    if not input_dtype:
140      input_dtype = 'float32'
141    input_data_shape = list(input_shape)
142    for i, e in enumerate(input_data_shape):
143      if e is None:
144        input_data_shape[i] = np.random.randint(1, 4)
145    input_data = 10 * np.random.random(input_data_shape)
146    if input_dtype[:5] == 'float':
147      input_data -= 0.5
148    input_data = input_data.astype(input_dtype)
149  elif input_shape is None:
150    input_shape = input_data.shape
151  if input_dtype is None:
152    input_dtype = input_data.dtype
153  if expected_output_dtype is None:
154    expected_output_dtype = input_dtype
155
156  if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
157    if test_harness:
158      assert_equal = test_harness.assertAllEqual
159    else:
160      assert_equal = string_test
161  else:
162    if test_harness:
163      assert_equal = test_harness.assertAllClose
164    else:
165      assert_equal = numeric_test
166
167  # instantiation
168  kwargs = kwargs or {}
169  layer = layer_cls(**kwargs)
170
171  if (supports_masking is not None
172      and layer.supports_masking != supports_masking):
173    raise AssertionError(
174        'When testing layer %s, the `supports_masking` property is %r'
175        'but expected to be %r.\nFull kwargs: %s' %
176        (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs))
177
178  # Test adapt, if data was passed.
179  if adapt_data is not None:
180    layer.adapt(adapt_data)
181
182  # test get_weights , set_weights at layer level
183  weights = layer.get_weights()
184  layer.set_weights(weights)
185
186  # test and instantiation from weights
187  if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
188    kwargs['weights'] = weights
189    layer = layer_cls(**kwargs)
190
191  # test in functional API
192  x = layers.Input(shape=input_shape[1:], dtype=input_dtype)
193  y = layer(x)
194  if backend.dtype(y) != expected_output_dtype:
195    raise AssertionError('When testing layer %s, for input %s, found output '
196                         'dtype=%s but expected to find %s.\nFull kwargs: %s' %
197                         (layer_cls.__name__, x, backend.dtype(y),
198                          expected_output_dtype, kwargs))
199
200  def assert_shapes_equal(expected, actual):
201    """Asserts that the output shape from the layer matches the actual shape."""
202    if len(expected) != len(actual):
203      raise AssertionError(
204          'When testing layer %s, for input %s, found output_shape='
205          '%s but expected to find %s.\nFull kwargs: %s' %
206          (layer_cls.__name__, x, actual, expected, kwargs))
207
208    for expected_dim, actual_dim in zip(expected, actual):
209      if isinstance(expected_dim, tensor_shape.Dimension):
210        expected_dim = expected_dim.value
211      if isinstance(actual_dim, tensor_shape.Dimension):
212        actual_dim = actual_dim.value
213      if expected_dim is not None and expected_dim != actual_dim:
214        raise AssertionError(
215            'When testing layer %s, for input %s, found output_shape='
216            '%s but expected to find %s.\nFull kwargs: %s' %
217            (layer_cls.__name__, x, actual, expected, kwargs))
218
219  if expected_output_shape is not None:
220    assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape),
221                        y.shape)
222
223  # check shape inference
224  model = models.Model(x, y)
225  computed_output_shape = tuple(
226      layer.compute_output_shape(
227          tensor_shape.TensorShape(input_shape)).as_list())
228  computed_output_signature = layer.compute_output_signature(
229      tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype))
230  actual_output = model.predict(input_data)
231  actual_output_shape = actual_output.shape
232  assert_shapes_equal(computed_output_shape, actual_output_shape)
233  assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
234  if computed_output_signature.dtype != actual_output.dtype:
235    raise AssertionError(
236        'When testing layer %s, for input %s, found output_dtype='
237        '%s but expected to find %s.\nFull kwargs: %s' %
238        (layer_cls.__name__, x, actual_output.dtype,
239         computed_output_signature.dtype, kwargs))
240  if expected_output is not None:
241    assert_equal(actual_output, expected_output)
242
243  # test serialization, weight setting at model level
244  model_config = model.get_config()
245  recovered_model = models.Model.from_config(model_config, custom_objects)
246  if model.weights:
247    weights = model.get_weights()
248    recovered_model.set_weights(weights)
249    output = recovered_model.predict(input_data)
250    assert_equal(output, actual_output)
251
252  # test training mode (e.g. useful for dropout tests)
253  # Rebuild the model to avoid the graph being reused between predict() and
254  # See b/120160788 for more details. This should be mitigated after 2.0.
255  layer_weights = layer.get_weights()  # Get the layer weights BEFORE training.
256  if validate_training:
257    model = models.Model(x, layer(x))
258    if _thread_local_data.run_eagerly is not None:
259      model.compile(
260          'rmsprop',
261          'mse',
262          weighted_metrics=['acc'],
263          run_eagerly=should_run_eagerly())
264    else:
265      model.compile('rmsprop', 'mse', weighted_metrics=['acc'])
266    model.train_on_batch(input_data, actual_output)
267
268  # test as first layer in Sequential API
269  layer_config = layer.get_config()
270  layer_config['batch_input_shape'] = input_shape
271  layer = layer.__class__.from_config(layer_config)
272
273  # Test adapt, if data was passed.
274  if adapt_data is not None:
275    layer.adapt(adapt_data)
276
277  model = models.Sequential()
278  model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
279  model.add(layer)
280
281  layer.set_weights(layer_weights)
282  actual_output = model.predict(input_data)
283  actual_output_shape = actual_output.shape
284  for expected_dim, actual_dim in zip(computed_output_shape,
285                                      actual_output_shape):
286    if expected_dim is not None:
287      if expected_dim != actual_dim:
288        raise AssertionError(
289            'When testing layer %s **after deserialization**, '
290            'for input %s, found output_shape='
291            '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
292            (layer_cls.__name__,
293             x,
294             actual_output_shape,
295             computed_output_shape,
296             kwargs))
297  if expected_output is not None:
298    assert_equal(actual_output, expected_output)
299
300  # test serialization, weight setting at model level
301  model_config = model.get_config()
302  recovered_model = models.Sequential.from_config(model_config, custom_objects)
303  if model.weights:
304    weights = model.get_weights()
305    recovered_model.set_weights(weights)
306    output = recovered_model.predict(input_data)
307    assert_equal(output, actual_output)
308
309  # for further checks in the caller function
310  return actual_output
311
312
313_thread_local_data = threading.local()
314_thread_local_data.model_type = None
315_thread_local_data.run_eagerly = None
316_thread_local_data.saved_model_format = None
317_thread_local_data.save_kwargs = None
318
319
320@tf_contextlib.contextmanager
321def model_type_scope(value):
322  """Provides a scope within which the model type to test is equal to `value`.
323
324  The model type gets restored to its original value upon exiting the scope.
325
326  Args:
327     value: model type value
328
329  Yields:
330    The provided value.
331  """
332  previous_value = _thread_local_data.model_type
333  try:
334    _thread_local_data.model_type = value
335    yield value
336  finally:
337    # Restore model type to initial value.
338    _thread_local_data.model_type = previous_value
339
340
341@tf_contextlib.contextmanager
342def run_eagerly_scope(value):
343  """Provides a scope within which we compile models to run eagerly or not.
344
345  The boolean gets restored to its original value upon exiting the scope.
346
347  Args:
348     value: Bool specifying if we should run models eagerly in the active test.
349     Should be True or False.
350
351  Yields:
352    The provided value.
353  """
354  previous_value = _thread_local_data.run_eagerly
355  try:
356    _thread_local_data.run_eagerly = value
357    yield value
358  finally:
359    # Restore model type to initial value.
360    _thread_local_data.run_eagerly = previous_value
361
362
363@tf_contextlib.contextmanager
364def use_keras_tensors_scope(value):
365  """Provides a scope within which we use KerasTensors in the func. API or not.
366
367  The boolean gets restored to its original value upon exiting the scope.
368
369  Args:
370     value: Bool specifying if we should build functional models
371      using KerasTensors in the active test.
372     Should be True or False.
373
374  Yields:
375    The provided value.
376  """
377  previous_value = keras_tensor._KERAS_TENSORS_ENABLED  # pylint: disable=protected-access
378  try:
379    keras_tensor._KERAS_TENSORS_ENABLED = value  # pylint: disable=protected-access
380    yield value
381  finally:
382    # Restore KerasTensor usage to initial value.
383    keras_tensor._KERAS_TENSORS_ENABLED = previous_value  # pylint: disable=protected-access
384
385
386def should_run_eagerly():
387  """Returns whether the models we are testing should be run eagerly."""
388  if _thread_local_data.run_eagerly is None:
389    raise ValueError('Cannot call `should_run_eagerly()` outside of a '
390                     '`run_eagerly_scope()` or `run_all_keras_modes` '
391                     'decorator.')
392
393  return _thread_local_data.run_eagerly and context.executing_eagerly()
394
395
396@tf_contextlib.contextmanager
397def saved_model_format_scope(value, **kwargs):
398  """Provides a scope within which the savde model format to test is `value`.
399
400  The saved model format gets restored to its original value upon exiting the
401  scope.
402
403  Args:
404     value: saved model format value
405     **kwargs: optional kwargs to pass to the save function.
406
407  Yields:
408    The provided value.
409  """
410  previous_format = _thread_local_data.saved_model_format
411  previous_kwargs = _thread_local_data.save_kwargs
412  try:
413    _thread_local_data.saved_model_format = value
414    _thread_local_data.save_kwargs = kwargs
415    yield
416  finally:
417    # Restore saved model format to initial value.
418    _thread_local_data.saved_model_format = previous_format
419    _thread_local_data.save_kwargs = previous_kwargs
420
421
422def get_save_format():
423  if _thread_local_data.saved_model_format is None:
424    raise ValueError(
425        'Cannot call `get_save_format()` outside of a '
426        '`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
427        'decorator.')
428  return _thread_local_data.saved_model_format
429
430
431def get_save_kwargs():
432  if _thread_local_data.save_kwargs is None:
433    raise ValueError(
434        'Cannot call `get_save_kwargs()` outside of a '
435        '`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
436        'decorator.')
437  return _thread_local_data.save_kwargs or {}
438
439
440def get_model_type():
441  """Gets the model type that should be tested."""
442  if _thread_local_data.model_type is None:
443    raise ValueError('Cannot call `get_model_type()` outside of a '
444                     '`model_type_scope()` or `run_with_all_model_types` '
445                     'decorator.')
446
447  return _thread_local_data.model_type
448
449
450def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None):
451  model = models.Sequential()
452  if input_dim:
453    model.add(layers.Dense(num_hidden, activation='relu', input_dim=input_dim))
454  else:
455    model.add(layers.Dense(num_hidden, activation='relu'))
456  activation = 'sigmoid' if num_classes == 1 else 'softmax'
457  model.add(layers.Dense(num_classes, activation=activation))
458  return model
459
460
461def get_small_functional_mlp(num_hidden, num_classes, input_dim):
462  inputs = layers.Input(shape=(input_dim,))
463  outputs = layers.Dense(num_hidden, activation='relu')(inputs)
464  activation = 'sigmoid' if num_classes == 1 else 'softmax'
465  outputs = layers.Dense(num_classes, activation=activation)(outputs)
466  return models.Model(inputs, outputs)
467
468
469class SmallSubclassMLP(models.Model):
470  """A subclass model based small MLP."""
471
472  def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False):
473    super(SmallSubclassMLP, self).__init__(name='test_model')
474    self.use_bn = use_bn
475    self.use_dp = use_dp
476
477    self.layer_a = layers.Dense(num_hidden, activation='relu')
478    activation = 'sigmoid' if num_classes == 1 else 'softmax'
479    self.layer_b = layers.Dense(num_classes, activation=activation)
480    if self.use_dp:
481      self.dp = layers.Dropout(0.5)
482    if self.use_bn:
483      self.bn = layers.BatchNormalization(axis=-1)
484
485  def call(self, inputs, **kwargs):
486    x = self.layer_a(inputs)
487    if self.use_dp:
488      x = self.dp(x)
489    if self.use_bn:
490      x = self.bn(x)
491    return self.layer_b(x)
492
493
494class _SmallSubclassMLPCustomBuild(models.Model):
495  """A subclass model small MLP that uses a custom build method."""
496
497  def __init__(self, num_hidden, num_classes):
498    super(_SmallSubclassMLPCustomBuild, self).__init__()
499    self.layer_a = None
500    self.layer_b = None
501    self.num_hidden = num_hidden
502    self.num_classes = num_classes
503
504  def build(self, input_shape):
505    self.layer_a = layers.Dense(self.num_hidden, activation='relu')
506    activation = 'sigmoid' if self.num_classes == 1 else 'softmax'
507    self.layer_b = layers.Dense(self.num_classes, activation=activation)
508
509  def call(self, inputs, **kwargs):
510    x = self.layer_a(inputs)
511    return self.layer_b(x)
512
513
514def get_small_subclass_mlp(num_hidden, num_classes):
515  return SmallSubclassMLP(num_hidden, num_classes)
516
517
518def get_small_subclass_mlp_with_custom_build(num_hidden, num_classes):
519  return _SmallSubclassMLPCustomBuild(num_hidden, num_classes)
520
521
522def get_small_mlp(num_hidden, num_classes, input_dim):
523  """Get a small mlp of the model type specified by `get_model_type`."""
524  model_type = get_model_type()
525  if model_type == 'subclass':
526    return get_small_subclass_mlp(num_hidden, num_classes)
527  if model_type == 'subclass_custom_build':
528    return get_small_subclass_mlp_with_custom_build(num_hidden, num_classes)
529  if model_type == 'sequential':
530    return get_small_sequential_mlp(num_hidden, num_classes, input_dim)
531  if model_type == 'functional':
532    return get_small_functional_mlp(num_hidden, num_classes, input_dim)
533  raise ValueError('Unknown model type {}'.format(model_type))
534
535
536class _SubclassModel(models.Model):
537  """A Keras subclass model."""
538
539  def __init__(self, model_layers, *args, **kwargs):
540    """Instantiate a model.
541
542    Args:
543      model_layers: a list of layers to be added to the model.
544      *args: Model's args
545      **kwargs: Model's keyword args, at most one of input_tensor -> the input
546        tensor required for ragged/sparse input.
547    """
548
549    inputs = kwargs.pop('input_tensor', None)
550    super(_SubclassModel, self).__init__(*args, **kwargs)
551    # Note that clone and build doesn't support lists of layers in subclassed
552    # models. Adding each layer directly here.
553    for i, layer in enumerate(model_layers):
554      setattr(self, self._layer_name_for_i(i), layer)
555
556    self.num_layers = len(model_layers)
557
558    if inputs is not None:
559      self._set_inputs(inputs)
560
561  def _layer_name_for_i(self, i):
562    return 'layer{}'.format(i)
563
564  def call(self, inputs, **kwargs):
565    x = inputs
566    for i in range(self.num_layers):
567      layer = getattr(self, self._layer_name_for_i(i))
568      x = layer(x)
569    return x
570
571
572class _SubclassModelCustomBuild(models.Model):
573  """A Keras subclass model that uses a custom build method."""
574
575  def __init__(self, layer_generating_func, *args, **kwargs):
576    super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs)
577    self.all_layers = None
578    self._layer_generating_func = layer_generating_func
579
580  def build(self, input_shape):
581    model_layers = []
582    for layer in self._layer_generating_func():
583      model_layers.append(layer)
584    self.all_layers = model_layers
585
586  def call(self, inputs, **kwargs):
587    x = inputs
588    for layer in self.all_layers:
589      x = layer(x)
590    return x
591
592
593def get_model_from_layers(model_layers,
594                          input_shape=None,
595                          input_dtype=None,
596                          name=None,
597                          input_ragged=None,
598                          input_sparse=None):
599  """Builds a model from a sequence of layers.
600
601  Args:
602    model_layers: The layers used to build the network.
603    input_shape: Shape tuple of the input or 'TensorShape' instance.
604    input_dtype: Datatype of the input.
605    name: Name for the model.
606    input_ragged: Boolean, whether the input data is a ragged tensor.
607    input_sparse: Boolean, whether the input data is a sparse tensor.
608
609  Returns:
610    A Keras model.
611  """
612
613  model_type = get_model_type()
614  if model_type == 'subclass':
615    inputs = None
616    if input_ragged or input_sparse:
617      inputs = layers.Input(
618          shape=input_shape,
619          dtype=input_dtype,
620          ragged=input_ragged,
621          sparse=input_sparse)
622    return _SubclassModel(model_layers, name=name, input_tensor=inputs)
623
624  if model_type == 'subclass_custom_build':
625    layer_generating_func = lambda: model_layers
626    return _SubclassModelCustomBuild(layer_generating_func, name=name)
627
628  if model_type == 'sequential':
629    model = models.Sequential(name=name)
630    if input_shape:
631      model.add(
632          layers.InputLayer(
633              input_shape=input_shape,
634              dtype=input_dtype,
635              ragged=input_ragged,
636              sparse=input_sparse))
637    for layer in model_layers:
638      model.add(layer)
639    return model
640
641  if model_type == 'functional':
642    if not input_shape:
643      raise ValueError('Cannot create a functional model from layers with no '
644                       'input shape.')
645    inputs = layers.Input(
646        shape=input_shape,
647        dtype=input_dtype,
648        ragged=input_ragged,
649        sparse=input_sparse)
650    outputs = inputs
651    for layer in model_layers:
652      outputs = layer(outputs)
653    return models.Model(inputs, outputs, name=name)
654
655  raise ValueError('Unknown model type {}'.format(model_type))
656
657
658class Bias(layers.Layer):
659
660  def build(self, input_shape):
661    self.bias = self.add_variable('bias', (1,), initializer='zeros')
662
663  def call(self, inputs):
664    return inputs + self.bias
665
666
667class _MultiIOSubclassModel(models.Model):
668  """Multi IO Keras subclass model."""
669
670  def __init__(self, branch_a, branch_b, shared_input_branch=None,
671               shared_output_branch=None, name=None):
672    super(_MultiIOSubclassModel, self).__init__(name=name)
673    self._shared_input_branch = shared_input_branch
674    self._branch_a = branch_a
675    self._branch_b = branch_b
676    self._shared_output_branch = shared_output_branch
677
678  def call(self, inputs, **kwargs):
679    if self._shared_input_branch:
680      for layer in self._shared_input_branch:
681        inputs = layer(inputs)
682      a = inputs
683      b = inputs
684    elif isinstance(inputs, dict):
685      a = inputs['input_1']
686      b = inputs['input_2']
687    else:
688      a, b = inputs
689
690    for layer in self._branch_a:
691      a = layer(a)
692    for layer in self._branch_b:
693      b = layer(b)
694    outs = [a, b]
695
696    if self._shared_output_branch:
697      for layer in self._shared_output_branch:
698        outs = layer(outs)
699
700    return outs
701
702
703class _MultiIOSubclassModelCustomBuild(models.Model):
704  """Multi IO Keras subclass model that uses a custom build method."""
705
706  def __init__(self, branch_a_func, branch_b_func,
707               shared_input_branch_func=None,
708               shared_output_branch_func=None):
709    super(_MultiIOSubclassModelCustomBuild, self).__init__()
710    self._shared_input_branch_func = shared_input_branch_func
711    self._branch_a_func = branch_a_func
712    self._branch_b_func = branch_b_func
713    self._shared_output_branch_func = shared_output_branch_func
714
715    self._shared_input_branch = None
716    self._branch_a = None
717    self._branch_b = None
718    self._shared_output_branch = None
719
720  def build(self, input_shape):
721    if self._shared_input_branch_func():
722      self._shared_input_branch = self._shared_input_branch_func()
723    self._branch_a = self._branch_a_func()
724    self._branch_b = self._branch_b_func()
725
726    if self._shared_output_branch_func():
727      self._shared_output_branch = self._shared_output_branch_func()
728
729  def call(self, inputs, **kwargs):
730    if self._shared_input_branch:
731      for layer in self._shared_input_branch:
732        inputs = layer(inputs)
733      a = inputs
734      b = inputs
735    else:
736      a, b = inputs
737
738    for layer in self._branch_a:
739      a = layer(a)
740    for layer in self._branch_b:
741      b = layer(b)
742    outs = a, b
743
744    if self._shared_output_branch:
745      for layer in self._shared_output_branch:
746        outs = layer(outs)
747
748    return outs
749
750
751def get_multi_io_model(
752    branch_a,
753    branch_b,
754    shared_input_branch=None,
755    shared_output_branch=None):
756  """Builds a multi-io model that contains two branches.
757
758  The produced model will be of the type specified by `get_model_type`.
759
760  To build a two-input, two-output model:
761    Specify a list of layers for branch a and branch b, but do not specify any
762    shared input branch or shared output branch. The resulting model will apply
763    each branch to a different input, to produce two outputs.
764
765    The first value in branch_a must be the Keras 'Input' layer for branch a,
766    and the first value in branch_b must be the Keras 'Input' layer for
767    branch b.
768
769    example usage:
770    ```
771    branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()]
772    branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()]
773
774    model = get_multi_io_model(branch_a, branch_b)
775    ```
776
777  To build a two-input, one-output model:
778    Specify a list of layers for branch a and branch b, and specify a
779    shared output branch. The resulting model will apply
780    each branch to a different input. It will then apply the shared output
781    branch to a tuple containing the intermediate outputs of each branch,
782    to produce a single output. The first layer in the shared_output_branch
783    must be able to merge a tuple of two tensors.
784
785    The first value in branch_a must be the Keras 'Input' layer for branch a,
786    and the first value in branch_b must be the Keras 'Input' layer for
787    branch b.
788
789    example usage:
790    ```
791    input_branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()]
792    input_branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()]
793    shared_output_branch = [Concatenate(), Dense(), Dense()]
794
795    model = get_multi_io_model(input_branch_a, input_branch_b,
796                               shared_output_branch=shared_output_branch)
797    ```
798  To build a one-input, two-output model:
799    Specify a list of layers for branch a and branch b, and specify a
800    shared input branch. The resulting model will take one input, and apply
801    the shared input branch to it. It will then respectively apply each branch
802    to that intermediate result in parallel, to produce two outputs.
803
804    The first value in the shared_input_branch must be the Keras 'Input' layer
805    for the whole model. Branch a and branch b should not contain any Input
806    layers.
807
808    example usage:
809    ```
810    shared_input_branch = [Input(shape=(2,), name='in'), Dense(), Dense()]
811    output_branch_a = [Dense(), Dense()]
812    output_branch_b = [Dense(), Dense()]
813
814
815    model = get_multi_io_model(output__branch_a, output_branch_b,
816                               shared_input_branch=shared_input_branch)
817    ```
818
819  Args:
820    branch_a: A sequence of layers for branch a of the model.
821    branch_b: A sequence of layers for branch b of the model.
822    shared_input_branch: An optional sequence of layers to apply to a single
823      input, before applying both branches to that intermediate result. If set,
824      the model will take only one input instead of two. Defaults to None.
825    shared_output_branch: An optional sequence of layers to merge the
826      intermediate results produced by branch a and branch b. If set,
827      the model will produce only one output instead of two. Defaults to None.
828
829  Returns:
830    A multi-io model of the type specified by `get_model_type`, specified
831    by the different branches.
832  """
833  # Extract the functional inputs from the layer lists
834  if shared_input_branch:
835    inputs = shared_input_branch[0]
836    shared_input_branch = shared_input_branch[1:]
837  else:
838    inputs = branch_a[0], branch_b[0]
839    branch_a = branch_a[1:]
840    branch_b = branch_b[1:]
841
842  model_type = get_model_type()
843  if model_type == 'subclass':
844    return _MultiIOSubclassModel(branch_a, branch_b, shared_input_branch,
845                                 shared_output_branch)
846
847  if model_type == 'subclass_custom_build':
848    return _MultiIOSubclassModelCustomBuild((lambda: branch_a),
849                                            (lambda: branch_b),
850                                            (lambda: shared_input_branch),
851                                            (lambda: shared_output_branch))
852
853  if model_type == 'sequential':
854    raise ValueError('Cannot use `get_multi_io_model` to construct '
855                     'sequential models')
856
857  if model_type == 'functional':
858    if shared_input_branch:
859      a_and_b = inputs
860      for layer in shared_input_branch:
861        a_and_b = layer(a_and_b)
862      a = a_and_b
863      b = a_and_b
864    else:
865      a, b = inputs
866
867    for layer in branch_a:
868      a = layer(a)
869    for layer in branch_b:
870      b = layer(b)
871    outputs = a, b
872
873    if shared_output_branch:
874      for layer in shared_output_branch:
875        outputs = layer(outputs)
876
877    return models.Model(inputs, outputs)
878
879  raise ValueError('Unknown model type {}'.format(model_type))
880
881
882_V2_OPTIMIZER_MAP = {
883    'adadelta': adadelta_v2.Adadelta,
884    'adagrad': adagrad_v2.Adagrad,
885    'adam': adam_v2.Adam,
886    'adamax': adamax_v2.Adamax,
887    'nadam': nadam_v2.Nadam,
888    'rmsprop': rmsprop_v2.RMSprop,
889    'sgd': gradient_descent_v2.SGD
890}
891
892
893def get_v2_optimizer(name, **kwargs):
894  """Get the v2 optimizer requested.
895
896  This is only necessary until v2 are the default, as we are testing in Eager,
897  and Eager + v1 optimizers fail tests. When we are in v2, the strings alone
898  should be sufficient, and this mapping can theoretically be removed.
899
900  Args:
901    name: string name of Keras v2 optimizer.
902    **kwargs: any kwargs to pass to the optimizer constructor.
903
904  Returns:
905    Initialized Keras v2 optimizer.
906
907  Raises:
908    ValueError: if an unknown name was passed.
909  """
910  try:
911    return _V2_OPTIMIZER_MAP[name](**kwargs)
912  except KeyError:
913    raise ValueError(
914        'Could not find requested v2 optimizer: {}\nValid choices: {}'.format(
915            name, list(_V2_OPTIMIZER_MAP.keys())))
916
917
918def get_expected_metric_variable_names(var_names, name_suffix=''):
919  """Returns expected metric variable names given names and prefix/suffix."""
920  if tf2.enabled() or context.executing_eagerly():
921    # In V1 eager mode and V2 variable names are not made unique.
922    return [n + ':0' for n in var_names]
923  # In V1 graph mode variable names are made unique using a suffix.
924  return [n + name_suffix + ':0' for n in var_names]
925
926
927def enable_v2_dtype_behavior(fn):
928  """Decorator for enabling the layer V2 dtype behavior on a test."""
929  return _set_v2_dtype_behavior(fn, True)
930
931
932def disable_v2_dtype_behavior(fn):
933  """Decorator for disabling the layer V2 dtype behavior on a test."""
934  return _set_v2_dtype_behavior(fn, False)
935
936
937def _set_v2_dtype_behavior(fn, enabled):
938  """Returns version of 'fn' that runs with v2 dtype behavior on or off."""
939  @functools.wraps(fn)
940  def wrapper(*args, **kwargs):
941    v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR
942    base_layer_utils.V2_DTYPE_BEHAVIOR = enabled
943    try:
944      return fn(*args, **kwargs)
945    finally:
946      base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior
947
948  return tf_decorator.make_decorator(fn, wrapper)
949
950
951@contextlib.contextmanager
952def device(should_use_gpu):
953  """Uses gpu when requested and available."""
954  if should_use_gpu and test_util.is_gpu_available():
955    dev = '/device:GPU:0'
956  else:
957    dev = '/device:CPU:0'
958  with ops.device(dev):
959    yield
960
961
962@contextlib.contextmanager
963def use_gpu():
964  """Uses gpu when requested and available."""
965  with device(should_use_gpu=True):
966    yield
967
968
969def for_all_test_methods(decorator, *args, **kwargs):
970  """Generate class-level decorator from given method-level decorator.
971
972  It is expected for the given decorator to take some arguments and return
973  a method that is then called on the test method to produce a decorated
974  method.
975
976  Args:
977    decorator: The decorator to apply.
978    *args: Positional arguments
979    **kwargs: Keyword arguments
980  Returns: Function that will decorate a given classes test methods with the
981    decorator.
982  """
983
984  def all_test_methods_impl(cls):
985    """Apply decorator to all test methods in class."""
986    for name in dir(cls):
987      value = getattr(cls, name)
988      if callable(value) and name.startswith('test') and (name !=
989                                                          'test_session'):
990        setattr(cls, name, decorator(*args, **kwargs)(value))
991    return cls
992
993  return all_test_methods_impl
994
995
996# The description is just for documentation purposes.
997def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
998  """Execute test with TensorFloat-32 disabled.
999
1000  While almost every real-world deep learning model runs fine with
1001  TensorFloat-32, many tests use assertAllClose or similar methods.
1002  TensorFloat-32 matmuls typically will cause such methods to fail with the
1003  default tolerances.
1004
1005  Args:
1006    description: A description used for documentation purposes, describing why
1007      the test requires TensorFloat-32 to be disabled.
1008
1009  Returns:
1010    Decorator which runs a test with TensorFloat-32 disabled.
1011  """
1012
1013  def decorator(f):
1014
1015    @functools.wraps(f)
1016    def decorated(self, *args, **kwargs):
1017      allowed = config.tensor_float_32_execution_enabled()
1018      try:
1019        config.enable_tensor_float_32_execution(False)
1020        f(self, *args, **kwargs)
1021      finally:
1022        config.enable_tensor_float_32_execution(allowed)
1023
1024    return decorated
1025
1026  return decorator
1027
1028
1029# The description is just for documentation purposes.
1030def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
1031  """Execute all tests in a class with TensorFloat-32 disabled."""
1032  return for_all_test_methods(run_without_tensor_float_32, description)
1033
1034
1035def run_v2_only(func=None):
1036  """Execute the decorated test only if running in v2 mode.
1037
1038  This function is intended to be applied to tests that exercise v2 only
1039  functionality. If the test is run in v1 mode it will simply be skipped.
1040
1041  See go/tf-test-decorator-cheatsheet for the decorators to use in different
1042  v1/v2/eager/graph combinations.
1043
1044  Args:
1045    func: function to be annotated. If `func` is None, this method returns a
1046      decorator the can be applied to a function. If `func` is not None this
1047      returns the decorator applied to `func`.
1048
1049  Returns:
1050    Returns a decorator that will conditionally skip the decorated test method.
1051  """
1052
1053  def decorator(f):
1054    if tf_inspect.isclass(f):
1055      raise ValueError('`run_v2_only` only supports test methods.')
1056
1057    def decorated(self, *args, **kwargs):
1058      if not tf2.enabled():
1059        self.skipTest('Test is only compatible with v2')
1060
1061      return f(self, *args, **kwargs)
1062
1063    return decorated
1064
1065  if func is not None:
1066    return decorator(func)
1067
1068  return decorator
1069
1070
1071def generate_combinations_with_testcase_name(**kwargs):
1072  """Generate combinations based on its keyword arguments using combine().
1073
1074  This function calls combine() and appends a testcase name to the list of
1075  dictionaries returned. The 'testcase_name' key is a required for named
1076  parameterized tests.
1077
1078  Args:
1079    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1080      `option=the_only_possibility`.
1081
1082  Returns:
1083    a list of dictionaries for each combination. Keys in the dictionaries are
1084    the keyword argument names.  Each key has one value - one of the
1085    corresponding keyword argument values.
1086  """
1087  sort_by_key = lambda k: k[0]
1088  combinations = []
1089  for key, values in sorted(kwargs.items(), key=sort_by_key):
1090    if not isinstance(values, list):
1091      values = [values]
1092    combinations.append([(key, value) for value in values])
1093
1094  combinations = [collections.OrderedDict(result)
1095                  for result in itertools.product(*combinations)]
1096  named_combinations = []
1097  for combination in combinations:
1098    assert isinstance(combination, collections.OrderedDict)
1099    name = ''.join([
1100        '_{}_{}'.format(''.join(filter(str.isalnum, key)),
1101                        ''.join(filter(str.isalnum, str(value))))
1102        for key, value in combination.items()
1103    ])
1104    named_combinations.append(
1105        collections.OrderedDict(
1106            list(combination.items()) +
1107            [('testcase_name', '_test{}'.format(name))]))
1108
1109  return named_combinations
1110