1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Tests reviving models from config and SavedModel. 17 18These tests ensure that a model revived from a combination of config and 19SavedModel have the expected structure. 20""" 21# TODO(kathywu): Move relevant tests from saved_model_test to 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import shutil 28 29from absl.testing import parameterized 30import numpy as np 31 32from tensorflow.python import keras 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import sparse_tensor 37from tensorflow.python.keras import backend 38from tensorflow.python.keras import keras_parameterized 39from tensorflow.python.keras import testing_utils 40from tensorflow.python.keras.saving.saved_model import load as keras_load 41from tensorflow.python.keras.utils import generic_utils 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import sparse_ops 45from tensorflow.python.ops import string_ops 46from tensorflow.python.ops import variables 47from tensorflow.python.platform import test 48 49 50class SubclassedModelNoConfig(keras.Model): 51 52 def __init__(self, a, b): 53 super(SubclassedModelNoConfig, self).__init__() 54 55 self.a = a 56 self.b = b 57 self.shared = CustomLayerNoConfig(a, b) 58 self.all_layers = [] 59 60 def build(self, input_shape): 61 self.all_layers.extend([ 62 self.shared, 63 CustomLayerWithConfig(self.a + 1, self.b + 2), 64 CustomLayerNoConfig(self.a + 3, self.b + 4), 65 keras.Sequential([ 66 # TODO(b/145029112): Bug with losses when there are shared layers. 67 # self.shared, <-- Enable when bug is fixed. 68 CustomLayerNoConfig(self.a + 5, self.b + 6)])]) 69 super(SubclassedModelNoConfig, self).build(input_shape) 70 71 def call(self, inputs): 72 x = inputs 73 for layer in self.all_layers: 74 x = layer(x) 75 return x 76 77 78class SparseDense(keras.layers.Dense): 79 80 def call(self, inputs): 81 input_shape = array_ops.stack( 82 (math_ops.reduce_prod(array_ops.shape(inputs)[:-1]), 83 self.kernel.shape[0])) 84 output_shape = array_ops.concat( 85 (array_ops.shape(inputs)[:-1], [self.kernel.shape[1]]), -1) 86 x = sparse_ops.sparse_reshape(inputs, input_shape) 87 return array_ops.reshape( 88 self.activation( 89 sparse_ops.sparse_tensor_dense_matmul(x, self.kernel) + self.bias), 90 output_shape) 91 92 93class SubclassedSparseModelNoConfig(keras.Model): 94 95 def __init__(self, a, b): 96 super(SubclassedSparseModelNoConfig, self).__init__() 97 self.a = a 98 self.shared = CustomLayerNoConfig(a, b) 99 self.all_layers = [SparseDense(4)] 100 101 def call(self, inputs): 102 x = inputs 103 for layer in self.all_layers: 104 x = layer(x) 105 return self.shared(x + self.a) 106 107 108class SubclassedModelWithConfig(SubclassedModelNoConfig): 109 110 def get_config(self): 111 return {'a': self.a, 112 'b': self.b} 113 114 @classmethod 115 def from_config(cls, config): 116 return cls(**config) 117 118 119class CustomLayerNoConfig(keras.layers.Layer): 120 121 def __init__(self, a, b, name=None): 122 super(CustomLayerNoConfig, self).__init__(name=name) 123 self.a = variables.Variable(a, name='a') 124 self.b = b 125 def a_regularizer(): 126 return self.a * 2 127 self.add_loss(a_regularizer) 128 self.sum_metric = keras.metrics.Sum(name='inputs_sum') 129 self.unused_metric = keras.metrics.Sum(name='not_added_to_metrics') 130 131 def build(self, input_shape): 132 self.c = variables.Variable( 133 constant_op.constant(1.0, shape=input_shape[1:]), name=self.name+'_c') 134 135 def call(self, inputs): 136 self.add_loss(math_ops.reduce_sum(inputs), inputs=inputs) 137 self.add_metric(self.sum_metric(inputs)) 138 self.add_metric(inputs, aggregation='mean', name='mean') 139 140 return inputs + self.c 141 142 143class CustomLayerWithConfig(CustomLayerNoConfig): 144 145 def get_config(self): 146 return {'a': backend.get_value(self.a), 147 'b': self.b, 148 'name': self.name} 149 150 151class CustomNetworkDefaultConfig(keras.Model): 152 153 def __init__(self, num_classes, name=None): 154 inputs = keras.Input((2, 3), name='inputs') 155 x = keras.layers.Flatten(name='flatten')(inputs) 156 y = keras.layers.Dense(num_classes, name='outputs')(x) 157 super(CustomNetworkDefaultConfig, self).__init__(inputs, y, name=name) 158 159 160class CustomNetworkWithConfig(CustomNetworkDefaultConfig): 161 162 def __init__(self, num_classes, name=None): 163 super(CustomNetworkWithConfig, self).__init__(num_classes, name=name) 164 self._config_dict = dict(num_classes=num_classes) 165 166 def get_config(self): 167 return self._config_dict 168 169 @classmethod 170 def from_config(cls, config): 171 return cls(config['num_classes'], name=config.get('name')) 172 173 174class CustomNetworkWithConfigName(CustomNetworkWithConfig): 175 176 def __init__(self, num_classes, name=None): 177 super(CustomNetworkWithConfigName, self).__init__(num_classes, name=name) 178 self._config_dict['name'] = self.name 179 180 181class UnregisteredCustomSequentialModel(keras.Sequential): 182 # This class is *not* registered in the CustomObjectScope. 183 184 def __init__(self, **kwargs): 185 super(UnregisteredCustomSequentialModel, self).__init__(**kwargs) 186 self.add(keras.layers.InputLayer(input_shape=(2, 3))) 187 188 189class ReviveTestBase(keras_parameterized.TestCase): 190 191 def setUp(self): 192 super(ReviveTestBase, self).setUp() 193 self.path = self.get_temp_dir() 194 self.addCleanup(shutil.rmtree, self.path, ignore_errors=True) 195 196 def _assert_revived_correctness(self, model, revived): 197 self.assertAllEqual(model.input_names, revived.input_names) 198 self.assertAllEqual(model.output_names, revived.output_names) 199 if model.inputs is not None: 200 self.assertTrue( 201 all([ 202 i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype 203 for (i, r) in zip(model.inputs, revived.inputs) 204 ])) 205 self.assertTrue( 206 all([ 207 i.shape.as_list() == r.shape.as_list() and i.dtype == r.dtype 208 for (i, r) in zip(model.outputs, revived.outputs) 209 ])) 210 211 self.assertAllClose(self.evaluate(model.weights), 212 self.evaluate(revived.weights)) 213 input_arr = constant_op.constant( 214 np.random.random((2, 2, 3)).astype(np.float32)) 215 if isinstance(revived._saved_model_inputs_spec, 216 sparse_tensor.SparseTensorSpec): 217 input_arr = sparse_ops.from_dense(input_arr) 218 219 self.assertAllClose(model(input_arr), revived(input_arr)) 220 self.assertAllClose(sum(model.losses), sum(revived.losses)) 221 self.assertAllClose(len(model.losses), len(revived.losses)) 222 self.assertEqual(len(model.metrics), len(revived.metrics)) 223 # TODO(b/150403085): Investigate why the metric order changes when running 224 # this test in tf-nightly. 225 self.assertAllClose(sorted([m.result() for m in model.metrics]), 226 sorted([m.result() for m in revived.metrics])) 227 model_layers = {layer.name: layer for layer in model.layers} 228 revived_layers = {layer.name: layer for layer in revived.layers} 229 self.assertAllEqual(model_layers.keys(), revived_layers.keys()) 230 231 for name in model_layers: 232 model_layer = model_layers[name] 233 revived_layer = revived_layers[name] 234 self.assertEqual(model_layer.name, revived_layer.name) 235 self.assertEqual(model_layer.dtype, revived_layer.dtype) 236 self.assertEqual(model_layer.trainable, revived_layer.trainable) 237 if 'WithConfig' in type(model_layer).__name__: 238 self.assertEqual(type(model_layer), type(revived_layer)) 239 else: 240 # When loading layers from SavedModel, a new class is dynamically 241 # created with the same name. 242 self.assertEqual(type(model_layer).__name__, 243 type(revived_layer).__name__) 244 245 246# These tests take a while to run, so each should run in a separate shard 247# (putting them in the same TestCase resolves this). 248class TestBigModelRevive(ReviveTestBase): 249 250 @keras_parameterized.run_with_all_model_types 251 def test_revive(self): 252 input_shape = None 253 if testing_utils.get_model_type() == 'functional': 254 input_shape = (2, 3) 255 256 layer_with_config = CustomLayerWithConfig(1., 2) 257 layer_without_config = CustomLayerNoConfig(3., 4) 258 subclassed_with_config = SubclassedModelWithConfig(4., 6.) 259 subclassed_without_config = SubclassedModelNoConfig(7., 8.) 260 261 inputs = keras.Input((2, 3)) 262 x = CustomLayerWithConfig(1., 2)(inputs) 263 x = CustomLayerNoConfig(3., 4)(x) 264 x = SubclassedModelWithConfig(4., 6.)(x) 265 x = SubclassedModelNoConfig(7., 8.)(x) 266 inner_model_functional = keras.Model(inputs, x) 267 268 inner_model_sequential = keras.Sequential( 269 [CustomLayerWithConfig(1., 2), 270 CustomLayerNoConfig(3., 4), 271 SubclassedModelWithConfig(4., 6.), 272 SubclassedModelNoConfig(7., 8.)]) 273 274 class SubclassedModel(keras.Model): 275 276 def __init__(self): 277 super(SubclassedModel, self).__init__() 278 self.all_layers = [CustomLayerWithConfig(1., 2), 279 CustomLayerNoConfig(3., 4), 280 SubclassedModelWithConfig(4., 6.), 281 SubclassedModelNoConfig(7., 8.)] 282 283 def call(self, inputs): 284 x = inputs 285 for layer in self.all_layers: 286 x = layer(x) 287 return x 288 289 inner_model_subclassed = SubclassedModel() 290 291 layers = [layer_with_config, 292 layer_without_config, 293 subclassed_with_config, 294 subclassed_without_config, 295 inner_model_functional, 296 inner_model_sequential, 297 inner_model_subclassed] 298 model = testing_utils.get_model_from_layers( 299 layers, input_shape=input_shape) 300 # Run data through the Model to create save spec and weights. 301 model.predict(np.ones((10, 2, 3)), batch_size=10) 302 303 # Test that the correct checkpointed values are loaded, whether the layer is 304 # created from the config or SavedModel. 305 layer_with_config.c.assign(2 * layer_with_config.c) 306 layer_without_config.c.assign(3 * layer_without_config.c) 307 308 model.save(self.path, save_format='tf') 309 revived = keras_load.load(self.path) 310 self._assert_revived_correctness(model, revived) 311 312 313class TestModelRevive(ReviveTestBase): 314 315 def test_revive_subclassed_with_nested_model(self): 316 model = SubclassedModelNoConfig(1., 2.) 317 # Run data through the Model to create save spec and weights. 318 model.predict(np.ones((10, 2, 3)), batch_size=10) 319 model.save(self.path, save_format='tf') 320 revived = keras_load.load(self.path) 321 self._assert_revived_correctness(model, revived) 322 323 def test_revive_subclassed_with_sparse_model(self): 324 model = SubclassedSparseModelNoConfig(1., 2.) 325 # Run data through the Model to create save spec and weights. 326 x = sparse_ops.from_dense(np.ones((10, 2, 3), dtype=np.float32)) 327 model.predict(x, batch_size=10) 328 model.save(self.path, save_format='tf') 329 revived = keras_load.load(self.path) 330 self._assert_revived_correctness(model, revived) 331 332 def test_revive_unregistered_sequential(self): 333 model = UnregisteredCustomSequentialModel() 334 x = np.random.random((2, 2, 3)).astype(np.float32) 335 model(x) 336 model.save(self.path, save_format='tf') 337 revived = keras_load.load(self.path) 338 self._assert_revived_correctness(model, revived) 339 340 def test_revive_sequential_inputs(self): 341 model = keras.models.Sequential([ 342 keras.Input((None,), dtype=dtypes.string), 343 keras.layers.Lambda(string_ops.string_lower) 344 ]) 345 model.save(self.path, save_format='tf') 346 revived = keras_load.load(self.path) 347 revived_layers = list( 348 revived._flatten_layers(include_self=False, recursive=False)) 349 self.assertEqual(dtypes.string, revived_layers[0].dtype) 350 351 @parameterized.named_parameters( 352 ('default_config', CustomNetworkDefaultConfig), 353 ('with_config', CustomNetworkWithConfig), 354 ('with_config_name', CustomNetworkWithConfigName)) 355 def test_revive_network(self, model_cls): 356 model = model_cls(8) 357 model.save(self.path, include_optimizer=False, save_format='tf') 358 revived = keras_load.load(self.path, compile=False) 359 self._assert_revived_correctness(model, revived) 360 361 def test_load_compiled_metrics(self): 362 model = testing_utils.get_small_sequential_mlp(1, 3) 363 364 # Compile with dense categorical accuracy 365 model.compile('rmsprop', 'mse', 'acc') 366 x = np.random.random((5, 10)).astype(np.float32) 367 y_true = np.random.random((5, 3)).astype(np.float32) 368 model.train_on_batch(x, y_true) 369 370 model.save(self.path, include_optimizer=True, save_format='tf') 371 revived = keras_load.load(self.path, compile=True) 372 self.assertAllClose(model.test_on_batch(x, y_true), 373 revived.test_on_batch(x, y_true)) 374 375 # Compile with sparse categorical accuracy 376 model.compile('rmsprop', 'mse', 'acc') 377 y_true = np.random.randint(0, 3, (5, 1)).astype(np.float32) 378 model.train_on_batch(x, y_true) 379 model.save(self.path, include_optimizer=True, save_format='tf') 380 revived = keras_load.load(self.path, compile=True) 381 self.assertAllClose(model.test_on_batch(x, y_true), 382 revived.test_on_batch(x, y_true)) 383 384 def test_revived_model_has_save_spec(self): 385 model = SubclassedModelWithConfig(2, 3) 386 model.predict(np.random.random((5, 10)).astype(np.float32)) 387 model.save(self.path, save_format='tf') 388 revived = keras_load.load(self.path, compile=True) 389 self.assertAllEqual( 390 model._get_save_spec(dynamic_batch=False), 391 revived._get_save_spec(dynamic_batch=False)) 392 393 394if __name__ == '__main__': 395 ops.enable_eager_execution() 396 with generic_utils.CustomObjectScope({ 397 'CustomLayerWithConfig': CustomLayerWithConfig, 398 'CustomNetworkWithConfig': CustomNetworkWithConfig, 399 'CustomNetworkWithConfigName': CustomNetworkWithConfigName, 400 'SubclassedModelWithConfig': SubclassedModelWithConfig 401 }): 402 test.main() 403