1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for convolutional recurrent layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python import keras 25from tensorflow.python.framework import test_util 26from tensorflow.python.keras import keras_parameterized 27from tensorflow.python.keras import testing_utils 28from tensorflow.python.platform import test 29 30 31@keras_parameterized.run_all_keras_modes 32class ConvLSTMTest(keras_parameterized.TestCase): 33 34 @parameterized.named_parameters( 35 *test_util.generate_combinations_with_testcase_name( 36 data_format=['channels_first', 'channels_last'], 37 return_sequences=[True, False])) 38 def test_conv_lstm(self, data_format, return_sequences): 39 num_row = 3 40 num_col = 3 41 filters = 2 42 num_samples = 1 43 input_channel = 2 44 input_num_row = 5 45 input_num_col = 5 46 sequence_len = 2 47 if data_format == 'channels_first': 48 inputs = np.random.rand(num_samples, sequence_len, 49 input_channel, 50 input_num_row, input_num_col) 51 else: 52 inputs = np.random.rand(num_samples, sequence_len, 53 input_num_row, input_num_col, 54 input_channel) 55 56 # test for return state: 57 x = keras.Input(batch_shape=inputs.shape) 58 kwargs = {'data_format': data_format, 59 'return_sequences': return_sequences, 60 'return_state': True, 61 'stateful': True, 62 'filters': filters, 63 'kernel_size': (num_row, num_col), 64 'padding': 'valid'} 65 layer = keras.layers.ConvLSTM2D(**kwargs) 66 layer.build(inputs.shape) 67 outputs = layer(x) 68 _, states = outputs[0], outputs[1:] 69 self.assertEqual(len(states), 2) 70 model = keras.models.Model(x, states[0]) 71 state = model.predict(inputs) 72 73 self.assertAllClose( 74 keras.backend.eval(layer.states[0]), state, atol=1e-4) 75 76 # test for output shape: 77 testing_utils.layer_test( 78 keras.layers.ConvLSTM2D, 79 kwargs={'data_format': data_format, 80 'return_sequences': return_sequences, 81 'filters': filters, 82 'kernel_size': (num_row, num_col), 83 'padding': 'valid'}, 84 input_shape=inputs.shape) 85 86 def test_conv_lstm_statefulness(self): 87 # Tests for statefulness 88 num_row = 3 89 num_col = 3 90 filters = 2 91 num_samples = 1 92 input_channel = 2 93 input_num_row = 5 94 input_num_col = 5 95 sequence_len = 2 96 inputs = np.random.rand(num_samples, sequence_len, 97 input_num_row, input_num_col, 98 input_channel) 99 100 with self.cached_session(): 101 model = keras.models.Sequential() 102 kwargs = {'data_format': 'channels_last', 103 'return_sequences': False, 104 'filters': filters, 105 'kernel_size': (num_row, num_col), 106 'stateful': True, 107 'batch_input_shape': inputs.shape, 108 'padding': 'same'} 109 layer = keras.layers.ConvLSTM2D(**kwargs) 110 111 model.add(layer) 112 model.compile(optimizer='sgd', loss='mse') 113 out1 = model.predict(np.ones_like(inputs)) 114 115 # train once so that the states change 116 model.train_on_batch(np.ones_like(inputs), 117 np.random.random(out1.shape)) 118 out2 = model.predict(np.ones_like(inputs)) 119 120 # if the state is not reset, output should be different 121 self.assertNotEqual(out1.max(), out2.max()) 122 123 # check that output changes after states are reset 124 # (even though the model itself didn't change) 125 layer.reset_states() 126 out3 = model.predict(np.ones_like(inputs)) 127 self.assertNotEqual(out3.max(), out2.max()) 128 129 # check that container-level reset_states() works 130 model.reset_states() 131 out4 = model.predict(np.ones_like(inputs)) 132 self.assertAllClose(out3, out4, atol=1e-5) 133 134 # check that the call to `predict` updated the states 135 out5 = model.predict(np.ones_like(inputs)) 136 self.assertNotEqual(out4.max(), out5.max()) 137 138 def test_conv_lstm_regularizers(self): 139 # check regularizers 140 num_row = 3 141 num_col = 3 142 filters = 2 143 num_samples = 1 144 input_channel = 2 145 input_num_row = 5 146 input_num_col = 5 147 sequence_len = 2 148 inputs = np.random.rand(num_samples, sequence_len, 149 input_num_row, input_num_col, 150 input_channel) 151 152 with self.cached_session(): 153 kwargs = {'data_format': 'channels_last', 154 'return_sequences': False, 155 'kernel_size': (num_row, num_col), 156 'stateful': True, 157 'filters': filters, 158 'batch_input_shape': inputs.shape, 159 'kernel_regularizer': keras.regularizers.L1L2(l1=0.01), 160 'recurrent_regularizer': keras.regularizers.L1L2(l1=0.01), 161 'activity_regularizer': 'l2', 162 'bias_regularizer': 'l2', 163 'kernel_constraint': 'max_norm', 164 'recurrent_constraint': 'max_norm', 165 'bias_constraint': 'max_norm', 166 'padding': 'same'} 167 168 layer = keras.layers.ConvLSTM2D(**kwargs) 169 layer.build(inputs.shape) 170 self.assertEqual(len(layer.losses), 3) 171 layer(keras.backend.variable(np.ones(inputs.shape))) 172 self.assertEqual(len(layer.losses), 4) 173 174 def test_conv_lstm_dropout(self): 175 # check dropout 176 with self.cached_session(): 177 testing_utils.layer_test( 178 keras.layers.ConvLSTM2D, 179 kwargs={'data_format': 'channels_last', 180 'return_sequences': False, 181 'filters': 2, 182 'kernel_size': (3, 3), 183 'padding': 'same', 184 'dropout': 0.1, 185 'recurrent_dropout': 0.1}, 186 input_shape=(1, 2, 5, 5, 2)) 187 188 def test_conv_lstm_cloning(self): 189 with self.cached_session(): 190 model = keras.models.Sequential() 191 model.add(keras.layers.ConvLSTM2D(5, 3, input_shape=(None, 5, 5, 3))) 192 193 test_inputs = np.random.random((2, 4, 5, 5, 3)) 194 reference_outputs = model.predict(test_inputs) 195 weights = model.get_weights() 196 197 # Use a new graph to clone the model 198 with self.cached_session(): 199 clone = keras.models.clone_model(model) 200 clone.set_weights(weights) 201 202 outputs = clone.predict(test_inputs) 203 self.assertAllClose(reference_outputs, outputs, atol=1e-5) 204 205 206if __name__ == '__main__': 207 test.main() 208