1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras initializers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.keras import backend
24from tensorflow.python.keras import combinations
25from tensorflow.python.keras import initializers
26from tensorflow.python.keras import models
27from tensorflow.python.keras import testing_utils
28from tensorflow.python.keras.engine import input_layer
29from tensorflow.python.keras.layers import core
30from tensorflow.python.ops import array_ops
31from tensorflow.python.platform import test
32
33
34def _compute_fans(shape):
35  """Computes the number of input and output units for a weight shape.
36
37  Args:
38    shape: Integer shape tuple or TF tensor shape.
39
40  Returns:
41    A tuple of integer scalars (fan_in, fan_out).
42  """
43  if len(shape) < 1:  # Just to avoid errors for constants.
44    fan_in = fan_out = 1
45  elif len(shape) == 1:
46    fan_in = fan_out = shape[0]
47  elif len(shape) == 2:
48    fan_in = shape[0]
49    fan_out = shape[1]
50  else:
51    # Assuming convolution kernels (2D, 3D, or more).
52    # kernel shape: (..., input_depth, depth)
53    receptive_field_size = 1
54    for dim in shape[:-2]:
55      receptive_field_size *= dim
56    fan_in = shape[-2] * receptive_field_size
57    fan_out = shape[-1] * receptive_field_size
58  return int(fan_in), int(fan_out)
59
60
61@combinations.generate(combinations.combine(mode=['graph', 'eager']))
62class KerasInitializersTest(test.TestCase):
63
64  def _runner(self, init, shape, target_mean=None, target_std=None,
65              target_max=None, target_min=None):
66    variable = backend.variable(init(shape))
67    output = backend.get_value(variable)
68    # Test serialization (assumes deterministic behavior).
69    config = init.get_config()
70    reconstructed_init = init.__class__.from_config(config)
71    variable = backend.variable(reconstructed_init(shape))
72    output_2 = backend.get_value(variable)
73    self.assertAllClose(output, output_2, atol=1e-4)
74
75  def test_uniform(self):
76    tensor_shape = (9, 6, 7)
77    with self.cached_session():
78      self._runner(
79          initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
80          tensor_shape,
81          target_mean=0.,
82          target_max=1,
83          target_min=-1)
84
85  def test_normal(self):
86    tensor_shape = (8, 12, 99)
87    with self.cached_session():
88      self._runner(
89          initializers.RandomNormalV2(mean=0, stddev=1, seed=153),
90          tensor_shape,
91          target_mean=0.,
92          target_std=1)
93
94  def test_truncated_normal(self):
95    tensor_shape = (12, 99, 7)
96    with self.cached_session():
97      self._runner(
98          initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126),
99          tensor_shape,
100          target_mean=0.,
101          target_max=2,
102          target_min=-2)
103
104  def test_constant(self):
105    tensor_shape = (5, 6, 4)
106    with self.cached_session():
107      self._runner(
108          initializers.ConstantV2(2.),
109          tensor_shape,
110          target_mean=2,
111          target_max=2,
112          target_min=2)
113
114  def test_lecun_uniform(self):
115    tensor_shape = (5, 6, 4, 2)
116    with self.cached_session():
117      fan_in, _ = _compute_fans(tensor_shape)
118      std = np.sqrt(1. / fan_in)
119      self._runner(
120          initializers.LecunUniformV2(seed=123),
121          tensor_shape,
122          target_mean=0.,
123          target_std=std)
124
125  def test_glorot_uniform(self):
126    tensor_shape = (5, 6, 4, 2)
127    with self.cached_session():
128      fan_in, fan_out = _compute_fans(tensor_shape)
129      std = np.sqrt(2. / (fan_in + fan_out))
130      self._runner(
131          initializers.GlorotUniformV2(seed=123),
132          tensor_shape,
133          target_mean=0.,
134          target_std=std)
135
136  def test_he_uniform(self):
137    tensor_shape = (5, 6, 4, 2)
138    with self.cached_session():
139      fan_in, _ = _compute_fans(tensor_shape)
140      std = np.sqrt(2. / fan_in)
141      self._runner(
142          initializers.HeUniformV2(seed=123),
143          tensor_shape,
144          target_mean=0.,
145          target_std=std)
146
147  def test_lecun_normal(self):
148    tensor_shape = (5, 6, 4, 2)
149    with self.cached_session():
150      fan_in, _ = _compute_fans(tensor_shape)
151      std = np.sqrt(1. / fan_in)
152      self._runner(
153          initializers.LecunNormalV2(seed=123),
154          tensor_shape,
155          target_mean=0.,
156          target_std=std)
157
158  def test_glorot_normal(self):
159    tensor_shape = (5, 6, 4, 2)
160    with self.cached_session():
161      fan_in, fan_out = _compute_fans(tensor_shape)
162      std = np.sqrt(2. / (fan_in + fan_out))
163      self._runner(
164          initializers.GlorotNormalV2(seed=123),
165          tensor_shape,
166          target_mean=0.,
167          target_std=std)
168
169  def test_he_normal(self):
170    tensor_shape = (5, 6, 4, 2)
171    with self.cached_session():
172      fan_in, _ = _compute_fans(tensor_shape)
173      std = np.sqrt(2. / fan_in)
174      self._runner(
175          initializers.HeNormalV2(seed=123),
176          tensor_shape,
177          target_mean=0.,
178          target_std=std)
179
180  def test_orthogonal(self):
181    tensor_shape = (20, 20)
182    with self.cached_session():
183      self._runner(
184          initializers.OrthogonalV2(seed=123), tensor_shape, target_mean=0.)
185
186  def test_identity(self):
187    with self.cached_session():
188      tensor_shape = (3, 4, 5)
189      with self.assertRaises(ValueError):
190        self._runner(
191            initializers.IdentityV2(),
192            tensor_shape,
193            target_mean=1. / tensor_shape[0],
194            target_max=1.)
195
196      tensor_shape = (3, 3)
197      self._runner(
198          initializers.IdentityV2(),
199          tensor_shape,
200          target_mean=1. / tensor_shape[0],
201          target_max=1.)
202
203  def test_zero(self):
204    tensor_shape = (4, 5)
205    with self.cached_session():
206      self._runner(
207          initializers.ZerosV2(), tensor_shape, target_mean=0., target_max=0.)
208
209  def test_one(self):
210    tensor_shape = (4, 5)
211    with self.cached_session():
212      self._runner(
213          initializers.OnesV2(), tensor_shape, target_mean=1., target_max=1.)
214
215  def test_default_random_uniform(self):
216    ru = initializers.get('uniform')
217    self.assertEqual(ru.minval, -0.05)
218    self.assertEqual(ru.maxval, 0.05)
219
220  def test_default_random_normal(self):
221    rn = initializers.get('normal')
222    self.assertEqual(rn.mean, 0.0)
223    self.assertEqual(rn.stddev, 0.05)
224
225  def test_default_truncated_normal(self):
226    tn = initializers.get('truncated_normal')
227    self.assertEqual(tn.mean, 0.0)
228    self.assertEqual(tn.stddev, 0.05)
229
230  def test_custom_initializer_saving(self):
231
232    def my_initializer(shape, dtype=None):
233      return array_ops.ones(shape, dtype=dtype)
234
235    inputs = input_layer.Input((10,))
236    outputs = core.Dense(1, kernel_initializer=my_initializer)(inputs)
237    model = models.Model(inputs, outputs)
238    model2 = model.from_config(
239        model.get_config(), custom_objects={'my_initializer': my_initializer})
240    self.assertEqual(model2.layers[1].kernel_initializer, my_initializer)
241
242  @testing_utils.run_v2_only
243  def test_load_external_variance_scaling_v2(self):
244    external_serialized_json = {
245        'class_name': 'VarianceScaling',
246        'config': {
247            'distribution': 'normal',
248            'mode': 'fan_avg',
249            'scale': 1.0,
250            'seed': None
251        }
252    }
253    initializer = initializers.deserialize(external_serialized_json)
254    self.assertEqual(initializer.distribution, 'truncated_normal')
255
256  def test_partition(self):
257    with self.cached_session():
258      partition_enabled_initializers = [
259          initializers.ZerosV2(),
260          initializers.OnesV2(),
261          initializers.RandomUniformV2(),
262          initializers.RandomNormalV2(),
263          initializers.TruncatedNormalV2(),
264          initializers.LecunUniformV2(),
265          initializers.GlorotUniformV2(),
266          initializers.HeUniformV2()
267      ]
268      for initializer in partition_enabled_initializers:
269        got = initializer(
270            shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
271        self.assertEqual(got.shape, (2, 2))
272
273      partition_forbidden_initializers = [
274          initializers.OrthogonalV2(),
275          initializers.IdentityV2()
276      ]
277      for initializer in partition_forbidden_initializers:
278        with self.assertRaisesRegex(
279            ValueError,
280            "initializer doesn't support partition-related arguments"):
281          initializer(
282              shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
283
284
285if __name__ == '__main__':
286  test.main()
287