1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for convolutional recurrent layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python import keras
25from tensorflow.python.framework import test_util
26from tensorflow.python.keras import keras_parameterized
27from tensorflow.python.keras import testing_utils
28from tensorflow.python.platform import test
29
30
31@keras_parameterized.run_all_keras_modes
32class ConvLSTMTest(keras_parameterized.TestCase):
33
34  @parameterized.named_parameters(
35      *test_util.generate_combinations_with_testcase_name(
36          data_format=['channels_first', 'channels_last'],
37          return_sequences=[True, False]))
38  def test_conv_lstm(self, data_format, return_sequences):
39    num_row = 3
40    num_col = 3
41    filters = 2
42    num_samples = 1
43    input_channel = 2
44    input_num_row = 5
45    input_num_col = 5
46    sequence_len = 2
47    if data_format == 'channels_first':
48      inputs = np.random.rand(num_samples, sequence_len,
49                              input_channel,
50                              input_num_row, input_num_col)
51    else:
52      inputs = np.random.rand(num_samples, sequence_len,
53                              input_num_row, input_num_col,
54                              input_channel)
55
56    # test for return state:
57    x = keras.Input(batch_shape=inputs.shape)
58    kwargs = {'data_format': data_format,
59              'return_sequences': return_sequences,
60              'return_state': True,
61              'stateful': True,
62              'filters': filters,
63              'kernel_size': (num_row, num_col),
64              'padding': 'valid'}
65    layer = keras.layers.ConvLSTM2D(**kwargs)
66    layer.build(inputs.shape)
67    outputs = layer(x)
68    _, states = outputs[0], outputs[1:]
69    self.assertEqual(len(states), 2)
70    model = keras.models.Model(x, states[0])
71    state = model.predict(inputs)
72
73    self.assertAllClose(
74        keras.backend.eval(layer.states[0]), state, atol=1e-4)
75
76    # test for output shape:
77    testing_utils.layer_test(
78        keras.layers.ConvLSTM2D,
79        kwargs={'data_format': data_format,
80                'return_sequences': return_sequences,
81                'filters': filters,
82                'kernel_size': (num_row, num_col),
83                'padding': 'valid'},
84        input_shape=inputs.shape)
85
86  def test_conv_lstm_statefulness(self):
87    # Tests for statefulness
88    num_row = 3
89    num_col = 3
90    filters = 2
91    num_samples = 1
92    input_channel = 2
93    input_num_row = 5
94    input_num_col = 5
95    sequence_len = 2
96    inputs = np.random.rand(num_samples, sequence_len,
97                            input_num_row, input_num_col,
98                            input_channel)
99
100    with self.cached_session():
101      model = keras.models.Sequential()
102      kwargs = {'data_format': 'channels_last',
103                'return_sequences': False,
104                'filters': filters,
105                'kernel_size': (num_row, num_col),
106                'stateful': True,
107                'batch_input_shape': inputs.shape,
108                'padding': 'same'}
109      layer = keras.layers.ConvLSTM2D(**kwargs)
110
111      model.add(layer)
112      model.compile(optimizer='sgd', loss='mse')
113      out1 = model.predict(np.ones_like(inputs))
114
115      # train once so that the states change
116      model.train_on_batch(np.ones_like(inputs),
117                           np.random.random(out1.shape))
118      out2 = model.predict(np.ones_like(inputs))
119
120      # if the state is not reset, output should be different
121      self.assertNotEqual(out1.max(), out2.max())
122
123      # check that output changes after states are reset
124      # (even though the model itself didn't change)
125      layer.reset_states()
126      out3 = model.predict(np.ones_like(inputs))
127      self.assertNotEqual(out3.max(), out2.max())
128
129      # check that container-level reset_states() works
130      model.reset_states()
131      out4 = model.predict(np.ones_like(inputs))
132      self.assertAllClose(out3, out4, atol=1e-5)
133
134      # check that the call to `predict` updated the states
135      out5 = model.predict(np.ones_like(inputs))
136      self.assertNotEqual(out4.max(), out5.max())
137
138  def test_conv_lstm_regularizers(self):
139    # check regularizers
140    num_row = 3
141    num_col = 3
142    filters = 2
143    num_samples = 1
144    input_channel = 2
145    input_num_row = 5
146    input_num_col = 5
147    sequence_len = 2
148    inputs = np.random.rand(num_samples, sequence_len,
149                            input_num_row, input_num_col,
150                            input_channel)
151
152    with self.cached_session():
153      kwargs = {'data_format': 'channels_last',
154                'return_sequences': False,
155                'kernel_size': (num_row, num_col),
156                'stateful': True,
157                'filters': filters,
158                'batch_input_shape': inputs.shape,
159                'kernel_regularizer': keras.regularizers.L1L2(l1=0.01),
160                'recurrent_regularizer': keras.regularizers.L1L2(l1=0.01),
161                'activity_regularizer': 'l2',
162                'bias_regularizer': 'l2',
163                'kernel_constraint': 'max_norm',
164                'recurrent_constraint': 'max_norm',
165                'bias_constraint': 'max_norm',
166                'padding': 'same'}
167
168      layer = keras.layers.ConvLSTM2D(**kwargs)
169      layer.build(inputs.shape)
170      self.assertEqual(len(layer.losses), 3)
171      layer(keras.backend.variable(np.ones(inputs.shape)))
172      self.assertEqual(len(layer.losses), 4)
173
174  def test_conv_lstm_dropout(self):
175    # check dropout
176    with self.cached_session():
177      testing_utils.layer_test(
178          keras.layers.ConvLSTM2D,
179          kwargs={'data_format': 'channels_last',
180                  'return_sequences': False,
181                  'filters': 2,
182                  'kernel_size': (3, 3),
183                  'padding': 'same',
184                  'dropout': 0.1,
185                  'recurrent_dropout': 0.1},
186          input_shape=(1, 2, 5, 5, 2))
187
188  def test_conv_lstm_cloning(self):
189    with self.cached_session():
190      model = keras.models.Sequential()
191      model.add(keras.layers.ConvLSTM2D(5, 3, input_shape=(None, 5, 5, 3)))
192
193      test_inputs = np.random.random((2, 4, 5, 5, 3))
194      reference_outputs = model.predict(test_inputs)
195      weights = model.get_weights()
196
197    # Use a new graph to clone the model
198    with self.cached_session():
199      clone = keras.models.clone_model(model)
200      clone.set_weights(weights)
201
202      outputs = clone.predict(test_inputs)
203      self.assertAllClose(reference_outputs, outputs, atol=1e-5)
204
205
206if __name__ == '__main__':
207  test.main()
208