1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for normalization layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python import keras
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import test_util as tf_test_util
27from tensorflow.python.keras import keras_parameterized
28from tensorflow.python.keras import testing_utils
29from tensorflow.python.keras.layers import normalization
30from tensorflow.python.keras.layers import normalization_v2
31from tensorflow.python.keras.mixed_precision.experimental import policy
32from tensorflow.python.platform import test
33from tensorflow.python.training import gradient_descent
34
35
36class BatchNormalizationTest(keras_parameterized.TestCase):
37
38  @keras_parameterized.run_all_keras_modes
39  def test_basic_batchnorm(self):
40    testing_utils.layer_test(
41        keras.layers.BatchNormalization,
42        kwargs={
43            'momentum': 0.9,
44            'epsilon': 0.1,
45            'gamma_regularizer': keras.regularizers.l2(0.01),
46            'beta_regularizer': keras.regularizers.l2(0.01)
47        },
48        input_shape=(3, 4, 2))
49    testing_utils.layer_test(
50        keras.layers.BatchNormalization,
51        kwargs={
52            'gamma_initializer': 'ones',
53            'beta_initializer': 'ones',
54            'moving_mean_initializer': 'zeros',
55            'moving_variance_initializer': 'ones'
56        },
57        input_shape=(3, 4, 2))
58    testing_utils.layer_test(
59        keras.layers.BatchNormalization,
60        kwargs={'scale': False,
61                'center': False},
62        input_shape=(3, 3))
63
64  @tf_test_util.run_in_graph_and_eager_modes
65  def test_batchnorm_weights(self):
66    layer = keras.layers.BatchNormalization(scale=False, center=False)
67    layer.build((None, 3, 4))
68    self.assertEqual(len(layer.trainable_weights), 0)
69    self.assertEqual(len(layer.weights), 2)
70
71    layer = keras.layers.BatchNormalization()
72    layer.build((None, 3, 4))
73    self.assertEqual(len(layer.trainable_weights), 2)
74    self.assertEqual(len(layer.weights), 4)
75
76  @tf_test_util.run_in_graph_and_eager_modes
77  def test_batchnorm_regularization(self):
78    layer = keras.layers.BatchNormalization(
79        gamma_regularizer='l1', beta_regularizer='l1')
80    layer.build((None, 3, 4))
81    self.assertEqual(len(layer.losses), 2)
82    max_norm = keras.constraints.max_norm
83    layer = keras.layers.BatchNormalization(
84        gamma_constraint=max_norm, beta_constraint=max_norm)
85    layer.build((None, 3, 4))
86    self.assertEqual(layer.gamma.constraint, max_norm)
87    self.assertEqual(layer.beta.constraint, max_norm)
88
89  @keras_parameterized.run_all_keras_modes
90  def test_batchnorm_convnet(self):
91    if test.is_gpu_available(cuda_only=True):
92      with self.session(use_gpu=True):
93        model = keras.models.Sequential()
94        norm = keras.layers.BatchNormalization(
95            axis=1, input_shape=(3, 4, 4), momentum=0.8)
96        model.add(norm)
97        model.compile(loss='mse',
98                      optimizer=gradient_descent.GradientDescentOptimizer(0.01),
99                      run_eagerly=testing_utils.should_run_eagerly())
100
101        # centered on 5.0, variance 10.0
102        x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
103        model.fit(x, x, epochs=4, verbose=0)
104        out = model.predict(x)
105        out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
106        out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))
107
108        np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
109        np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
110
111  @keras_parameterized.run_all_keras_modes
112  def test_batchnorm_convnet_channel_last(self):
113    model = keras.models.Sequential()
114    norm = keras.layers.BatchNormalization(
115        axis=-1, input_shape=(4, 4, 3), momentum=0.8)
116    model.add(norm)
117    model.compile(loss='mse',
118                  optimizer=gradient_descent.GradientDescentOptimizer(0.01),
119                  run_eagerly=testing_utils.should_run_eagerly())
120
121    # centered on 5.0, variance 10.0
122    x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
123    model.fit(x, x, epochs=4, verbose=0)
124    out = model.predict(x)
125    out -= np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 3))
126    out /= np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 3))
127
128    np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
129    np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
130
131  @keras_parameterized.run_all_keras_modes
132  def test_batchnorm_correctness(self):
133    _run_batchnorm_correctness_test(
134        normalization.BatchNormalization, dtype='float32')
135    _run_batchnorm_correctness_test(
136        normalization_v2.BatchNormalization, dtype='float32')
137
138  @keras_parameterized.run_all_keras_modes
139  def test_batchnorm_mixed_precision(self):
140    _run_batchnorm_correctness_test(
141        normalization.BatchNormalization, dtype='float16')
142    _run_batchnorm_correctness_test(
143        normalization_v2.BatchNormalization, dtype='float16')
144
145  @tf_test_util.run_in_graph_and_eager_modes
146  def test_batchnorm_policy(self):
147    norm = keras.layers.BatchNormalization(
148        axis=-1,
149        input_shape=(4, 4, 3),
150        momentum=0.8,
151        dtype=policy.Policy('infer_float32_vars'))
152    x = np.random.normal(size=(10, 4, 4, 3)).astype('float16')
153    y = norm(x)
154    self.assertEqual(y.dtype, 'float16')
155    self.assertEqual(norm.beta.dtype.base_dtype, 'float32')
156    self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')
157
158
159class BatchNormalizationV1Test(test.TestCase):
160
161  @tf_test_util.run_in_graph_and_eager_modes
162  def test_v1_fused_attribute(self):
163    norm = normalization.BatchNormalization()
164    inp = keras.layers.Input((4, 4, 4))
165    norm(inp)
166    self.assertEqual(norm.fused, True)
167
168    norm = normalization.BatchNormalization(fused=False)
169    self.assertEqual(norm.fused, False)
170    inp = keras.layers.Input(shape=(4, 4, 4))
171    norm(inp)
172    self.assertEqual(norm.fused, False)
173
174    norm = normalization.BatchNormalization(virtual_batch_size=2)
175    self.assertEqual(norm.fused, True)
176    inp = keras.layers.Input(shape=(2, 2, 2))
177    norm(inp)
178    self.assertEqual(norm.fused, False)
179
180
181class BatchNormalizationV2Test(keras_parameterized.TestCase):
182
183  @keras_parameterized.run_all_keras_modes
184  def test_basic_batchnorm_v2(self):
185    testing_utils.layer_test(
186        normalization_v2.BatchNormalization,
187        kwargs={'fused': True},
188        input_shape=(3, 3, 3, 3))
189    testing_utils.layer_test(
190        normalization_v2.BatchNormalization,
191        kwargs={'fused': None},
192        input_shape=(3, 3, 3))
193
194  @tf_test_util.run_in_graph_and_eager_modes
195  def test_v2_fused_attribute(self):
196    norm = normalization_v2.BatchNormalization()
197    self.assertEqual(norm.fused, None)
198    inp = keras.layers.Input(shape=(4, 4, 4))
199    norm(inp)
200    self.assertEqual(norm.fused, True)
201
202    norm = normalization_v2.BatchNormalization()
203    self.assertEqual(norm.fused, None)
204    inp = keras.layers.Input(shape=(4, 4))
205    norm(inp)
206    self.assertEqual(norm.fused, False)
207
208    norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
209    self.assertEqual(norm.fused, False)
210    inp = keras.layers.Input(shape=(4, 4, 4))
211    norm(inp)
212    self.assertEqual(norm.fused, False)
213
214    norm = normalization_v2.BatchNormalization(fused=False)
215    self.assertEqual(norm.fused, False)
216    inp = keras.layers.Input(shape=(4, 4, 4))
217    norm(inp)
218    self.assertEqual(norm.fused, False)
219
220    norm = normalization_v2.BatchNormalization(fused=True, axis=[3])
221    self.assertEqual(norm.fused, True)
222    inp = keras.layers.Input(shape=(4, 4, 4))
223    norm(inp)
224    self.assertEqual(norm.fused, True)
225
226    with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
227      normalization_v2.BatchNormalization(fused=True, renorm=True)
228
229    with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
230      normalization_v2.BatchNormalization(fused=True, axis=2)
231
232    with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
233      normalization_v2.BatchNormalization(fused=True, axis=[1, 3])
234
235    with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
236      normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2)
237
238    with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'):
239      normalization_v2.BatchNormalization(fused=True,
240                                          adjustment=lambda _: (1, 0))
241
242    norm = normalization_v2.BatchNormalization(fused=True)
243    self.assertEqual(norm.fused, True)
244    inp = keras.layers.Input(shape=(4, 4))
245    with self.assertRaisesRegexp(ValueError, '4D input tensors'):
246      norm(inp)
247
248
249def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
250  model = keras.models.Sequential()
251  model.add(keras.Input(shape=(2, 2, 2), dtype=dtype))
252  norm = layer(momentum=0.8, fused=fused)
253  model.add(norm)
254  if dtype == 'float16':
255    # Keras models require float32 losses.
256    model.add(keras.layers.Lambda(lambda x: keras.backend.cast(x, 'float32')))
257  model.compile(loss='mse',
258                optimizer=gradient_descent.GradientDescentOptimizer(0.01),
259                run_eagerly=testing_utils.should_run_eagerly())
260
261  # centered on 5.0, variance 10.0
262  x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
263       .astype(dtype))
264  model.fit(x, x, epochs=4, verbose=0)
265  out = model.predict(x)
266  out -= keras.backend.eval(norm.beta)
267  out /= keras.backend.eval(norm.gamma)
268
269  np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
270  np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
271
272
273@parameterized.parameters(
274    [normalization.BatchNormalization, normalization_v2.BatchNormalization])
275class NormalizationLayersGraphModeOnlyTest(
276    test.TestCase, parameterized.TestCase):
277
278  def test_shared_batchnorm(self, layer):
279    """Test that a BN layer can be shared across different data streams."""
280    with self.cached_session():
281      # Test single layer reuse
282      bn = layer()
283      x1 = keras.layers.Input(shape=(10,))
284      _ = bn(x1)
285
286      x2 = keras.layers.Input(shape=(10,))
287      y2 = bn(x2)
288
289      x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
290      model = keras.models.Model(x2, y2)
291
292      model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
293      model.train_on_batch(x, x)
294
295      self.assertEqual(len(bn.updates), 4)
296      self.assertEqual(len(model.updates), 2)
297      self.assertEqual(len(model.get_updates_for(x2)), 2)
298
299      # Test model-level reuse
300      x3 = keras.layers.Input(shape=(10,))
301      y3 = model(x3)
302      new_model = keras.models.Model(x3, y3, name='new_model')
303
304      self.assertEqual(len(new_model.updates), 2)
305      self.assertEqual(len(model.updates), 4)
306      self.assertEqual(len(new_model.get_updates_for(x3)), 2)
307      new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
308      new_model.train_on_batch(x, x)
309
310  def test_that_trainable_disables_updates(self, layer):
311    with self.cached_session():
312      val_a = np.random.random((10, 4))
313      val_out = np.random.random((10, 4))
314
315      a = keras.layers.Input(shape=(4,))
316      layer = layer(input_shape=(4,))
317      b = layer(a)
318      model = keras.models.Model(a, b)
319
320      model.trainable = False
321      assert not model.updates
322
323      model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
324      assert not model.updates
325
326      x1 = model.predict(val_a)
327      model.train_on_batch(val_a, val_out)
328      x2 = model.predict(val_a)
329      self.assertAllClose(x1, x2, atol=1e-7)
330
331      model.trainable = True
332      model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
333      assert model.updates
334
335      model.train_on_batch(val_a, val_out)
336      x2 = model.predict(val_a)
337      assert np.abs(np.sum(x1 - x2)) > 1e-5
338
339      layer.trainable = False
340      model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
341      assert not model.updates
342
343      x1 = model.predict(val_a)
344      model.train_on_batch(val_a, val_out)
345      x2 = model.predict(val_a)
346      self.assertAllClose(x1, x2, atol=1e-7)
347
348  @tf_test_util.run_deprecated_v1
349  def test_batchnorm_trainable(self, layer):
350    """Tests that batchnorm layer is trainable when learning phase is enabled.
351
352    Computes mean and std for current inputs then
353    applies batch normalization using them.
354
355    Args:
356      layer: Either V1 or V2 of BatchNormalization layer.
357    """
358    # TODO(fchollet): enable in all execution modes when issue with
359    # learning phase setting is resolved.
360    with self.cached_session():
361      bn_mean = 0.5
362      bn_std = 10.
363      val_a = np.expand_dims(np.arange(10.), axis=1)
364
365      def get_model(bn_mean, bn_std):
366        inp = keras.layers.Input(shape=(1,))
367        x = layer()(inp)
368        model1 = keras.models.Model(inp, x)
369        model1.set_weights([
370            np.array([1.]),
371            np.array([0.]),
372            np.array([bn_mean]),
373            np.array([bn_std**2])
374        ])
375        return model1
376
377      # Simulates training-mode with trainable layer.
378      # Should use mini-batch statistics.
379      with keras.backend.learning_phase_scope(1):
380        model = get_model(bn_mean, bn_std)
381        model.compile(loss='mse', optimizer='rmsprop')
382        out = model.predict(val_a)
383        self.assertAllClose(
384            (val_a - np.mean(val_a)) / np.std(val_a), out, atol=1e-3)
385
386
387def _run_layernorm_correctness_test(layer, dtype='float32'):
388  model = keras.models.Sequential()
389  norm = layer(input_shape=(2, 2, 2))
390  model.add(norm)
391  model.compile(loss='mse',
392                optimizer=gradient_descent.GradientDescentOptimizer(0.01),
393                run_eagerly=testing_utils.should_run_eagerly())
394
395  # centered on 5.0, variance 10.0
396  x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
397       .astype(dtype))
398  model.fit(x, x, epochs=4, verbose=0)
399  out = model.predict(x)
400  out -= keras.backend.eval(norm.beta)
401  out /= keras.backend.eval(norm.gamma)
402
403  np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
404  np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
405
406
407class LayerNormalizationTest(keras_parameterized.TestCase):
408
409  @keras_parameterized.run_all_keras_modes
410  def test_basic_layernorm(self):
411    testing_utils.layer_test(
412        keras.layers.LayerNormalization,
413        kwargs={
414            'gamma_regularizer': keras.regularizers.l2(0.01),
415            'beta_regularizer': keras.regularizers.l2(0.01)
416        },
417        input_shape=(3, 4, 2))
418    testing_utils.layer_test(
419        keras.layers.LayerNormalization,
420        kwargs={
421            'gamma_initializer': 'ones',
422            'beta_initializer': 'ones',
423        },
424        input_shape=(3, 4, 2))
425    testing_utils.layer_test(
426        keras.layers.LayerNormalization,
427        kwargs={'scale': False,
428                'center': False},
429        input_shape=(3, 3))
430
431  @tf_test_util.run_in_graph_and_eager_modes
432  def test_layernorm_weights(self):
433    layer = keras.layers.LayerNormalization(scale=False, center=False)
434    layer.build((None, 3, 4))
435    self.assertEqual(len(layer.trainable_weights), 0)
436    self.assertEqual(len(layer.weights), 0)
437
438    layer = keras.layers.LayerNormalization()
439    layer.build((None, 3, 4))
440    self.assertEqual(len(layer.trainable_weights), 2)
441    self.assertEqual(len(layer.weights), 2)
442
443  @tf_test_util.run_in_graph_and_eager_modes
444  def test_layernorm_regularization(self):
445    layer = keras.layers.LayerNormalization(
446        gamma_regularizer='l1', beta_regularizer='l1')
447    layer.build((None, 3, 4))
448    self.assertEqual(len(layer.losses), 2)
449    max_norm = keras.constraints.max_norm
450    layer = keras.layers.LayerNormalization(
451        gamma_constraint=max_norm, beta_constraint=max_norm)
452    layer.build((None, 3, 4))
453    self.assertEqual(layer.gamma.constraint, max_norm)
454    self.assertEqual(layer.beta.constraint, max_norm)
455
456  @keras_parameterized.run_all_keras_modes
457  def test_layernorm_convnet(self):
458    if test.is_gpu_available(cuda_only=True):
459      with self.session(use_gpu=True):
460        model = keras.models.Sequential()
461        norm = keras.layers.LayerNormalization(
462            input_shape=(3, 4, 4), params_axis=1)
463        model.add(norm)
464        model.compile(loss='mse',
465                      optimizer=gradient_descent.GradientDescentOptimizer(0.01),
466                      run_eagerly=testing_utils.should_run_eagerly())
467
468        # centered on 5.0, variance 10.0
469        x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
470        model.fit(x, x, epochs=4, verbose=0)
471        out = model.predict(x)
472        out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
473        out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))
474
475        np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
476        np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
477
478  @keras_parameterized.run_all_keras_modes
479  def test_layernorm_convnet_channel_last(self):
480    model = keras.models.Sequential()
481    norm = keras.layers.LayerNormalization(input_shape=(4, 4, 3))
482    model.add(norm)
483    model.compile(loss='mse',
484                  optimizer=gradient_descent.GradientDescentOptimizer(0.01),
485                  run_eagerly=testing_utils.should_run_eagerly())
486
487    # centered on 5.0, variance 10.0
488    x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
489    model.fit(x, x, epochs=4, verbose=0)
490    out = model.predict(x)
491    out -= np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 3))
492    out /= np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 3))
493
494    np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
495    np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
496
497  @keras_parameterized.run_all_keras_modes
498  def test_layernorm_correctness(self):
499    _run_layernorm_correctness_test(
500        normalization.LayerNormalization, dtype='float32')
501
502  @keras_parameterized.run_all_keras_modes
503  def test_layernorm_mixed_precision(self):
504    _run_layernorm_correctness_test(
505        normalization.LayerNormalization, dtype='float16')
506
507  def doOutputTest(self,
508                   input_shape,
509                   tol=1e-5,
510                   norm_axis=None,
511                   params_axis=-1,
512                   dtype=None):
513    ndim = len(input_shape)
514    if norm_axis is None:
515      moments_axis = range(1, ndim)
516    elif isinstance(norm_axis, int):
517      if norm_axis < 0:
518        moments_axis = [norm_axis + ndim]
519      else:
520        moments_axis = [norm_axis]
521    else:
522      moments_axis = []
523      for dim in norm_axis:
524        if dim < 0:
525          dim = dim + ndim
526        moments_axis.append(dim)
527
528    moments_axis = tuple(moments_axis)
529    expected_shape = []
530    for i in range(ndim):
531      if i not in moments_axis:
532        expected_shape.append(input_shape[i])
533
534    expected_mean = np.zeros(expected_shape)
535    expected_var = np.ones(expected_shape)
536    for mu in [0.0, 1e2]:
537      for sigma in [1.0, 0.1]:
538        inputs = np.random.randn(*input_shape) * sigma + mu
539        inputs_t = constant_op.constant(inputs, shape=input_shape)
540        layer = normalization.LayerNormalization(
541            norm_axis=norm_axis, params_axis=params_axis, dtype=dtype)
542        outputs = layer(inputs_t)
543        beta = layer.beta
544        gamma = layer.gamma
545        for weight in layer.weights:
546          self.evaluate(weight.initializer)
547        outputs = self.evaluate(outputs)
548        beta = self.evaluate(beta)
549        gamma = self.evaluate(gamma)
550
551        # The mean and variance of the output should be close to 0 and 1
552        # respectively.
553
554        # Make sure that there are no NaNs
555        self.assertFalse(np.isnan(outputs).any())
556        mean = np.mean(outputs, axis=moments_axis)
557        var = np.var(outputs, axis=moments_axis)
558        # Layer-norm implemented in numpy
559        eps = 1e-12
560        expected_out = (
561            (gamma * (inputs - np.mean(
562                inputs, axis=moments_axis, keepdims=True)) /
563             np.sqrt(eps + np.var(
564                 inputs, axis=moments_axis, keepdims=True))) + beta)
565        self.assertAllClose(expected_mean, mean, atol=tol, rtol=tol)
566        self.assertAllClose(expected_var, var, atol=tol)
567        # The full computation gets a bigger tolerance
568        self.assertAllClose(expected_out, outputs, atol=5 * tol)
569
570  @tf_test_util.run_in_graph_and_eager_modes
571  def testOutput2DInput(self):
572    self.doOutputTest((10, 300))
573    self.doOutputTest((10, 300), norm_axis=[0])
574    self.doOutputTest((10, 300), params_axis=[0, 1])
575
576  @tf_test_util.run_in_graph_and_eager_modes
577  def testOutput2DInputDegenerateNormAxis(self):
578    with self.assertRaisesRegexp(ValueError, r'Invalid axis: 2'):
579      self.doOutputTest((10, 300), norm_axis=2)
580
581  @tf_test_util.run_in_graph_and_eager_modes
582  def testOutput4DInput(self):
583    self.doOutputTest((100, 10, 10, 3))
584
585  @tf_test_util.run_in_graph_and_eager_modes
586  def testOutput4DInputNormOnInnermostAxis(self):
587    # Equivalent tests
588    shape = (100, 10, 10, 3)
589    self.doOutputTest(
590        shape, norm_axis=list(range(3, len(shape))), tol=1e-4, dtype='float64')
591    self.doOutputTest(shape, norm_axis=-1, tol=1e-4, dtype='float64')
592
593  @tf_test_util.run_in_graph_and_eager_modes
594  def testOutputSmallInput(self):
595    self.doOutputTest((10, 10, 10, 30))
596
597  @tf_test_util.run_in_graph_and_eager_modes
598  def testOutputSmallInputNormOnInnermostAxis(self):
599    self.doOutputTest((10, 10, 10, 30), norm_axis=3)
600
601  @tf_test_util.run_in_graph_and_eager_modes
602  def testOutputSmallInputNormOnMixedAxes(self):
603    self.doOutputTest((10, 10, 10, 30), norm_axis=[0, 3])
604    self.doOutputTest((10, 10, 10, 30), params_axis=[-2, -1])
605    self.doOutputTest((10, 10, 10, 30), norm_axis=[0, 3],
606                      params_axis=[-3, -2, -1])
607
608  @tf_test_util.run_in_graph_and_eager_modes
609  def testOutputBigInput(self):
610    self.doOutputTest((1, 100, 100, 1))
611    self.doOutputTest((1, 100, 100, 1), norm_axis=[1, 2])
612    self.doOutputTest((1, 100, 100, 1), norm_axis=[1, 2],
613                      params_axis=[-2, -1])
614
615
616if __name__ == '__main__':
617  test.main()
618