1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Tests reviving models from config and SavedModel.
17
18These tests ensure that a model revived from a combination of config and
19SavedModel have the expected structure.
20"""
21# TODO(kathywu): Move relevant tests from saved_model_test to
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import shutil
28
29from absl.testing import parameterized
30import numpy as np
31
32from tensorflow.python import keras
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.keras import backend
38from tensorflow.python.keras import keras_parameterized
39from tensorflow.python.keras import testing_utils
40from tensorflow.python.keras.saving.saved_model import load as keras_load
41from tensorflow.python.keras.utils import generic_utils
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import sparse_ops
45from tensorflow.python.ops import string_ops
46from tensorflow.python.ops import variables
47from tensorflow.python.platform import test
48
49
50class SubclassedModelNoConfig(keras.Model):
51
52  def __init__(self, a, b):
53    super(SubclassedModelNoConfig, self).__init__()
54
55    self.a = a
56    self.b = b
57    self.shared = CustomLayerNoConfig(a, b)
58    self.all_layers = []
59
60  def build(self, input_shape):
61    self.all_layers.extend([
62        self.shared,
63        CustomLayerWithConfig(self.a + 1, self.b + 2),
64        CustomLayerNoConfig(self.a + 3, self.b + 4),
65        keras.Sequential([
66            # TODO(b/145029112): Bug with losses when there are shared layers.
67            # self.shared,  <-- Enable when bug is fixed.
68            CustomLayerNoConfig(self.a + 5, self.b + 6)])])
69    super(SubclassedModelNoConfig, self).build(input_shape)
70
71  def call(self, inputs):
72    x = inputs
73    for layer in self.all_layers:
74      x = layer(x)
75    return x
76
77
78class SparseDense(keras.layers.Dense):
79
80  def call(self, inputs):
81    input_shape = array_ops.stack(
82        (math_ops.reduce_prod(array_ops.shape(inputs)[:-1]),
83         self.kernel.shape[0]))
84    output_shape = array_ops.concat(
85        (array_ops.shape(inputs)[:-1], [self.kernel.shape[1]]), -1)
86    x = sparse_ops.sparse_reshape(inputs, input_shape)
87    return array_ops.reshape(
88        self.activation(
89            sparse_ops.sparse_tensor_dense_matmul(x, self.kernel) + self.bias),
90        output_shape)
91
92
93class SubclassedSparseModelNoConfig(keras.Model):
94
95  def __init__(self, a, b):
96    super(SubclassedSparseModelNoConfig, self).__init__()
97    self.a = a
98    self.shared = CustomLayerNoConfig(a, b)
99    self.all_layers = [SparseDense(4)]
100
101  def call(self, inputs):
102    x = inputs
103    for layer in self.all_layers:
104      x = layer(x)
105    return self.shared(x + self.a)
106
107
108class SubclassedModelWithConfig(SubclassedModelNoConfig):
109
110  def get_config(self):
111    return {'a': self.a,
112            'b': self.b}
113
114  @classmethod
115  def from_config(cls, config):
116    return cls(**config)
117
118
119class CustomLayerNoConfig(keras.layers.Layer):
120
121  def __init__(self, a, b, name=None):
122    super(CustomLayerNoConfig, self).__init__(name=name)
123    self.a = variables.Variable(a, name='a')
124    self.b = b
125    def a_regularizer():
126      return self.a * 2
127    self.add_loss(a_regularizer)
128    self.sum_metric = keras.metrics.Sum(name='inputs_sum')
129    self.unused_metric = keras.metrics.Sum(name='not_added_to_metrics')
130
131  def build(self, input_shape):
132    self.c = variables.Variable(
133        constant_op.constant(1.0, shape=input_shape[1:]), name=self.name+'_c')
134
135  def call(self, inputs):
136    self.add_loss(math_ops.reduce_sum(inputs), inputs=inputs)
137    self.add_metric(self.sum_metric(inputs))
138    self.add_metric(inputs, aggregation='mean', name='mean')
139
140    return inputs + self.c
141
142
143class CustomLayerWithConfig(CustomLayerNoConfig):
144
145  def get_config(self):
146    return {'a': backend.get_value(self.a),
147            'b': self.b,
148            'name': self.name}
149
150
151class CustomNetworkDefaultConfig(keras.Model):
152
153  def __init__(self, num_classes, name=None):
154    inputs = keras.Input((2, 3), name='inputs')
155    x = keras.layers.Flatten(name='flatten')(inputs)
156    y = keras.layers.Dense(num_classes, name='outputs')(x)
157    super(CustomNetworkDefaultConfig, self).__init__(inputs, y, name=name)
158
159
160class CustomNetworkWithConfig(CustomNetworkDefaultConfig):
161
162  def __init__(self, num_classes, name=None):
163    super(CustomNetworkWithConfig, self).__init__(num_classes, name=name)
164    self._config_dict = dict(num_classes=num_classes)
165
166  def get_config(self):
167    return self._config_dict
168
169  @classmethod
170  def from_config(cls, config):
171    return cls(config['num_classes'], name=config.get('name'))
172
173
174class CustomNetworkWithConfigName(CustomNetworkWithConfig):
175
176  def __init__(self, num_classes, name=None):
177    super(CustomNetworkWithConfigName, self).__init__(num_classes, name=name)
178    self._config_dict['name'] = self.name
179
180
181class UnregisteredCustomSequentialModel(keras.Sequential):
182  # This class is *not* registered in the CustomObjectScope.
183
184  def __init__(self, **kwargs):
185    super(UnregisteredCustomSequentialModel, self).__init__(**kwargs)
186    self.add(keras.layers.InputLayer(input_shape=(2, 3)))
187
188
189class ReviveTestBase(keras_parameterized.TestCase):
190
191  def setUp(self):
192    super(ReviveTestBase, self).setUp()
193    self.path = self.get_temp_dir()
194    self.addCleanup(shutil.rmtree, self.path, ignore_errors=True)
195
196  def _assert_revived_correctness(self, model, revived):
197    self.assertAllEqual(model.input_names, revived.input_names)
198    self.assertAllEqual(model.output_names, revived.output_names)
199    if model.inputs is not None:
200      self.assertTrue(
201          all([
202              i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype
203              for (i, r) in zip(model.inputs, revived.inputs)
204          ]))
205      self.assertTrue(
206          all([
207              i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype
208              for (i, r) in zip(model.outputs, revived.outputs)
209          ]))
210
211    self.assertAllClose(self.evaluate(model.weights),
212                        self.evaluate(revived.weights))
213    input_arr = constant_op.constant(
214        np.random.random((2, 2, 3)).astype(np.float32))
215    if isinstance(revived._saved_model_inputs_spec,
216                  sparse_tensor.SparseTensorSpec):
217      input_arr = sparse_ops.from_dense(input_arr)
218
219    self.assertAllClose(model(input_arr), revived(input_arr))
220    self.assertAllClose(sum(model.losses), sum(revived.losses))
221    self.assertAllClose(len(model.losses), len(revived.losses))
222    self.assertEqual(len(model.metrics), len(revived.metrics))
223    # TODO(b/150403085): Investigate why the metric order changes when running
224    # this test in tf-nightly.
225    self.assertAllClose(sorted([m.result() for m in model.metrics]),
226                        sorted([m.result() for m in revived.metrics]))
227    model_layers = {layer.name: layer for layer in model.layers}
228    revived_layers = {layer.name: layer for layer in revived.layers}
229    self.assertAllEqual(model_layers.keys(), revived_layers.keys())
230
231    for name in model_layers:
232      model_layer = model_layers[name]
233      revived_layer = revived_layers[name]
234      self.assertEqual(model_layer.name, revived_layer.name)
235      self.assertEqual(model_layer.dtype, revived_layer.dtype)
236      self.assertEqual(model_layer.trainable, revived_layer.trainable)
237      if 'WithConfig' in type(model_layer).__name__:
238        self.assertEqual(type(model_layer), type(revived_layer))
239      else:
240        # When loading layers from SavedModel, a new class is dynamically
241        # created with the same name.
242        self.assertEqual(type(model_layer).__name__,
243                         type(revived_layer).__name__)
244
245
246# These tests take a while to run, so each should run in a separate shard
247# (putting them in the same TestCase resolves this).
248class TestBigModelRevive(ReviveTestBase):
249
250  @keras_parameterized.run_with_all_model_types
251  def test_revive(self):
252    input_shape = None
253    if testing_utils.get_model_type() == 'functional':
254      input_shape = (2, 3)
255
256    layer_with_config = CustomLayerWithConfig(1., 2)
257    layer_without_config = CustomLayerNoConfig(3., 4)
258    subclassed_with_config = SubclassedModelWithConfig(4., 6.)
259    subclassed_without_config = SubclassedModelNoConfig(7., 8.)
260
261    inputs = keras.Input((2, 3))
262    x = CustomLayerWithConfig(1., 2)(inputs)
263    x = CustomLayerNoConfig(3., 4)(x)
264    x = SubclassedModelWithConfig(4., 6.)(x)
265    x = SubclassedModelNoConfig(7., 8.)(x)
266    inner_model_functional = keras.Model(inputs, x)
267
268    inner_model_sequential = keras.Sequential(
269        [CustomLayerWithConfig(1., 2),
270         CustomLayerNoConfig(3., 4),
271         SubclassedModelWithConfig(4., 6.),
272         SubclassedModelNoConfig(7., 8.)])
273
274    class SubclassedModel(keras.Model):
275
276      def __init__(self):
277        super(SubclassedModel, self).__init__()
278        self.all_layers = [CustomLayerWithConfig(1., 2),
279                           CustomLayerNoConfig(3., 4),
280                           SubclassedModelWithConfig(4., 6.),
281                           SubclassedModelNoConfig(7., 8.)]
282
283      def call(self, inputs):
284        x = inputs
285        for layer in self.all_layers:
286          x = layer(x)
287        return x
288
289    inner_model_subclassed = SubclassedModel()
290
291    layers = [layer_with_config,
292              layer_without_config,
293              subclassed_with_config,
294              subclassed_without_config,
295              inner_model_functional,
296              inner_model_sequential,
297              inner_model_subclassed]
298    model = testing_utils.get_model_from_layers(
299        layers, input_shape=input_shape)
300    # Run data through the Model to create save spec and weights.
301    model.predict(np.ones((10, 2, 3)), batch_size=10)
302
303    # Test that the correct checkpointed values are loaded, whether the layer is
304    # created from the config or SavedModel.
305    layer_with_config.c.assign(2 * layer_with_config.c)
306    layer_without_config.c.assign(3 * layer_without_config.c)
307
308    model.save(self.path, save_format='tf')
309    revived = keras_load.load(self.path)
310    self._assert_revived_correctness(model, revived)
311
312
313class TestModelRevive(ReviveTestBase):
314
315  def test_revive_subclassed_with_nested_model(self):
316    model = SubclassedModelNoConfig(1., 2.)
317    # Run data through the Model to create save spec and weights.
318    model.predict(np.ones((10, 2, 3)), batch_size=10)
319    model.save(self.path, save_format='tf')
320    revived = keras_load.load(self.path)
321    self._assert_revived_correctness(model, revived)
322
323  def test_revive_subclassed_with_sparse_model(self):
324    model = SubclassedSparseModelNoConfig(1., 2.)
325    # Run data through the Model to create save spec and weights.
326    x = sparse_ops.from_dense(np.ones((10, 2, 3), dtype=np.float32))
327    model.predict(x, batch_size=10)
328    model.save(self.path, save_format='tf')
329    revived = keras_load.load(self.path)
330    self._assert_revived_correctness(model, revived)
331
332  def test_revive_unregistered_sequential(self):
333    model = UnregisteredCustomSequentialModel()
334    x = np.random.random((2, 2, 3)).astype(np.float32)
335    model(x)
336    model.save(self.path, save_format='tf')
337    revived = keras_load.load(self.path)
338    self._assert_revived_correctness(model, revived)
339
340  def test_revive_sequential_inputs(self):
341    model = keras.models.Sequential([
342        keras.Input((None,), dtype=dtypes.string),
343        keras.layers.Lambda(string_ops.string_lower)
344    ])
345    model.save(self.path, save_format='tf')
346    revived = keras_load.load(self.path)
347    revived_layers = list(
348        revived._flatten_layers(include_self=False, recursive=False))
349    self.assertEqual(dtypes.string, revived_layers[0].dtype)
350
351  @parameterized.named_parameters(
352      ('default_config', CustomNetworkDefaultConfig),
353      ('with_config', CustomNetworkWithConfig),
354      ('with_config_name', CustomNetworkWithConfigName))
355  def test_revive_network(self, model_cls):
356    model = model_cls(8)
357    model.save(self.path, include_optimizer=False, save_format='tf')
358    revived = keras_load.load(self.path, compile=False)
359    self._assert_revived_correctness(model, revived)
360
361  def test_load_compiled_metrics(self):
362    model = testing_utils.get_small_sequential_mlp(1, 3)
363
364    # Compile with dense categorical accuracy
365    model.compile('rmsprop', 'mse', 'acc')
366    x = np.random.random((5, 10)).astype(np.float32)
367    y_true = np.random.random((5, 3)).astype(np.float32)
368    model.train_on_batch(x, y_true)
369
370    model.save(self.path, include_optimizer=True, save_format='tf')
371    revived = keras_load.load(self.path, compile=True)
372    self.assertAllClose(model.test_on_batch(x, y_true),
373                        revived.test_on_batch(x, y_true))
374
375    # Compile with sparse categorical accuracy
376    model.compile('rmsprop', 'mse', 'acc')
377    y_true = np.random.randint(0, 3, (5, 1)).astype(np.float32)
378    model.train_on_batch(x, y_true)
379    model.save(self.path, include_optimizer=True, save_format='tf')
380    revived = keras_load.load(self.path, compile=True)
381    self.assertAllClose(model.test_on_batch(x, y_true),
382                        revived.test_on_batch(x, y_true))
383
384  def test_revived_model_has_save_spec(self):
385    model = SubclassedModelWithConfig(2, 3)
386    model.predict(np.random.random((5, 10)).astype(np.float32))
387    model.save(self.path, save_format='tf')
388    revived = keras_load.load(self.path, compile=True)
389    self.assertAllEqual(
390        model._get_save_spec(dynamic_batch=False),
391        revived._get_save_spec(dynamic_batch=False))
392
393
394if __name__ == '__main__':
395  ops.enable_eager_execution()
396  with generic_utils.CustomObjectScope({
397      'CustomLayerWithConfig': CustomLayerWithConfig,
398      'CustomNetworkWithConfig': CustomNetworkWithConfig,
399      'CustomNetworkWithConfigName': CustomNetworkWithConfigName,
400      'SubclassedModelWithConfig': SubclassedModelWithConfig
401  }):
402    test.main()
403