1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for locally-connected layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python import keras 24from tensorflow.python.framework import test_util as tf_test_util 25from tensorflow.python.keras import testing_utils 26from tensorflow.python.platform import test 27from tensorflow.python.training.rmsprop import RMSPropOptimizer 28 29 30class LocallyConnected1DLayersTest(test.TestCase): 31 # TODO(fchollet): investigate why LocallyConnected1D 32 # fails inside a graph function in an eager context (fails with error 33 # "Incompatible shapes between op input and calculated input gradient"). 34 35 @tf_test_util.run_deprecated_v1 36 def test_locallyconnected_1d(self): 37 with self.cached_session(): 38 num_samples = 2 39 num_steps = 8 40 input_dim = 5 41 filter_length = 3 42 filters = 4 43 44 for padding in ['valid', 'same']: 45 for strides in [1]: 46 if padding == 'same' and strides != 1: 47 continue 48 for data_format in ['channels_first', 'channels_last']: 49 for implementation in [1, 2]: 50 kwargs = { 51 'filters': filters, 52 'kernel_size': filter_length, 53 'padding': padding, 54 'strides': strides, 55 'data_format': data_format, 56 'implementation': implementation 57 } 58 59 if padding == 'same' and implementation == 1: 60 self.assertRaises(ValueError, 61 keras.layers.LocallyConnected1D, 62 **kwargs) 63 else: 64 testing_utils.layer_test( 65 keras.layers.LocallyConnected1D, 66 kwargs=kwargs, 67 input_shape=(num_samples, num_steps, input_dim)) 68 69 def test_locallyconnected_1d_regularization(self): 70 num_samples = 2 71 num_steps = 8 72 input_dim = 5 73 filter_length = 3 74 filters = 4 75 for data_format in ['channels_first', 'channels_last']: 76 for padding in ['valid', 'same']: 77 for implementation in [1, 2]: 78 kwargs = { 79 'filters': filters, 80 'kernel_size': filter_length, 81 'kernel_regularizer': 'l2', 82 'bias_regularizer': 'l2', 83 'activity_regularizer': 'l2', 84 'data_format': data_format, 85 'implementation': implementation, 86 'padding': padding 87 } 88 89 if padding == 'same' and implementation == 1: 90 self.assertRaises(ValueError, 91 keras.layers.LocallyConnected1D, 92 **kwargs) 93 else: 94 with self.cached_session(): 95 layer = keras.layers.LocallyConnected1D(**kwargs) 96 layer.build((num_samples, num_steps, input_dim)) 97 self.assertEqual(len(layer.losses), 2) 98 layer( 99 keras.backend.variable(np.ones((num_samples, 100 num_steps, 101 input_dim)))) 102 self.assertEqual(len(layer.losses), 3) 103 104 k_constraint = keras.constraints.max_norm(0.01) 105 b_constraint = keras.constraints.max_norm(0.01) 106 kwargs = { 107 'filters': filters, 108 'kernel_size': filter_length, 109 'kernel_constraint': k_constraint, 110 'bias_constraint': b_constraint, 111 } 112 with self.cached_session(): 113 layer = keras.layers.LocallyConnected1D(**kwargs) 114 layer.build((num_samples, num_steps, input_dim)) 115 self.assertEqual(layer.kernel.constraint, k_constraint) 116 self.assertEqual(layer.bias.constraint, b_constraint) 117 118 119class LocallyConnected2DLayersTest(test.TestCase): 120 # TODO(fchollet): investigate why LocallyConnected2D 121 # fails inside a graph function in an eager context (fails with error 122 # "Incompatible shapes between op input and calculated input gradient"). 123 124 @tf_test_util.run_deprecated_v1 125 def test_locallyconnected_2d(self): 126 with self.cached_session(): 127 num_samples = 8 128 filters = 3 129 stack_size = 4 130 num_row = 6 131 num_col = 10 132 133 for padding in ['valid', 'same']: 134 for strides in [(1, 1), (2, 2)]: 135 for implementation in [1, 2]: 136 if padding == 'same' and strides != (1, 1): 137 continue 138 139 kwargs = { 140 'filters': filters, 141 'kernel_size': 3, 142 'padding': padding, 143 'kernel_regularizer': 'l2', 144 'bias_regularizer': 'l2', 145 'strides': strides, 146 'data_format': 'channels_last', 147 'implementation': implementation 148 } 149 150 if padding == 'same' and implementation == 1: 151 self.assertRaises(ValueError, 152 keras.layers.LocallyConnected2D, 153 **kwargs) 154 else: 155 testing_utils.layer_test( 156 keras.layers.LocallyConnected2D, 157 kwargs=kwargs, 158 input_shape=(num_samples, num_row, num_col, stack_size)) 159 160 @tf_test_util.run_deprecated_v1 161 def test_locallyconnected_2d_channels_first(self): 162 with self.cached_session(): 163 num_samples = 8 164 filters = 3 165 stack_size = 4 166 num_row = 6 167 num_col = 10 168 169 for implementation in [1, 2]: 170 for padding in ['valid', 'same']: 171 kwargs = { 172 'filters': filters, 173 'kernel_size': 3, 174 'data_format': 'channels_first', 175 'implementation': implementation, 176 'padding': padding 177 } 178 179 if padding == 'same' and implementation == 1: 180 self.assertRaises(ValueError, 181 keras.layers.LocallyConnected2D, 182 **kwargs) 183 else: 184 testing_utils.layer_test( 185 keras.layers.LocallyConnected2D, 186 kwargs=kwargs, 187 input_shape=(num_samples, num_row, num_col, stack_size)) 188 189 def test_locallyconnected_2d_regularization(self): 190 num_samples = 2 191 filters = 3 192 stack_size = 4 193 num_row = 6 194 num_col = 7 195 for implementation in [1, 2]: 196 for padding in ['valid', 'same']: 197 kwargs = { 198 'filters': filters, 199 'kernel_size': 3, 200 'kernel_regularizer': 'l2', 201 'bias_regularizer': 'l2', 202 'activity_regularizer': 'l2', 203 'implementation': implementation, 204 'padding': padding 205 } 206 207 if padding == 'same' and implementation == 1: 208 self.assertRaises(ValueError, 209 keras.layers.LocallyConnected2D, 210 **kwargs) 211 else: 212 with self.cached_session(): 213 layer = keras.layers.LocallyConnected2D(**kwargs) 214 layer.build((num_samples, num_row, num_col, stack_size)) 215 self.assertEqual(len(layer.losses), 2) 216 layer( 217 keras.backend.variable( 218 np.ones((num_samples, num_row, num_col, stack_size)))) 219 self.assertEqual(len(layer.losses), 3) 220 221 k_constraint = keras.constraints.max_norm(0.01) 222 b_constraint = keras.constraints.max_norm(0.01) 223 kwargs = { 224 'filters': filters, 225 'kernel_size': 3, 226 'kernel_constraint': k_constraint, 227 'bias_constraint': b_constraint, 228 } 229 with self.cached_session(): 230 layer = keras.layers.LocallyConnected2D(**kwargs) 231 layer.build((num_samples, num_row, num_col, stack_size)) 232 self.assertEqual(layer.kernel.constraint, k_constraint) 233 self.assertEqual(layer.bias.constraint, b_constraint) 234 235 236class LocallyConnectedImplementationModeTest(test.TestCase): 237 238 @tf_test_util.run_deprecated_v1 239 def test_locallyconnected_implementation(self): 240 with self.cached_session(): 241 num_samples = 4 242 num_classes = 3 243 num_epochs = 2 244 245 np.random.seed(1) 246 targets = np.random.randint(0, num_classes, (num_samples,)) 247 248 for width in [1, 6]: 249 for height in [7]: 250 for filters in [2]: 251 for data_format in ['channels_first', 'channels_last']: 252 inputs = get_inputs( 253 data_format, filters, height, num_samples, width) 254 255 for kernel_x in [(3,)]: 256 for kernel_y in [()] if width == 1 else [(2,)]: 257 for stride_x in [(1,)]: 258 for stride_y in [()] if width == 1 else [(3,)]: 259 for layers in [2]: 260 kwargs = { 261 'layers': layers, 262 'filters': filters, 263 'kernel_size': kernel_x + kernel_y, 264 'strides': stride_x + stride_y, 265 'data_format': data_format, 266 'num_classes': num_classes 267 } 268 model_1 = get_model(implementation=1, **kwargs) 269 model_2 = get_model(implementation=2, **kwargs) 270 271 # Build models. 272 model_1.train_on_batch(inputs, targets) 273 model_2.train_on_batch(inputs, targets) 274 275 # Copy weights. 276 copy_model_weights(model_2, model_1) 277 278 # Compare outputs at initialization. 279 out_1 = model_1.call(inputs) 280 out_2 = model_2.call(inputs) 281 self.assertAllCloseAccordingToType(out_1, out_2, 282 rtol=1e-5, atol=1e-5) 283 284 # Train. 285 model_1.fit(x=inputs, 286 y=targets, 287 epochs=num_epochs, 288 batch_size=num_samples) 289 model_2.fit(x=inputs, 290 y=targets, 291 epochs=num_epochs, 292 batch_size=num_samples) 293 294 # Compare outputs after a few training steps. 295 out_1 = model_1.call(inputs) 296 out_2 = model_2.call(inputs) 297 self.assertAllCloseAccordingToType( 298 out_1, out_2, atol=2e-4) 299 300 @tf_test_util.run_in_graph_and_eager_modes 301 def test_make_2d(self): 302 input_shapes = [ 303 (0,), 304 (0, 0), 305 (1,), 306 (2,), 307 (3,), 308 (1, 0), 309 (0, 3), 310 (1, 1), 311 (1, 2), 312 (3, 1), 313 (2, 2), 314 (3, 3), 315 (1, 0, 1), 316 (5, 2, 3), 317 (3, 5, 6, 7, 0), 318 (3, 2, 2, 4, 4), 319 (1, 2, 3, 4, 7, 2), 320 ] 321 np.random.seed(1) 322 323 for input_shape in input_shapes: 324 inputs = np.random.normal(0, 1, input_shape) 325 inputs_tf = keras.backend.variable(inputs) 326 327 split_dim = np.random.randint(0, inputs.ndim + 1) 328 shape_2d = (int(np.prod(inputs.shape[:split_dim])), 329 int(np.prod(inputs.shape[split_dim:]))) 330 inputs_2d = np.reshape(inputs, shape_2d) 331 332 inputs_2d_tf = keras.layers.local.make_2d(inputs_tf, split_dim) 333 inputs_2d_tf = keras.backend.get_value(inputs_2d_tf) 334 335 self.assertAllCloseAccordingToType(inputs_2d, inputs_2d_tf) 336 337 338def get_inputs(data_format, filters, height, num_samples, width): 339 if data_format == 'channels_first': 340 if width == 1: 341 input_shape = (filters, height) 342 else: 343 input_shape = (filters, height, width) 344 345 elif data_format == 'channels_last': 346 if width == 1: 347 input_shape = (height, filters) 348 else: 349 input_shape = (height, width, filters) 350 351 else: 352 raise NotImplementedError(data_format) 353 354 inputs = np.random.normal(0, 1, 355 (num_samples,) + input_shape).astype(np.float32) 356 return inputs 357 358 359def xent(y_true, y_pred): 360 y_true = keras.backend.cast( 361 keras.backend.reshape(y_true, (-1,)), 362 keras.backend.dtypes_module.int32) 363 364 return keras.backend.nn.sparse_softmax_cross_entropy_with_logits( 365 labels=y_true, 366 logits=y_pred) 367 368 369def get_model(implementation, 370 filters, 371 kernel_size, 372 strides, 373 layers, 374 num_classes, 375 data_format): 376 model = keras.Sequential() 377 378 if len(kernel_size) == 1: 379 lc_layer = keras.layers.LocallyConnected1D 380 elif len(kernel_size) == 2: 381 lc_layer = keras.layers.LocallyConnected2D 382 else: 383 raise NotImplementedError(kernel_size) 384 385 for _ in range(layers): 386 model.add(lc_layer( 387 padding='valid', 388 kernel_initializer=keras.initializers.random_normal(), 389 bias_initializer=keras.initializers.random_normal(), 390 filters=filters, 391 strides=strides, 392 kernel_size=kernel_size, 393 activation=keras.activations.relu, 394 data_format=data_format, 395 implementation=implementation)) 396 397 model.add(keras.layers.Flatten()) 398 model.add(keras.layers.Dense(num_classes)) 399 model.compile( 400 optimizer=RMSPropOptimizer(0.01), 401 metrics=[keras.metrics.categorical_accuracy], 402 loss=xent 403 ) 404 return model 405 406 407def copy_lc_weights(lc_layer_2_from, lc_layer_1_to): 408 lc_2_kernel, lc_2_bias = lc_layer_2_from.weights 409 lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask 410 411 data_format = lc_layer_2_from.data_format 412 413 if data_format == 'channels_first': 414 if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D): 415 permutation = (3, 0, 1, 2) 416 elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D): 417 permutation = (4, 5, 0, 1, 2, 3) 418 else: 419 raise NotImplementedError(lc_layer_2_from) 420 421 elif data_format == 'channels_last': 422 if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D): 423 permutation = (2, 0, 1, 3) 424 elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D): 425 permutation = (3, 4, 0, 1, 2, 5) 426 else: 427 raise NotImplementedError(lc_layer_2_from) 428 429 else: 430 raise NotImplementedError(data_format) 431 432 lc_2_kernel_masked = keras.backend.permute_dimensions( 433 lc_2_kernel_masked, permutation) 434 435 lc_2_kernel_mask = keras.backend.math_ops.not_equal( 436 lc_2_kernel_masked, 0) 437 lc_2_kernel_flat = keras.backend.array_ops.boolean_mask( 438 lc_2_kernel_masked, lc_2_kernel_mask) 439 lc_2_kernel_reshaped = keras.backend.reshape(lc_2_kernel_flat, 440 lc_layer_1_to.kernel.shape) 441 442 lc_2_kernel_reshaped = keras.backend.get_value(lc_2_kernel_reshaped) 443 lc_2_bias = keras.backend.get_value(lc_2_bias) 444 445 lc_layer_1_to.set_weights([lc_2_kernel_reshaped, lc_2_bias]) 446 447 448def copy_model_weights(model_2_from, model_1_to): 449 for l in range(len(model_2_from.layers)): 450 layer_2_from = model_2_from.layers[l] 451 layer_1_to = model_1_to.layers[l] 452 453 if isinstance(layer_2_from, (keras.layers.LocallyConnected2D, 454 keras.layers.LocallyConnected1D)): 455 copy_lc_weights(layer_2_from, layer_1_to) 456 457 elif isinstance(layer_2_from, keras.layers.Dense): 458 weights_2, bias_2 = layer_2_from.weights 459 weights_2 = keras.backend.get_value(weights_2) 460 bias_2 = keras.backend.get_value(bias_2) 461 layer_1_to.set_weights([weights_2, bias_2]) 462 463 else: 464 continue 465 466 467if __name__ == '__main__': 468 test.main() 469