1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests specific to `Sequential` model."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python import keras
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.framework import test_util as tf_test_util
29from tensorflow.python.keras import keras_parameterized
30from tensorflow.python.keras import testing_utils
31from tensorflow.python.ops import array_ops
32from tensorflow.python.platform import test
33
34
35class TestSequential(keras_parameterized.TestCase):
36  """Most Sequential model API tests are covered in `training_test.py`.
37  """
38
39  @keras_parameterized.run_all_keras_modes
40  def test_basic_methods(self):
41    model = keras.models.Sequential()
42    model.add(keras.layers.Dense(1, input_dim=2))
43    model.add(keras.layers.Dropout(0.3, name='dp'))
44    model.add(keras.layers.Dense(2, kernel_regularizer='l2',
45                                 kernel_constraint='max_norm'))
46    self.assertEqual(len(model.layers), 3)
47    self.assertEqual(len(model.weights), 2 * 2)
48    self.assertEqual(model.get_layer(name='dp').name, 'dp')
49
50  @keras_parameterized.run_all_keras_modes
51  def test_input_defined_first_layer(self):
52    model = keras.models.Sequential()
53    model.add(keras.Input(shape=(2,), name='input_layer'))
54    model.add(keras.layers.Dense(1))
55    model.add(keras.layers.Dropout(0.3, name='dp'))
56    model.add(keras.layers.Dense(2, kernel_regularizer='l2',
57                                 kernel_constraint='max_norm'))
58    self.assertLen(model.layers, 3)
59    self.assertLen(model.weights, 2 * 2)
60    self.assertEqual(model.get_layer(name='dp').name, 'dp')
61
62  @keras_parameterized.run_all_keras_modes
63  def test_sequential_pop(self):
64    num_hidden = 5
65    input_dim = 3
66    batch_size = 5
67    num_classes = 2
68
69    model = testing_utils.get_small_sequential_mlp(
70        num_hidden, num_classes, input_dim)
71    model.compile(
72        loss='mse',
73        optimizer='rmsprop',
74        run_eagerly=testing_utils.should_run_eagerly())
75    x = np.random.random((batch_size, input_dim))
76    y = np.random.random((batch_size, num_classes))
77    model.fit(x, y, epochs=1)
78    model.pop()
79    self.assertEqual(len(model.layers), 1)
80    self.assertEqual(model.output_shape, (None, num_hidden))
81    model.compile(
82        loss='mse',
83        optimizer='rmsprop',
84        run_eagerly=testing_utils.should_run_eagerly())
85    y = np.random.random((batch_size, num_hidden))
86    model.fit(x, y, epochs=1)
87
88    # Test popping single-layer model
89    model = keras.models.Sequential()
90    model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
91    model.pop()
92    self.assertEqual(model.layers, [])
93    self.assertEqual(model.outputs, None)
94
95    # Invalid use case
96    model = keras.models.Sequential()
97    with self.assertRaises(TypeError):
98      model.pop()
99
100  @keras_parameterized.run_all_keras_modes
101  def test_sequential_deferred_build_with_np_arrays(self):
102    num_hidden = 5
103    input_dim = 3
104    batch_size = 5
105    num_classes = 2
106
107    model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
108    model.compile(
109        loss='mse',
110        optimizer='rmsprop',
111        metrics=[keras.metrics.CategoricalAccuracy()],
112        run_eagerly=testing_utils.should_run_eagerly())
113    self.assertEqual(len(model.layers), 2)
114    self.assertEqual(len(model.weights), 0)
115    self.assertFalse(model.built)
116
117    x = np.random.random((batch_size, input_dim))
118    y = np.random.random((batch_size, num_classes))
119    model.fit(x, y, epochs=1)
120    self.assertTrue(model.built)
121    self.assertFalse(model._is_graph_network)
122    self.assertEqual(len(model.weights), 2 * 2)
123
124  @keras_parameterized.run_all_keras_modes
125  def test_sequential_deferred_build_with_dataset_iterators(self):
126    num_hidden = 5
127    input_dim = 3
128    num_classes = 2
129    num_samples = 50
130    steps_per_epoch = 10
131
132    model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
133    model.compile(
134        loss='mse',
135        optimizer='rmsprop',
136        metrics=[keras.metrics.CategoricalAccuracy()],
137        run_eagerly=testing_utils.should_run_eagerly())
138    self.assertEqual(len(model.layers), 2)
139    self.assertEqual(len(model.weights), 0)
140    self.assertFalse(model.built)
141
142    x = array_ops.ones((num_samples, input_dim))
143    y = array_ops.zeros((num_samples, num_classes))
144    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
145    dataset = dataset.repeat(100)
146    dataset = dataset.batch(10)
147    iterator = dataset_ops.make_one_shot_iterator(dataset)
148
149    model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch)
150    self.assertTrue(model.built)
151    self.assertEqual(len(model.weights), 2 * 2)
152    self.assertFalse(model._is_graph_network)
153
154  # TODO(kaftan) This test fails w/ run_with_all_keras_modes. File ticket
155  @parameterized.parameters((True,), (False,))
156  @tf_test_util.run_deprecated_v1
157  def test_training_and_eval_methods_on_symbolic_tensors(self, deferred):
158    with self.cached_session():
159
160      def get_model():
161        if deferred:
162          model = testing_utils.get_small_sequential_mlp(10, 4)
163        else:
164          model = testing_utils.get_small_sequential_mlp(10, 4, input_dim=3)
165        model.compile(
166            optimizer='rmsprop',
167            loss='categorical_crossentropy',
168            metrics=['accuracy'])
169        return model
170
171      inputs = keras.backend.zeros(shape=(10, 3))
172      targets = keras.backend.zeros(shape=(10, 4))
173
174      model = get_model()
175      model.fit(inputs, targets, epochs=10, steps_per_epoch=30)
176
177      model = get_model()
178      model.evaluate(inputs, targets, steps=2, verbose=0)
179
180      model = get_model()
181      model.predict(inputs, steps=2)
182
183      model = get_model()
184      model.train_on_batch(inputs, targets)
185
186      model = get_model()
187      model.test_on_batch(inputs, targets)
188
189      model = get_model()
190      model.fit(
191          inputs,
192          targets,
193          epochs=1,
194          steps_per_epoch=2,
195          verbose=0,
196          validation_data=(inputs, targets),
197          validation_steps=2)
198
199  @keras_parameterized.run_all_keras_modes
200  def test_invalid_use_cases(self):
201    # Added objects must be layer instances
202    with self.assertRaises(TypeError):
203      model = keras.models.Sequential()
204      model.add(None)
205
206    # Added layers cannot have multiple outputs
207    class MyLayer(keras.layers.Layer):
208
209      def call(self, inputs):
210        return [3 * inputs, 2 * inputs]
211
212      def compute_output_shape(self, input_shape):
213        return [input_shape, input_shape]
214
215    with self.assertRaises(ValueError):
216      model = keras.models.Sequential()
217      model.add(MyLayer(input_shape=(3,)))
218    with self.assertRaises(TypeError):
219      model = keras.models.Sequential()
220      model.add(keras.layers.Dense(1, input_dim=1))
221      model.add(MyLayer())
222
223  @keras_parameterized.run_all_keras_modes
224  def test_nested_sequential_trainability(self):
225    input_dim = 20
226    num_units = 10
227    num_classes = 2
228
229    inner_model = keras.models.Sequential()
230    inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,)))
231
232    model = keras.models.Sequential()
233    model.add(inner_model)
234    model.add(keras.layers.Dense(num_classes))
235
236    self.assertEqual(len(model.layers), 2)
237
238    self.assertEqual(len(model.trainable_weights), 4)
239    inner_model.trainable = False
240    self.assertEqual(len(model.trainable_weights), 2)
241    inner_model.trainable = True
242    self.assertEqual(len(model.trainable_weights), 4)
243
244  def test_sequential_update_disabling(self):
245    val_a = np.random.random((10, 4))
246    val_out = np.random.random((10, 4))
247
248    with self.cached_session():
249      model = keras.models.Sequential()
250      model.add(keras.layers.BatchNormalization(input_shape=(4,)))
251      assert model.updates
252
253      model.trainable = False
254      assert not model.updates
255
256      model.compile('sgd', 'mse')
257      assert not model.updates
258
259      x1 = model.predict(val_a)
260      model.train_on_batch(val_a, val_out)
261      x2 = model.predict(val_a)
262      self.assertAllClose(x1, x2, atol=1e-7)
263
264      model.trainable = True
265      model.compile('sgd', 'mse')
266      assert model.updates
267
268      model.train_on_batch(val_a, val_out)
269      x2 = model.predict(val_a)
270      assert np.abs(np.sum(x1 - x2)) > 1e-5
271
272  @keras_parameterized.run_all_keras_modes
273  def test_sequential_deferred_build_serialization(self):
274    num_hidden = 5
275    input_dim = 3
276    batch_size = 5
277    num_classes = 2
278
279    model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
280    model.compile(
281        loss='mse',
282        optimizer='rmsprop',
283        metrics=[keras.metrics.CategoricalAccuracy()],
284        run_eagerly=testing_utils.should_run_eagerly())
285    self.assertFalse(model.built)
286
287    x = np.random.random((batch_size, input_dim))
288    y = np.random.random((batch_size, num_classes))
289    model.train_on_batch(x, y)
290    self.assertTrue(model.built)
291
292    config = model.get_config()
293    self.assertIn('build_input_shape', config)
294
295    new_model = keras.models.Sequential.from_config(config)
296    self.assertEqual(len(new_model.layers), 2)
297    self.assertEqual(len(new_model.weights), 4)
298
299  @keras_parameterized.run_all_keras_modes
300  def test_sequential_shape_inference_deferred(self):
301    model = testing_utils.get_small_sequential_mlp(4, 5)
302    output_shape = model.compute_output_shape((None, 7))
303    self.assertEqual(tuple(output_shape.as_list()), (None, 5))
304
305  @keras_parameterized.run_all_keras_modes
306  def test_sequential_build_deferred(self):
307    model = testing_utils.get_small_sequential_mlp(4, 5)
308
309    model.build((None, 10))
310    self.assertTrue(model.built)
311    self.assertEqual(len(model.weights), 4)
312
313    # Test with nested model
314    model = testing_utils.get_small_sequential_mlp(4, 3)
315    inner_model = testing_utils.get_small_sequential_mlp(4, 5)
316    model.add(inner_model)
317
318    model.build((None, 10))
319    self.assertTrue(model.built)
320    self.assertEqual(len(model.weights), 8)
321
322  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
323  def test_sequential_deferred_manual_build(self):
324    model = testing_utils.get_small_sequential_mlp(4, 5)
325    self.assertFalse(model.built)
326    model(array_ops.zeros([1, 2]))
327    self.assertTrue(model.built)
328    self.assertEqual(len(model.outputs), 0)
329    model.compile('rmsprop',
330                  loss='mse',
331                  run_eagerly=testing_utils.should_run_eagerly())
332    self.assertEqual(len(model.outputs), 0)
333    model.train_on_batch(np.zeros((1, 2)), np.zeros((1, 5)))
334    self.assertEqual(len(model.outputs), 1)
335
336  @keras_parameterized.run_all_keras_modes
337  def test_sequential_nesting(self):
338    model = testing_utils.get_small_sequential_mlp(4, 3)
339    inner_model = testing_utils.get_small_sequential_mlp(4, 5)
340    model.add(inner_model)
341
342    model.compile(
343        loss='mse',
344        optimizer='rmsprop',
345        run_eagerly=testing_utils.should_run_eagerly())
346    x = np.random.random((2, 6))
347    y = np.random.random((2, 5))
348    model.fit(x, y, epochs=1)
349
350  @keras_parameterized.run_all_keras_modes
351  def test_variable_names(self):
352    model = keras.models.Sequential([keras.layers.Dense(3)])
353    model.add(keras.layers.Dense(2))
354    model(array_ops.ones([2, 4]))
355    self.assertEqual(
356        ['sequential/dense/kernel:0', 'sequential/dense/bias:0',
357         'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'],
358        [v.name for v in model.variables])
359
360  @keras_parameterized.run_all_keras_modes
361  def test_input_assumptions_propagation(self):
362    model = keras.models.Sequential()
363    model.add(keras.layers.Dense(1))
364    if context.executing_eagerly():
365      with self.assertRaisesRegexp(ValueError,
366                                   'expected min_ndim=2, found ndim=0'):
367        model(1.0)
368
369
370class TestSequentialEagerIntegration(keras_parameterized.TestCase):
371
372  @keras_parameterized.run_all_keras_modes
373  def test_defun_on_call(self):
374    # Check that one can subclass Sequential and place the `call` in a `defun`.
375
376    class MySequential(keras.Sequential):
377
378      def __init__(self, name=None):
379        super(MySequential, self).__init__(name=name)
380        self.call = function.defun(self.call)
381
382    model = MySequential()
383    model.add(keras.layers.Dense(4, activation='relu'))
384    model.add(keras.layers.Dense(5, activation='softmax'))
385
386    model.compile(
387        loss='mse',
388        optimizer='rmsprop',
389        run_eagerly=testing_utils.should_run_eagerly())
390
391    x = np.random.random((2, 6))
392    y = np.random.random((2, 5))
393    model.fit(x, y, epochs=1)
394
395  @keras_parameterized.run_all_keras_modes
396  def test_build_before_fit(self):
397    # Fix for b/112433577
398    model = testing_utils.get_small_sequential_mlp(4, 5)
399    model.compile(
400        loss='mse',
401        optimizer='rmsprop',
402        run_eagerly=testing_utils.should_run_eagerly())
403
404    model.build((None, 6))
405
406    x = np.random.random((2, 6))
407    y = np.random.random((2, 5))
408    model.fit(x, y, epochs=1)
409
410  @keras_parameterized.run_all_keras_modes
411  def test_sequential_model_fails_with_dict_inputs(self):
412    num_classes = 5
413    model = testing_utils.get_small_sequential_mlp(
414        num_hidden=10, num_classes=num_classes)
415    model.compile(
416        'rmsprop',
417        metrics=['acc'],
418        weighted_metrics=['mae'],
419        loss='categorical_crossentropy',
420        run_eagerly=testing_utils.should_run_eagerly())
421
422    x = {'dense_input': np.random.random((10, 1))}
423    y = np.random.randint(num_classes, size=(10, 1))
424
425    with self.assertRaisesRegexp(
426        ValueError, 'Passing a dictionary input to a Sequential Model which '
427        'doesn\'t have FeatureLayer as the first layer is an error'):
428      model.fit(x, y, batch_size=5, epochs=1)
429
430
431if __name__ == '__main__':
432  test.main()
433