1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for saving/loading function for keras Model.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21 22from tensorflow.python import keras 23 24# Declaring namedtuple() 25ModelFn = collections.namedtuple('ModelFn', 26 ['model', 'input_shape', 'target_shape']) 27 28 29def basic_sequential(): 30 """Basic sequential model.""" 31 model = keras.Sequential([ 32 keras.layers.Dense(3, activation='relu', input_shape=(3,)), 33 keras.layers.Dense(2, activation='softmax'), 34 ]) 35 return ModelFn(model, (None, 3), (None, 2)) 36 37 38def basic_sequential_deferred(): 39 """Sequential model with deferred input shape.""" 40 model = keras.Sequential([ 41 keras.layers.Dense(3, activation='relu'), 42 keras.layers.Dense(2, activation='softmax'), 43 ]) 44 return ModelFn(model, (None, 3), (None, 2)) 45 46 47def stacked_rnn(): 48 """Stacked RNN model.""" 49 inputs = keras.Input((None, 3)) 50 layer = keras.layers.RNN([keras.layers.LSTMCell(2) for _ in range(3)]) 51 x = layer(inputs) 52 outputs = keras.layers.Dense(2)(x) 53 model = keras.Model(inputs, outputs) 54 return ModelFn(model, (None, 4, 3), (None, 2)) 55 56 57def lstm(): 58 """LSTM model.""" 59 inputs = keras.Input((None, 3)) 60 x = keras.layers.LSTM(4, return_sequences=True)(inputs) 61 x = keras.layers.LSTM(3, return_sequences=True)(x) 62 x = keras.layers.LSTM(2, return_sequences=False)(x) 63 outputs = keras.layers.Dense(2)(x) 64 model = keras.Model(inputs, outputs) 65 return ModelFn(model, (None, 4, 3), (None, 2)) 66 67 68def multi_input_multi_output(): 69 """Multi-input Multi-output model.""" 70 body_input = keras.Input(shape=(None,), name='body') 71 tags_input = keras.Input(shape=(2,), name='tags') 72 73 x = keras.layers.Embedding(10, 4)(body_input) 74 body_features = keras.layers.LSTM(5)(x) 75 x = keras.layers.concatenate([body_features, tags_input]) 76 77 pred_1 = keras.layers.Dense(2, activation='sigmoid', name='priority')(x) 78 pred_2 = keras.layers.Dense(3, activation='softmax', name='department')(x) 79 80 model = keras.Model( 81 inputs=[body_input, tags_input], outputs=[pred_1, pred_2]) 82 return ModelFn(model, [(None, 1), (None, 2)], [(None, 2), (None, 3)]) 83 84 85def nested_sequential_in_functional(): 86 """A sequential model nested in a functional model.""" 87 inner_model = keras.Sequential([ 88 keras.layers.Dense(3, activation='relu', input_shape=(3,)), 89 keras.layers.Dense(2, activation='relu'), 90 ]) 91 92 inputs = keras.Input(shape=(3,)) 93 x = inner_model(inputs) 94 outputs = keras.layers.Dense(2, activation='softmax')(x) 95 model = keras.Model(inputs, outputs) 96 return ModelFn(model, (None, 3), (None, 2)) 97 98 99def seq_to_seq(): 100 """Sequence to sequence model.""" 101 num_encoder_tokens = 3 102 num_decoder_tokens = 3 103 latent_dim = 2 104 encoder_inputs = keras.Input(shape=(None, num_encoder_tokens)) 105 encoder = keras.layers.LSTM(latent_dim, return_state=True) 106 _, state_h, state_c = encoder(encoder_inputs) 107 encoder_states = [state_h, state_c] 108 decoder_inputs = keras.Input(shape=(None, num_decoder_tokens)) 109 decoder_lstm = keras.layers.LSTM( 110 latent_dim, return_sequences=True, return_state=True) 111 decoder_outputs, _, _ = decoder_lstm( 112 decoder_inputs, initial_state=encoder_states) 113 decoder_dense = keras.layers.Dense(num_decoder_tokens, activation='softmax') 114 decoder_outputs = decoder_dense(decoder_outputs) 115 model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs) 116 return ModelFn( 117 model, [(None, 2, num_encoder_tokens), (None, 2, num_decoder_tokens)], 118 (None, 2, num_decoder_tokens)) 119 120 121def shared_layer_functional(): 122 """Shared layer in a functional model.""" 123 main_input = keras.Input(shape=(10,), dtype='int32', name='main_input') 124 x = keras.layers.Embedding( 125 output_dim=5, input_dim=4, input_length=10)(main_input) 126 lstm_out = keras.layers.LSTM(3)(x) 127 auxiliary_output = keras.layers.Dense( 128 1, activation='sigmoid', name='aux_output')(lstm_out) 129 auxiliary_input = keras.Input(shape=(5,), name='aux_input') 130 x = keras.layers.concatenate([lstm_out, auxiliary_input]) 131 x = keras.layers.Dense(2, activation='relu')(x) 132 main_output = keras.layers.Dense( 133 1, activation='sigmoid', name='main_output')(x) 134 model = keras.Model( 135 inputs=[main_input, auxiliary_input], 136 outputs=[main_output, auxiliary_output]) 137 return ModelFn(model, [(None, 10), (None, 5)], [(None, 1), (None, 1)]) 138 139 140def shared_sequential(): 141 """Shared sequential model in a functional model.""" 142 inner_model = keras.Sequential([ 143 keras.layers.Conv2D(2, 3, activation='relu'), 144 keras.layers.Conv2D(2, 3, activation='relu'), 145 ]) 146 inputs_1 = keras.Input((5, 5, 3)) 147 inputs_2 = keras.Input((5, 5, 3)) 148 x1 = inner_model(inputs_1) 149 x2 = inner_model(inputs_2) 150 x = keras.layers.concatenate([x1, x2]) 151 outputs = keras.layers.GlobalAveragePooling2D()(x) 152 model = keras.Model([inputs_1, inputs_2], outputs) 153 return ModelFn(model, [(None, 5, 5, 3), (None, 5, 5, 3)], (None, 4)) 154 155 156class MySubclassModel(keras.Model): 157 """A subclass model.""" 158 159 def __init__(self, input_dim=3): 160 super(MySubclassModel, self).__init__(name='my_subclass_model') 161 self._config = {'input_dim': input_dim} 162 self.dense1 = keras.layers.Dense(8, activation='relu') 163 self.dense2 = keras.layers.Dense(2, activation='softmax') 164 self.bn = keras.layers.BatchNormalization() 165 self.dp = keras.layers.Dropout(0.5) 166 167 def call(self, inputs, **kwargs): 168 x = self.dense1(inputs) 169 x = self.dp(x) 170 x = self.bn(x) 171 return self.dense2(x) 172 173 def get_config(self): 174 return self._config 175 176 @classmethod 177 def from_config(cls, config): 178 return cls(**config) 179 180 181def nested_subclassed_model(): 182 """A subclass model nested in another subclass model.""" 183 184 class NestedSubclassModel(keras.Model): 185 """A nested subclass model.""" 186 187 def __init__(self): 188 super(NestedSubclassModel, self).__init__() 189 self.dense1 = keras.layers.Dense(4, activation='relu') 190 self.dense2 = keras.layers.Dense(2, activation='relu') 191 self.bn = keras.layers.BatchNormalization() 192 self.inner_subclass_model = MySubclassModel() 193 194 def call(self, inputs): 195 x = self.dense1(inputs) 196 x = self.bn(x) 197 x = self.inner_subclass_model(x) 198 return self.dense2(x) 199 200 return ModelFn(NestedSubclassModel(), (None, 3), (None, 2)) 201 202 203def nested_subclassed_in_functional_model(): 204 """A subclass model nested in a functional model.""" 205 inner_subclass_model = MySubclassModel() 206 inputs = keras.Input(shape=(3,)) 207 x = inner_subclass_model(inputs) 208 x = keras.layers.BatchNormalization()(x) 209 outputs = keras.layers.Dense(2, activation='softmax')(x) 210 model = keras.Model(inputs, outputs) 211 return ModelFn(model, (None, 3), (None, 2)) 212 213 214def nested_functional_in_subclassed_model(): 215 """A functional model nested in a subclass model.""" 216 def get_functional_model(): 217 inputs = keras.Input(shape=(4,)) 218 x = keras.layers.Dense(4, activation='relu')(inputs) 219 x = keras.layers.BatchNormalization()(x) 220 outputs = keras.layers.Dense(2)(x) 221 return keras.Model(inputs, outputs) 222 223 class NestedFunctionalInSubclassModel(keras.Model): 224 """A functional nested in subclass model.""" 225 226 def __init__(self): 227 super(NestedFunctionalInSubclassModel, self).__init__( 228 name='nested_functional_in_subclassed_model') 229 self.dense1 = keras.layers.Dense(4, activation='relu') 230 self.dense2 = keras.layers.Dense(2, activation='relu') 231 self.inner_functional_model = get_functional_model() 232 233 def call(self, inputs): 234 x = self.dense1(inputs) 235 x = self.inner_functional_model(x) 236 return self.dense2(x) 237 return ModelFn(NestedFunctionalInSubclassModel(), (None, 3), (None, 2)) 238 239 240def shared_layer_subclassed_model(): 241 """Shared layer in a subclass model.""" 242 243 class SharedLayerSubclassModel(keras.Model): 244 """A subclass model with shared layers.""" 245 246 def __init__(self): 247 super(SharedLayerSubclassModel, self).__init__( 248 name='shared_layer_subclass_model') 249 self.dense = keras.layers.Dense(3, activation='relu') 250 self.dp = keras.layers.Dropout(0.5) 251 self.bn = keras.layers.BatchNormalization() 252 253 def call(self, inputs): 254 x = self.dense(inputs) 255 x = self.dp(x) 256 x = self.bn(x) 257 return self.dense(x) 258 return ModelFn(SharedLayerSubclassModel(), (None, 3), (None, 3)) 259 260 261def functional_with_keyword_args(): 262 """A functional model with keyword args.""" 263 inputs = keras.Input(shape=(3,)) 264 x = keras.layers.Dense(4)(inputs) 265 x = keras.layers.BatchNormalization()(x) 266 outputs = keras.layers.Dense(2)(x) 267 268 model = keras.Model(inputs, outputs, name='m', trainable=False) 269 return ModelFn(model, (None, 3), (None, 2)) 270 271 272ALL_MODELS = [ 273 ('basic_sequential', basic_sequential), 274 ('basic_sequential_deferred', basic_sequential_deferred), 275 ('stacked_rnn', stacked_rnn), 276 ('lstm', lstm), 277 ('multi_input_multi_output', multi_input_multi_output), 278 ('nested_sequential_in_functional', nested_sequential_in_functional), 279 ('seq_to_seq', seq_to_seq), 280 ('shared_layer_functional', shared_layer_functional), 281 ('shared_sequential', shared_sequential), 282 ('nested_subclassed_model', nested_subclassed_model), 283 ('nested_subclassed_in_functional_model', 284 nested_subclassed_in_functional_model), 285 ('nested_functional_in_subclassed_model', 286 nested_functional_in_subclassed_model), 287 ('shared_layer_subclassed_model', shared_layer_subclassed_model), 288 ('functional_with_keyword_args', functional_with_keyword_args) 289] 290 291 292def get_models(exclude_models=None): 293 """Get all models excluding the specified ones.""" 294 models = [model for model in ALL_MODELS 295 if model[0] not in exclude_models] 296 return models 297