1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras initializers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.keras import backend 24from tensorflow.python.keras import combinations 25from tensorflow.python.keras import initializers 26from tensorflow.python.keras import models 27from tensorflow.python.keras import testing_utils 28from tensorflow.python.keras.engine import input_layer 29from tensorflow.python.keras.layers import core 30from tensorflow.python.ops import array_ops 31from tensorflow.python.platform import test 32 33 34def _compute_fans(shape): 35 """Computes the number of input and output units for a weight shape. 36 37 Args: 38 shape: Integer shape tuple or TF tensor shape. 39 40 Returns: 41 A tuple of integer scalars (fan_in, fan_out). 42 """ 43 if len(shape) < 1: # Just to avoid errors for constants. 44 fan_in = fan_out = 1 45 elif len(shape) == 1: 46 fan_in = fan_out = shape[0] 47 elif len(shape) == 2: 48 fan_in = shape[0] 49 fan_out = shape[1] 50 else: 51 # Assuming convolution kernels (2D, 3D, or more). 52 # kernel shape: (..., input_depth, depth) 53 receptive_field_size = 1 54 for dim in shape[:-2]: 55 receptive_field_size *= dim 56 fan_in = shape[-2] * receptive_field_size 57 fan_out = shape[-1] * receptive_field_size 58 return int(fan_in), int(fan_out) 59 60 61@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 62class KerasInitializersTest(test.TestCase): 63 64 def _runner(self, init, shape, target_mean=None, target_std=None, 65 target_max=None, target_min=None): 66 variable = backend.variable(init(shape)) 67 output = backend.get_value(variable) 68 # Test serialization (assumes deterministic behavior). 69 config = init.get_config() 70 reconstructed_init = init.__class__.from_config(config) 71 variable = backend.variable(reconstructed_init(shape)) 72 output_2 = backend.get_value(variable) 73 self.assertAllClose(output, output_2, atol=1e-4) 74 75 def test_uniform(self): 76 tensor_shape = (9, 6, 7) 77 with self.cached_session(): 78 self._runner( 79 initializers.RandomUniformV2(minval=-1, maxval=1, seed=124), 80 tensor_shape, 81 target_mean=0., 82 target_max=1, 83 target_min=-1) 84 85 def test_normal(self): 86 tensor_shape = (8, 12, 99) 87 with self.cached_session(): 88 self._runner( 89 initializers.RandomNormalV2(mean=0, stddev=1, seed=153), 90 tensor_shape, 91 target_mean=0., 92 target_std=1) 93 94 def test_truncated_normal(self): 95 tensor_shape = (12, 99, 7) 96 with self.cached_session(): 97 self._runner( 98 initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126), 99 tensor_shape, 100 target_mean=0., 101 target_max=2, 102 target_min=-2) 103 104 def test_constant(self): 105 tensor_shape = (5, 6, 4) 106 with self.cached_session(): 107 self._runner( 108 initializers.ConstantV2(2.), 109 tensor_shape, 110 target_mean=2, 111 target_max=2, 112 target_min=2) 113 114 def test_lecun_uniform(self): 115 tensor_shape = (5, 6, 4, 2) 116 with self.cached_session(): 117 fan_in, _ = _compute_fans(tensor_shape) 118 std = np.sqrt(1. / fan_in) 119 self._runner( 120 initializers.LecunUniformV2(seed=123), 121 tensor_shape, 122 target_mean=0., 123 target_std=std) 124 125 def test_glorot_uniform(self): 126 tensor_shape = (5, 6, 4, 2) 127 with self.cached_session(): 128 fan_in, fan_out = _compute_fans(tensor_shape) 129 std = np.sqrt(2. / (fan_in + fan_out)) 130 self._runner( 131 initializers.GlorotUniformV2(seed=123), 132 tensor_shape, 133 target_mean=0., 134 target_std=std) 135 136 def test_he_uniform(self): 137 tensor_shape = (5, 6, 4, 2) 138 with self.cached_session(): 139 fan_in, _ = _compute_fans(tensor_shape) 140 std = np.sqrt(2. / fan_in) 141 self._runner( 142 initializers.HeUniformV2(seed=123), 143 tensor_shape, 144 target_mean=0., 145 target_std=std) 146 147 def test_lecun_normal(self): 148 tensor_shape = (5, 6, 4, 2) 149 with self.cached_session(): 150 fan_in, _ = _compute_fans(tensor_shape) 151 std = np.sqrt(1. / fan_in) 152 self._runner( 153 initializers.LecunNormalV2(seed=123), 154 tensor_shape, 155 target_mean=0., 156 target_std=std) 157 158 def test_glorot_normal(self): 159 tensor_shape = (5, 6, 4, 2) 160 with self.cached_session(): 161 fan_in, fan_out = _compute_fans(tensor_shape) 162 std = np.sqrt(2. / (fan_in + fan_out)) 163 self._runner( 164 initializers.GlorotNormalV2(seed=123), 165 tensor_shape, 166 target_mean=0., 167 target_std=std) 168 169 def test_he_normal(self): 170 tensor_shape = (5, 6, 4, 2) 171 with self.cached_session(): 172 fan_in, _ = _compute_fans(tensor_shape) 173 std = np.sqrt(2. / fan_in) 174 self._runner( 175 initializers.HeNormalV2(seed=123), 176 tensor_shape, 177 target_mean=0., 178 target_std=std) 179 180 def test_orthogonal(self): 181 tensor_shape = (20, 20) 182 with self.cached_session(): 183 self._runner( 184 initializers.OrthogonalV2(seed=123), tensor_shape, target_mean=0.) 185 186 def test_identity(self): 187 with self.cached_session(): 188 tensor_shape = (3, 4, 5) 189 with self.assertRaises(ValueError): 190 self._runner( 191 initializers.IdentityV2(), 192 tensor_shape, 193 target_mean=1. / tensor_shape[0], 194 target_max=1.) 195 196 tensor_shape = (3, 3) 197 self._runner( 198 initializers.IdentityV2(), 199 tensor_shape, 200 target_mean=1. / tensor_shape[0], 201 target_max=1.) 202 203 def test_zero(self): 204 tensor_shape = (4, 5) 205 with self.cached_session(): 206 self._runner( 207 initializers.ZerosV2(), tensor_shape, target_mean=0., target_max=0.) 208 209 def test_one(self): 210 tensor_shape = (4, 5) 211 with self.cached_session(): 212 self._runner( 213 initializers.OnesV2(), tensor_shape, target_mean=1., target_max=1.) 214 215 def test_default_random_uniform(self): 216 ru = initializers.get('uniform') 217 self.assertEqual(ru.minval, -0.05) 218 self.assertEqual(ru.maxval, 0.05) 219 220 def test_default_random_normal(self): 221 rn = initializers.get('normal') 222 self.assertEqual(rn.mean, 0.0) 223 self.assertEqual(rn.stddev, 0.05) 224 225 def test_default_truncated_normal(self): 226 tn = initializers.get('truncated_normal') 227 self.assertEqual(tn.mean, 0.0) 228 self.assertEqual(tn.stddev, 0.05) 229 230 def test_custom_initializer_saving(self): 231 232 def my_initializer(shape, dtype=None): 233 return array_ops.ones(shape, dtype=dtype) 234 235 inputs = input_layer.Input((10,)) 236 outputs = core.Dense(1, kernel_initializer=my_initializer)(inputs) 237 model = models.Model(inputs, outputs) 238 model2 = model.from_config( 239 model.get_config(), custom_objects={'my_initializer': my_initializer}) 240 self.assertEqual(model2.layers[1].kernel_initializer, my_initializer) 241 242 @testing_utils.run_v2_only 243 def test_load_external_variance_scaling_v2(self): 244 external_serialized_json = { 245 'class_name': 'VarianceScaling', 246 'config': { 247 'distribution': 'normal', 248 'mode': 'fan_avg', 249 'scale': 1.0, 250 'seed': None 251 } 252 } 253 initializer = initializers.deserialize(external_serialized_json) 254 self.assertEqual(initializer.distribution, 'truncated_normal') 255 256 def test_partition(self): 257 with self.cached_session(): 258 partition_enabled_initializers = [ 259 initializers.ZerosV2(), 260 initializers.OnesV2(), 261 initializers.RandomUniformV2(), 262 initializers.RandomNormalV2(), 263 initializers.TruncatedNormalV2(), 264 initializers.LecunUniformV2(), 265 initializers.GlorotUniformV2(), 266 initializers.HeUniformV2() 267 ] 268 for initializer in partition_enabled_initializers: 269 got = initializer( 270 shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) 271 self.assertEqual(got.shape, (2, 2)) 272 273 partition_forbidden_initializers = [ 274 initializers.OrthogonalV2(), 275 initializers.IdentityV2() 276 ] 277 for initializer in partition_forbidden_initializers: 278 with self.assertRaisesRegex( 279 ValueError, 280 "initializer doesn't support partition-related arguments"): 281 initializer( 282 shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) 283 284 285if __name__ == '__main__': 286 test.main() 287