1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for saving/loading function for keras Model."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21
22from tensorflow.python import keras
23
24# Declaring namedtuple()
25ModelFn = collections.namedtuple('ModelFn',
26                                 ['model', 'input_shape', 'target_shape'])
27
28
29def basic_sequential():
30  """Basic sequential model."""
31  model = keras.Sequential([
32      keras.layers.Dense(3, activation='relu', input_shape=(3,)),
33      keras.layers.Dense(2, activation='softmax'),
34  ])
35  return ModelFn(model, (None, 3), (None, 2))
36
37
38def basic_sequential_deferred():
39  """Sequential model with deferred input shape."""
40  model = keras.Sequential([
41      keras.layers.Dense(3, activation='relu'),
42      keras.layers.Dense(2, activation='softmax'),
43  ])
44  return ModelFn(model, (None, 3), (None, 2))
45
46
47def stacked_rnn():
48  """Stacked RNN model."""
49  inputs = keras.Input((None, 3))
50  layer = keras.layers.RNN([keras.layers.LSTMCell(2) for _ in range(3)])
51  x = layer(inputs)
52  outputs = keras.layers.Dense(2)(x)
53  model = keras.Model(inputs, outputs)
54  return ModelFn(model, (None, 4, 3), (None, 2))
55
56
57def lstm():
58  """LSTM model."""
59  inputs = keras.Input((None, 3))
60  x = keras.layers.LSTM(4, return_sequences=True)(inputs)
61  x = keras.layers.LSTM(3, return_sequences=True)(x)
62  x = keras.layers.LSTM(2, return_sequences=False)(x)
63  outputs = keras.layers.Dense(2)(x)
64  model = keras.Model(inputs, outputs)
65  return ModelFn(model, (None, 4, 3), (None, 2))
66
67
68def multi_input_multi_output():
69  """Multi-input Multi-output model."""
70  body_input = keras.Input(shape=(None,), name='body')
71  tags_input = keras.Input(shape=(2,), name='tags')
72
73  x = keras.layers.Embedding(10, 4)(body_input)
74  body_features = keras.layers.LSTM(5)(x)
75  x = keras.layers.concatenate([body_features, tags_input])
76
77  pred_1 = keras.layers.Dense(2, activation='sigmoid', name='priority')(x)
78  pred_2 = keras.layers.Dense(3, activation='softmax', name='department')(x)
79
80  model = keras.Model(
81      inputs=[body_input, tags_input], outputs=[pred_1, pred_2])
82  return ModelFn(model, [(None, 1), (None, 2)], [(None, 2), (None, 3)])
83
84
85def nested_sequential_in_functional():
86  """A sequential model nested in a functional model."""
87  inner_model = keras.Sequential([
88      keras.layers.Dense(3, activation='relu', input_shape=(3,)),
89      keras.layers.Dense(2, activation='relu'),
90  ])
91
92  inputs = keras.Input(shape=(3,))
93  x = inner_model(inputs)
94  outputs = keras.layers.Dense(2, activation='softmax')(x)
95  model = keras.Model(inputs, outputs)
96  return ModelFn(model, (None, 3), (None, 2))
97
98
99def seq_to_seq():
100  """Sequence to sequence model."""
101  num_encoder_tokens = 3
102  num_decoder_tokens = 3
103  latent_dim = 2
104  encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
105  encoder = keras.layers.LSTM(latent_dim, return_state=True)
106  _, state_h, state_c = encoder(encoder_inputs)
107  encoder_states = [state_h, state_c]
108  decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
109  decoder_lstm = keras.layers.LSTM(
110      latent_dim, return_sequences=True, return_state=True)
111  decoder_outputs, _, _ = decoder_lstm(
112      decoder_inputs, initial_state=encoder_states)
113  decoder_dense = keras.layers.Dense(num_decoder_tokens, activation='softmax')
114  decoder_outputs = decoder_dense(decoder_outputs)
115  model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
116  return ModelFn(
117      model, [(None, 2, num_encoder_tokens), (None, 2, num_decoder_tokens)],
118      (None, 2, num_decoder_tokens))
119
120
121def shared_layer_functional():
122  """Shared layer in a functional model."""
123  main_input = keras.Input(shape=(10,), dtype='int32', name='main_input')
124  x = keras.layers.Embedding(
125      output_dim=5, input_dim=4, input_length=10)(main_input)
126  lstm_out = keras.layers.LSTM(3)(x)
127  auxiliary_output = keras.layers.Dense(
128      1, activation='sigmoid', name='aux_output')(lstm_out)
129  auxiliary_input = keras.Input(shape=(5,), name='aux_input')
130  x = keras.layers.concatenate([lstm_out, auxiliary_input])
131  x = keras.layers.Dense(2, activation='relu')(x)
132  main_output = keras.layers.Dense(
133      1, activation='sigmoid', name='main_output')(x)
134  model = keras.Model(
135      inputs=[main_input, auxiliary_input],
136      outputs=[main_output, auxiliary_output])
137  return ModelFn(model, [(None, 10), (None, 5)], [(None, 1), (None, 1)])
138
139
140def shared_sequential():
141  """Shared sequential model in a functional model."""
142  inner_model = keras.Sequential([
143      keras.layers.Conv2D(2, 3, activation='relu'),
144      keras.layers.Conv2D(2, 3, activation='relu'),
145  ])
146  inputs_1 = keras.Input((5, 5, 3))
147  inputs_2 = keras.Input((5, 5, 3))
148  x1 = inner_model(inputs_1)
149  x2 = inner_model(inputs_2)
150  x = keras.layers.concatenate([x1, x2])
151  outputs = keras.layers.GlobalAveragePooling2D()(x)
152  model = keras.Model([inputs_1, inputs_2], outputs)
153  return ModelFn(model, [(None, 5, 5, 3), (None, 5, 5, 3)], (None, 4))
154
155
156class MySubclassModel(keras.Model):
157  """A subclass model."""
158
159  def __init__(self, input_dim=3):
160    super(MySubclassModel, self).__init__(name='my_subclass_model')
161    self._config = {'input_dim': input_dim}
162    self.dense1 = keras.layers.Dense(8, activation='relu')
163    self.dense2 = keras.layers.Dense(2, activation='softmax')
164    self.bn = keras.layers.BatchNormalization()
165    self.dp = keras.layers.Dropout(0.5)
166
167  def call(self, inputs, **kwargs):
168    x = self.dense1(inputs)
169    x = self.dp(x)
170    x = self.bn(x)
171    return self.dense2(x)
172
173  def get_config(self):
174    return self._config
175
176  @classmethod
177  def from_config(cls, config):
178    return cls(**config)
179
180
181def nested_subclassed_model():
182  """A subclass model nested in another subclass model."""
183
184  class NestedSubclassModel(keras.Model):
185    """A nested subclass model."""
186
187    def __init__(self):
188      super(NestedSubclassModel, self).__init__()
189      self.dense1 = keras.layers.Dense(4, activation='relu')
190      self.dense2 = keras.layers.Dense(2, activation='relu')
191      self.bn = keras.layers.BatchNormalization()
192      self.inner_subclass_model = MySubclassModel()
193
194    def call(self, inputs):
195      x = self.dense1(inputs)
196      x = self.bn(x)
197      x = self.inner_subclass_model(x)
198      return self.dense2(x)
199
200  return ModelFn(NestedSubclassModel(), (None, 3), (None, 2))
201
202
203def nested_subclassed_in_functional_model():
204  """A subclass model nested in a functional model."""
205  inner_subclass_model = MySubclassModel()
206  inputs = keras.Input(shape=(3,))
207  x = inner_subclass_model(inputs)
208  x = keras.layers.BatchNormalization()(x)
209  outputs = keras.layers.Dense(2, activation='softmax')(x)
210  model = keras.Model(inputs, outputs)
211  return ModelFn(model, (None, 3), (None, 2))
212
213
214def nested_functional_in_subclassed_model():
215  """A functional model nested in a subclass model."""
216  def get_functional_model():
217    inputs = keras.Input(shape=(4,))
218    x = keras.layers.Dense(4, activation='relu')(inputs)
219    x = keras.layers.BatchNormalization()(x)
220    outputs = keras.layers.Dense(2)(x)
221    return keras.Model(inputs, outputs)
222
223  class NestedFunctionalInSubclassModel(keras.Model):
224    """A functional nested in subclass model."""
225
226    def __init__(self):
227      super(NestedFunctionalInSubclassModel, self).__init__(
228          name='nested_functional_in_subclassed_model')
229      self.dense1 = keras.layers.Dense(4, activation='relu')
230      self.dense2 = keras.layers.Dense(2, activation='relu')
231      self.inner_functional_model = get_functional_model()
232
233    def call(self, inputs):
234      x = self.dense1(inputs)
235      x = self.inner_functional_model(x)
236      return self.dense2(x)
237  return ModelFn(NestedFunctionalInSubclassModel(), (None, 3), (None, 2))
238
239
240def shared_layer_subclassed_model():
241  """Shared layer in a subclass model."""
242
243  class SharedLayerSubclassModel(keras.Model):
244    """A subclass model with shared layers."""
245
246    def __init__(self):
247      super(SharedLayerSubclassModel, self).__init__(
248          name='shared_layer_subclass_model')
249      self.dense = keras.layers.Dense(3, activation='relu')
250      self.dp = keras.layers.Dropout(0.5)
251      self.bn = keras.layers.BatchNormalization()
252
253    def call(self, inputs):
254      x = self.dense(inputs)
255      x = self.dp(x)
256      x = self.bn(x)
257      return self.dense(x)
258  return ModelFn(SharedLayerSubclassModel(), (None, 3), (None, 3))
259
260
261def functional_with_keyword_args():
262  """A functional model with keyword args."""
263  inputs = keras.Input(shape=(3,))
264  x = keras.layers.Dense(4)(inputs)
265  x = keras.layers.BatchNormalization()(x)
266  outputs = keras.layers.Dense(2)(x)
267
268  model = keras.Model(inputs, outputs, name='m', trainable=False)
269  return ModelFn(model, (None, 3), (None, 2))
270
271
272ALL_MODELS = [
273    ('basic_sequential', basic_sequential),
274    ('basic_sequential_deferred', basic_sequential_deferred),
275    ('stacked_rnn', stacked_rnn),
276    ('lstm', lstm),
277    ('multi_input_multi_output', multi_input_multi_output),
278    ('nested_sequential_in_functional', nested_sequential_in_functional),
279    ('seq_to_seq', seq_to_seq),
280    ('shared_layer_functional', shared_layer_functional),
281    ('shared_sequential', shared_sequential),
282    ('nested_subclassed_model', nested_subclassed_model),
283    ('nested_subclassed_in_functional_model',
284     nested_subclassed_in_functional_model),
285    ('nested_functional_in_subclassed_model',
286     nested_functional_in_subclassed_model),
287    ('shared_layer_subclassed_model', shared_layer_subclassed_model),
288    ('functional_with_keyword_args', functional_with_keyword_args)
289]
290
291
292def get_models(exclude_models=None):
293  """Get all models excluding the specified ones."""
294  models = [model for model in ALL_MODELS
295            if model[0] not in exclude_models]
296  return models
297