1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for locally-connected layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.framework import test_util as tf_test_util
25from tensorflow.python.keras import testing_utils
26from tensorflow.python.platform import test
27from tensorflow.python.training.rmsprop import RMSPropOptimizer
28
29
30class LocallyConnected1DLayersTest(test.TestCase):
31  # TODO(fchollet): investigate why LocallyConnected1D
32  # fails inside a graph function in an eager context (fails with error
33  # "Incompatible shapes between op input and calculated input gradient").
34
35  @tf_test_util.run_deprecated_v1
36  def test_locallyconnected_1d(self):
37    with self.cached_session():
38      num_samples = 2
39      num_steps = 8
40      input_dim = 5
41      filter_length = 3
42      filters = 4
43
44      for padding in ['valid', 'same']:
45        for strides in [1]:
46          if padding == 'same' and strides != 1:
47            continue
48          for data_format in ['channels_first', 'channels_last']:
49            for implementation in [1, 2]:
50              kwargs = {
51                  'filters': filters,
52                  'kernel_size': filter_length,
53                  'padding': padding,
54                  'strides': strides,
55                  'data_format': data_format,
56                  'implementation': implementation
57              }
58
59              if padding == 'same' and implementation == 1:
60                self.assertRaises(ValueError,
61                                  keras.layers.LocallyConnected1D,
62                                  **kwargs)
63              else:
64                testing_utils.layer_test(
65                    keras.layers.LocallyConnected1D,
66                    kwargs=kwargs,
67                    input_shape=(num_samples, num_steps, input_dim))
68
69  def test_locallyconnected_1d_regularization(self):
70    num_samples = 2
71    num_steps = 8
72    input_dim = 5
73    filter_length = 3
74    filters = 4
75    for data_format in ['channels_first', 'channels_last']:
76      for padding in ['valid', 'same']:
77        for implementation in [1, 2]:
78          kwargs = {
79              'filters': filters,
80              'kernel_size': filter_length,
81              'kernel_regularizer': 'l2',
82              'bias_regularizer': 'l2',
83              'activity_regularizer': 'l2',
84              'data_format': data_format,
85              'implementation': implementation,
86              'padding': padding
87          }
88
89          if padding == 'same' and implementation == 1:
90            self.assertRaises(ValueError,
91                              keras.layers.LocallyConnected1D,
92                              **kwargs)
93          else:
94            with self.cached_session():
95              layer = keras.layers.LocallyConnected1D(**kwargs)
96              layer.build((num_samples, num_steps, input_dim))
97              self.assertEqual(len(layer.losses), 2)
98              layer(
99                  keras.backend.variable(np.ones((num_samples,
100                                                  num_steps,
101                                                  input_dim))))
102              self.assertEqual(len(layer.losses), 3)
103
104            k_constraint = keras.constraints.max_norm(0.01)
105            b_constraint = keras.constraints.max_norm(0.01)
106            kwargs = {
107                'filters': filters,
108                'kernel_size': filter_length,
109                'kernel_constraint': k_constraint,
110                'bias_constraint': b_constraint,
111            }
112            with self.cached_session():
113              layer = keras.layers.LocallyConnected1D(**kwargs)
114              layer.build((num_samples, num_steps, input_dim))
115              self.assertEqual(layer.kernel.constraint, k_constraint)
116              self.assertEqual(layer.bias.constraint, b_constraint)
117
118
119class LocallyConnected2DLayersTest(test.TestCase):
120  # TODO(fchollet): investigate why LocallyConnected2D
121  # fails inside a graph function in an eager context (fails with error
122  # "Incompatible shapes between op input and calculated input gradient").
123
124  @tf_test_util.run_deprecated_v1
125  def test_locallyconnected_2d(self):
126    with self.cached_session():
127      num_samples = 8
128      filters = 3
129      stack_size = 4
130      num_row = 6
131      num_col = 10
132
133      for padding in ['valid', 'same']:
134        for strides in [(1, 1), (2, 2)]:
135          for implementation in [1, 2]:
136            if padding == 'same' and strides != (1, 1):
137              continue
138
139            kwargs = {
140                'filters': filters,
141                'kernel_size': 3,
142                'padding': padding,
143                'kernel_regularizer': 'l2',
144                'bias_regularizer': 'l2',
145                'strides': strides,
146                'data_format': 'channels_last',
147                'implementation': implementation
148            }
149
150            if padding == 'same' and implementation == 1:
151              self.assertRaises(ValueError,
152                                keras.layers.LocallyConnected2D,
153                                **kwargs)
154            else:
155              testing_utils.layer_test(
156                  keras.layers.LocallyConnected2D,
157                  kwargs=kwargs,
158                  input_shape=(num_samples, num_row, num_col, stack_size))
159
160  @tf_test_util.run_deprecated_v1
161  def test_locallyconnected_2d_channels_first(self):
162    with self.cached_session():
163      num_samples = 8
164      filters = 3
165      stack_size = 4
166      num_row = 6
167      num_col = 10
168
169      for implementation in [1, 2]:
170        for padding in ['valid', 'same']:
171          kwargs = {
172              'filters': filters,
173              'kernel_size': 3,
174              'data_format': 'channels_first',
175              'implementation': implementation,
176              'padding': padding
177          }
178
179          if padding == 'same' and implementation == 1:
180            self.assertRaises(ValueError,
181                              keras.layers.LocallyConnected2D,
182                              **kwargs)
183          else:
184            testing_utils.layer_test(
185                keras.layers.LocallyConnected2D,
186                kwargs=kwargs,
187                input_shape=(num_samples, num_row, num_col, stack_size))
188
189  def test_locallyconnected_2d_regularization(self):
190    num_samples = 2
191    filters = 3
192    stack_size = 4
193    num_row = 6
194    num_col = 7
195    for implementation in [1, 2]:
196      for padding in ['valid', 'same']:
197        kwargs = {
198            'filters': filters,
199            'kernel_size': 3,
200            'kernel_regularizer': 'l2',
201            'bias_regularizer': 'l2',
202            'activity_regularizer': 'l2',
203            'implementation': implementation,
204            'padding': padding
205        }
206
207        if padding == 'same' and implementation == 1:
208          self.assertRaises(ValueError,
209                            keras.layers.LocallyConnected2D,
210                            **kwargs)
211        else:
212          with self.cached_session():
213            layer = keras.layers.LocallyConnected2D(**kwargs)
214            layer.build((num_samples, num_row, num_col, stack_size))
215            self.assertEqual(len(layer.losses), 2)
216            layer(
217                keras.backend.variable(
218                    np.ones((num_samples, num_row, num_col, stack_size))))
219            self.assertEqual(len(layer.losses), 3)
220
221          k_constraint = keras.constraints.max_norm(0.01)
222          b_constraint = keras.constraints.max_norm(0.01)
223          kwargs = {
224              'filters': filters,
225              'kernel_size': 3,
226              'kernel_constraint': k_constraint,
227              'bias_constraint': b_constraint,
228          }
229          with self.cached_session():
230            layer = keras.layers.LocallyConnected2D(**kwargs)
231            layer.build((num_samples, num_row, num_col, stack_size))
232            self.assertEqual(layer.kernel.constraint, k_constraint)
233            self.assertEqual(layer.bias.constraint, b_constraint)
234
235
236class LocallyConnectedImplementationModeTest(test.TestCase):
237
238  @tf_test_util.run_deprecated_v1
239  def test_locallyconnected_implementation(self):
240    with self.cached_session():
241      num_samples = 4
242      num_classes = 3
243      num_epochs = 2
244
245      np.random.seed(1)
246      targets = np.random.randint(0, num_classes, (num_samples,))
247
248      for width in [1, 6]:
249        for height in [7]:
250          for filters in [2]:
251            for data_format in ['channels_first', 'channels_last']:
252              inputs = get_inputs(
253                  data_format, filters, height, num_samples, width)
254
255              for kernel_x in [(3,)]:
256                for kernel_y in [()] if width == 1 else [(2,)]:
257                  for stride_x in [(1,)]:
258                    for stride_y in [()] if width == 1 else [(3,)]:
259                      for layers in [2]:
260                        kwargs = {
261                            'layers': layers,
262                            'filters': filters,
263                            'kernel_size': kernel_x + kernel_y,
264                            'strides': stride_x + stride_y,
265                            'data_format': data_format,
266                            'num_classes': num_classes
267                        }
268                        model_1 = get_model(implementation=1, **kwargs)
269                        model_2 = get_model(implementation=2, **kwargs)
270
271                        # Build models.
272                        model_1.train_on_batch(inputs, targets)
273                        model_2.train_on_batch(inputs, targets)
274
275                        # Copy weights.
276                        copy_model_weights(model_2, model_1)
277
278                        # Compare outputs at initialization.
279                        out_1 = model_1.call(inputs)
280                        out_2 = model_2.call(inputs)
281                        self.assertAllCloseAccordingToType(out_1, out_2,
282                                                           rtol=1e-5, atol=1e-5)
283
284                        # Train.
285                        model_1.fit(x=inputs,
286                                    y=targets,
287                                    epochs=num_epochs,
288                                    batch_size=num_samples)
289                        model_2.fit(x=inputs,
290                                    y=targets,
291                                    epochs=num_epochs,
292                                    batch_size=num_samples)
293
294                        # Compare outputs after a few training steps.
295                        out_1 = model_1.call(inputs)
296                        out_2 = model_2.call(inputs)
297                        self.assertAllCloseAccordingToType(
298                            out_1, out_2, atol=2e-4)
299
300  @tf_test_util.run_in_graph_and_eager_modes
301  def test_make_2d(self):
302    input_shapes = [
303        (0,),
304        (0, 0),
305        (1,),
306        (2,),
307        (3,),
308        (1, 0),
309        (0, 3),
310        (1, 1),
311        (1, 2),
312        (3, 1),
313        (2, 2),
314        (3, 3),
315        (1, 0, 1),
316        (5, 2, 3),
317        (3, 5, 6, 7, 0),
318        (3, 2, 2, 4, 4),
319        (1, 2, 3, 4, 7, 2),
320    ]
321    np.random.seed(1)
322
323    for input_shape in input_shapes:
324      inputs = np.random.normal(0, 1, input_shape)
325      inputs_tf = keras.backend.variable(inputs)
326
327      split_dim = np.random.randint(0, inputs.ndim + 1)
328      shape_2d = (int(np.prod(inputs.shape[:split_dim])),
329                  int(np.prod(inputs.shape[split_dim:])))
330      inputs_2d = np.reshape(inputs, shape_2d)
331
332      inputs_2d_tf = keras.layers.local.make_2d(inputs_tf, split_dim)
333      inputs_2d_tf = keras.backend.get_value(inputs_2d_tf)
334
335      self.assertAllCloseAccordingToType(inputs_2d, inputs_2d_tf)
336
337
338def get_inputs(data_format, filters, height, num_samples, width):
339  if data_format == 'channels_first':
340    if width == 1:
341      input_shape = (filters, height)
342    else:
343      input_shape = (filters, height, width)
344
345  elif data_format == 'channels_last':
346    if width == 1:
347      input_shape = (height, filters)
348    else:
349      input_shape = (height, width, filters)
350
351  else:
352    raise NotImplementedError(data_format)
353
354  inputs = np.random.normal(0, 1,
355                            (num_samples,) + input_shape).astype(np.float32)
356  return inputs
357
358
359def xent(y_true, y_pred):
360  y_true = keras.backend.cast(
361      keras.backend.reshape(y_true, (-1,)),
362      keras.backend.dtypes_module.int32)
363
364  return keras.backend.nn.sparse_softmax_cross_entropy_with_logits(
365      labels=y_true,
366      logits=y_pred)
367
368
369def get_model(implementation,
370              filters,
371              kernel_size,
372              strides,
373              layers,
374              num_classes,
375              data_format):
376  model = keras.Sequential()
377
378  if len(kernel_size) == 1:
379    lc_layer = keras.layers.LocallyConnected1D
380  elif len(kernel_size) == 2:
381    lc_layer = keras.layers.LocallyConnected2D
382  else:
383    raise NotImplementedError(kernel_size)
384
385  for _ in range(layers):
386    model.add(lc_layer(
387        padding='valid',
388        kernel_initializer=keras.initializers.random_normal(),
389        bias_initializer=keras.initializers.random_normal(),
390        filters=filters,
391        strides=strides,
392        kernel_size=kernel_size,
393        activation=keras.activations.relu,
394        data_format=data_format,
395        implementation=implementation))
396
397  model.add(keras.layers.Flatten())
398  model.add(keras.layers.Dense(num_classes))
399  model.compile(
400      optimizer=RMSPropOptimizer(0.01),
401      metrics=[keras.metrics.categorical_accuracy],
402      loss=xent
403  )
404  return model
405
406
407def copy_lc_weights(lc_layer_2_from, lc_layer_1_to):
408  lc_2_kernel, lc_2_bias = lc_layer_2_from.weights
409  lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask
410
411  data_format = lc_layer_2_from.data_format
412
413  if data_format == 'channels_first':
414    if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D):
415      permutation = (3, 0, 1, 2)
416    elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D):
417      permutation = (4, 5, 0, 1, 2, 3)
418    else:
419      raise NotImplementedError(lc_layer_2_from)
420
421  elif data_format == 'channels_last':
422    if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D):
423      permutation = (2, 0, 1, 3)
424    elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D):
425      permutation = (3, 4, 0, 1, 2, 5)
426    else:
427      raise NotImplementedError(lc_layer_2_from)
428
429  else:
430    raise NotImplementedError(data_format)
431
432  lc_2_kernel_masked = keras.backend.permute_dimensions(
433      lc_2_kernel_masked, permutation)
434
435  lc_2_kernel_mask = keras.backend.math_ops.not_equal(
436      lc_2_kernel_masked, 0)
437  lc_2_kernel_flat = keras.backend.array_ops.boolean_mask(
438      lc_2_kernel_masked, lc_2_kernel_mask)
439  lc_2_kernel_reshaped = keras.backend.reshape(lc_2_kernel_flat,
440                                               lc_layer_1_to.kernel.shape)
441
442  lc_2_kernel_reshaped = keras.backend.get_value(lc_2_kernel_reshaped)
443  lc_2_bias = keras.backend.get_value(lc_2_bias)
444
445  lc_layer_1_to.set_weights([lc_2_kernel_reshaped, lc_2_bias])
446
447
448def copy_model_weights(model_2_from, model_1_to):
449  for l in range(len(model_2_from.layers)):
450    layer_2_from = model_2_from.layers[l]
451    layer_1_to = model_1_to.layers[l]
452
453    if isinstance(layer_2_from, (keras.layers.LocallyConnected2D,
454                                 keras.layers.LocallyConnected1D)):
455      copy_lc_weights(layer_2_from, layer_1_to)
456
457    elif isinstance(layer_2_from, keras.layers.Dense):
458      weights_2, bias_2 = layer_2_from.weights
459      weights_2 = keras.backend.get_value(weights_2)
460      bias_2 = keras.backend.get_value(bias_2)
461      layer_1_to.set_weights([weights_2, bias_2])
462
463    else:
464      continue
465
466
467if __name__ == '__main__':
468  test.main()
469