1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations often used for initializing tensors.
16
17All variable initializers returned by functions in this file should have the
18following signature:
19
20def _initializer(shape, dtype=dtypes.float32, partition_info=None):
21  Args:
22    shape: List of `int` representing the shape of the output `Tensor`. Some
23      initializers may also be able to accept a `Tensor`.
24    dtype: (Optional) Type of the output `Tensor`.
25    partition_info: (Optional) variable_scope._PartitionInfo object holding
26      additional information about how the variable is partitioned. May be
27      `None` if the variable is not partitioned.
28  Returns:
29    A `Tensor` of type `dtype` and `shape`.
30"""
31from __future__ import absolute_import
32from __future__ import division
33from __future__ import print_function
34
35import math
36
37import numpy as np
38
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import dtypes
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import linalg_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import random_ops
45from tensorflow.python.ops import random_ops
46from tensorflow.python.util.deprecation import deprecated
47from tensorflow.python.util.tf_export import tf_export
48
49
50@tf_export("keras.initializers.Initializer")
51class Initializer(object):
52  """Initializer base class: all initializers inherit from this class.
53  """
54
55  def __call__(self, shape, dtype=None, partition_info=None):
56    raise NotImplementedError
57
58  def get_config(self):
59    """Returns the configuration of the initializer as a JSON-serializable dict.
60
61    Returns:
62      A JSON-serializable Python dict.
63    """
64    return {}
65
66  @classmethod
67  def from_config(cls, config):
68    """Instantiates an initializer from a configuration dictionary.
69
70    Example:
71
72    ```python
73    initializer = RandomUniform(-1, 1)
74    config = initializer.get_config()
75    initializer = RandomUniform.from_config(config)
76    ```
77
78    Args:
79      config: A Python dictionary.
80        It will typically be the output of `get_config`.
81
82    Returns:
83      An Initializer instance.
84    """
85    return cls(**config)
86
87
88@tf_export("keras.initializers.Zeros", "initializers.zeros",
89           "zeros_initializer")
90class Zeros(Initializer):
91  """Initializer that generates tensors initialized to 0."""
92
93  def __init__(self, dtype=dtypes.float32):
94    self.dtype = dtypes.as_dtype(dtype)
95
96  def __call__(self, shape, dtype=None, partition_info=None):
97    if dtype is None:
98      dtype = self.dtype
99    return array_ops.zeros(shape, dtype)
100
101  def get_config(self):
102    return {"dtype": self.dtype.name}
103
104
105@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer")
106class Ones(Initializer):
107  """Initializer that generates tensors initialized to 1."""
108
109  def __init__(self, dtype=dtypes.float32):
110    self.dtype = dtypes.as_dtype(dtype)
111
112  def __call__(self, shape, dtype=None, partition_info=None):
113    if dtype is None:
114      dtype = self.dtype
115    return array_ops.ones(shape, dtype)
116
117  def get_config(self):
118    return {"dtype": self.dtype.name}
119
120
121@tf_export("keras.initializers.Constant", "initializers.constant",
122           "constant_initializer")
123class Constant(Initializer):
124  """Initializer that generates tensors with constant values.
125
126  The resulting tensor is populated with values of type `dtype`, as
127  specified by arguments `value` following the desired `shape` of the
128  new tensor (see examples below).
129
130  The argument `value` can be a constant value, or a list of values of type
131  `dtype`. If `value` is a list, then the length of the list must be less
132  than or equal to the number of elements implied by the desired shape of the
133  tensor. In the case where the total number of elements in `value` is less
134  than the number of elements required by the tensor shape, the last element
135  in `value` will be used to fill the remaining entries. If the total number of
136  elements in `value` is greater than the number of elements required by the
137  tensor shape, the initializer will raise a `ValueError`.
138
139  Args:
140    value: A Python scalar, list or tuple of values, or a N-dimensional numpy
141      array. All elements of the initialized variable will be set to the
142      corresponding value in the `value` argument.
143    dtype: The data type.
144    verify_shape: Boolean that enables verification of the shape of `value`. If
145      `True`, the initializer will throw an error if the shape of `value` is not
146      compatible with the shape of the initialized tensor.
147
148  Raises:
149    TypeError: If the input `value` is not one of the expected types.
150
151  Examples:
152    The following example can be rewritten using a numpy.ndarray instead
153    of the `value` list, even reshaped, as shown in the two commented lines
154    below the `value` list initialization.
155
156  ```python
157    >>> import numpy as np
158    >>> import tensorflow as tf
159
160    >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
161    >>> # value = np.array(value)
162    >>> # value = value.reshape([2, 4])
163    >>> init = tf.constant_initializer(value)
164
165    >>> print('fitting shape:')
166    >>> with tf.Session():
167    >>>   x = tf.get_variable('x', shape=[2, 4], initializer=init)
168    >>>   x.initializer.run()
169    >>>   print(x.eval())
170
171    fitting shape:
172    [[ 0.  1.  2.  3.]
173     [ 4.  5.  6.  7.]]
174
175    >>> print('larger shape:')
176    >>> with tf.Session():
177    >>>   x = tf.get_variable('x', shape=[3, 4], initializer=init)
178    >>>   x.initializer.run()
179    >>>   print(x.eval())
180
181    larger shape:
182    [[ 0.  1.  2.  3.]
183     [ 4.  5.  6.  7.]
184     [ 7.  7.  7.  7.]]
185
186    >>> print('smaller shape:')
187    >>> with tf.Session():
188    >>>   x = tf.get_variable('x', shape=[2, 3], initializer=init)
189
190    ValueError: Too many elements provided. Needed at most 6, but received 8
191
192    >>> print('shape verification:')
193    >>> init_verify = tf.constant_initializer(value, verify_shape=True)
194    >>> with tf.Session():
195    >>>   x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
196
197    TypeError: Expected Tensor's shape: (3, 4), got (8,).
198  ```
199  """
200
201  def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
202    if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
203      raise TypeError(
204          "Invalid type for initial value: %s (expected Python scalar, list or "
205          "tuple of values, or numpy.ndarray)." % type(value))
206
207    self.value = value
208    self.dtype = dtypes.as_dtype(dtype)
209    self._verify_shape = verify_shape
210
211  def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
212    if dtype is None:
213      dtype = self.dtype
214    if verify_shape is None:
215      verify_shape = self._verify_shape
216    return constant_op.constant(
217        self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
218
219  def get_config(self):
220    # We don't include `verify_shape` for compatibility with Keras.
221    # `verify_shape` should be passed as an argument to `__call__` rather
222    # than as a constructor argument: conceptually it isn't a property
223    # of the initializer.
224    return {"value": self.value, "dtype": self.dtype.name}
225
226
227@tf_export("keras.initializers.RandomUniform", "initializers.random_uniform",
228           "random_uniform_initializer")
229class RandomUniform(Initializer):
230  """Initializer that generates tensors with a uniform distribution.
231
232  Args:
233    minval: A python scalar or a scalar tensor. Lower bound of the range
234      of random values to generate.
235    maxval: A python scalar or a scalar tensor. Upper bound of the range
236      of random values to generate.  Defaults to 1 for float types.
237    seed: A Python integer. Used to create random seeds. See
238      @{tf.set_random_seed}
239      for behavior.
240    dtype: The data type.
241  """
242
243  def __init__(self, minval=0, maxval=None, seed=None, dtype=dtypes.float32):
244    self.minval = minval
245    self.maxval = maxval
246    self.seed = seed
247    self.dtype = dtypes.as_dtype(dtype)
248
249  def __call__(self, shape, dtype=None, partition_info=None):
250    if dtype is None:
251      dtype = self.dtype
252    return random_ops.random_uniform(
253        shape, self.minval, self.maxval, dtype, seed=self.seed)
254
255  def get_config(self):
256    return {
257        "minval": self.minval,
258        "maxval": self.maxval,
259        "seed": self.seed,
260        "dtype": self.dtype.name
261    }
262
263
264@tf_export("keras.initializers.RandomNormal", "initializers.random_normal",
265           "random_normal_initializer")
266class RandomNormal(Initializer):
267  """Initializer that generates tensors with a normal distribution.
268
269  Args:
270    mean: a python scalar or a scalar tensor. Mean of the random values
271      to generate.
272    stddev: a python scalar or a scalar tensor. Standard deviation of the
273      random values to generate.
274    seed: A Python integer. Used to create random seeds. See
275      @{tf.set_random_seed}
276      for behavior.
277    dtype: The data type. Only floating point types are supported.
278  """
279
280  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
281    self.mean = mean
282    self.stddev = stddev
283    self.seed = seed
284    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
285
286  def __call__(self, shape, dtype=None, partition_info=None):
287    if dtype is None:
288      dtype = self.dtype
289    return random_ops.random_normal(
290        shape, self.mean, self.stddev, dtype, seed=self.seed)
291
292  def get_config(self):
293    return {
294        "mean": self.mean,
295        "stddev": self.stddev,
296        "seed": self.seed,
297        "dtype": self.dtype.name
298    }
299
300
301@tf_export("keras.initializers.TruncatedNormal",
302           "initializers.truncated_normal", "truncated_normal_initializer")
303class TruncatedNormal(Initializer):
304  """Initializer that generates a truncated normal distribution.
305
306  These values are similar to values from a `random_normal_initializer`
307  except that values more than two standard deviations from the mean
308  are discarded and re-drawn. This is the recommended initializer for
309  neural network weights and filters.
310
311  Args:
312    mean: a python scalar or a scalar tensor. Mean of the random values
313      to generate.
314    stddev: a python scalar or a scalar tensor. Standard deviation of the
315      random values to generate.
316    seed: A Python integer. Used to create random seeds. See
317      @{tf.set_random_seed}
318      for behavior.
319    dtype: The data type. Only floating point types are supported.
320  """
321
322  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
323    self.mean = mean
324    self.stddev = stddev
325    self.seed = seed
326    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
327
328  def __call__(self, shape, dtype=None, partition_info=None):
329    if dtype is None:
330      dtype = self.dtype
331    return random_ops.truncated_normal(
332        shape, self.mean, self.stddev, dtype, seed=self.seed)
333
334  def get_config(self):
335    return {
336        "mean": self.mean,
337        "stddev": self.stddev,
338        "seed": self.seed,
339        "dtype": self.dtype.name
340    }
341
342
343@tf_export("initializers.uniform_unit_scaling",
344           "uniform_unit_scaling_initializer")
345class UniformUnitScaling(Initializer):
346  """Initializer that generates tensors without scaling variance.
347
348  When initializing a deep network, it is in principle advantageous to keep
349  the scale of the input variance constant, so it does not explode or diminish
350  by reaching the final layer. If the input is `x` and the operation `x * W`,
351  and we want to initialize `W` uniformly at random, we need to pick `W` from
352
353      [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
354
355  to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
356  A similar calculation for convolutional networks gives an analogous result
357  with `dim` equal to the product of the first 3 dimensions.  When
358  nonlinearities are present, we need to multiply this by a constant `factor`.
359  See [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
360  ([pdf](http://arxiv.org/pdf/1412.6558.pdf)) for deeper motivation, experiments
361  and the calculation of constants. In section 2.3 there, the constants were
362  numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
363
364  Args:
365    factor: Float.  A multiplicative factor by which the values will be scaled.
366    seed: A Python integer. Used to create random seeds. See
367      @{tf.set_random_seed}
368      for behavior.
369    dtype: The data type. Only floating point types are supported.
370  """
371
372  @deprecated(None,
373              "Use tf.initializers.variance_scaling instead with distribution="
374              "uniform to get equivalent behavior.")
375  def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
376    self.factor = factor
377    self.seed = seed
378    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
379
380  def __call__(self, shape, dtype=None, partition_info=None):
381    if dtype is None:
382      dtype = self.dtype
383    scale_shape = shape
384    if partition_info is not None:
385      scale_shape = partition_info.full_shape
386
387    input_size = 1.0
388    # Estimating input size is not possible to do perfectly, but we try.
389    # The estimate, obtained by multiplying all dimensions but the last one,
390    # is the right thing for matrix multiply and convolutions (see above).
391    for dim in scale_shape[:-1]:
392      input_size *= float(dim)
393    # Avoid errors when initializing zero-size tensors.
394    input_size = max(input_size, 1.0)
395    max_val = math.sqrt(3 / input_size) * self.factor
396    return random_ops.random_uniform(
397        shape, -max_val, max_val, dtype, seed=self.seed)
398
399  def get_config(self):
400    return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
401
402
403@tf_export("keras.initializers.VarianceScaling",
404           "initializers.variance_scaling", "variance_scaling_initializer")
405class VarianceScaling(Initializer):
406  """Initializer capable of adapting its scale to the shape of weights tensors.
407
408  With `distribution="normal"`, samples are drawn from a truncated normal
409  distribution centered on zero, with `stddev = sqrt(scale / n)`
410  where n is:
411    - number of input units in the weight tensor, if mode = "fan_in"
412    - number of output units, if mode = "fan_out"
413    - average of the numbers of input and output units, if mode = "fan_avg"
414
415  With `distribution="uniform"`, samples are drawn from a uniform distribution
416  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
417
418  Args:
419    scale: Scaling factor (positive float).
420    mode: One of "fan_in", "fan_out", "fan_avg".
421    distribution: Random distribution to use. One of "normal", "uniform".
422    seed: A Python integer. Used to create random seeds. See
423      @{tf.set_random_seed}
424      for behavior.
425    dtype: The data type. Only floating point types are supported.
426
427  Raises:
428    ValueError: In case of an invalid value for the "scale", mode" or
429      "distribution" arguments.
430  """
431
432  def __init__(self,
433               scale=1.0,
434               mode="fan_in",
435               distribution="normal",
436               seed=None,
437               dtype=dtypes.float32):
438    if scale <= 0.:
439      raise ValueError("`scale` must be positive float.")
440    if mode not in {"fan_in", "fan_out", "fan_avg"}:
441      raise ValueError("Invalid `mode` argument:", mode)
442    distribution = distribution.lower()
443    if distribution not in {"normal", "uniform"}:
444      raise ValueError("Invalid `distribution` argument:", distribution)
445    self.scale = scale
446    self.mode = mode
447    self.distribution = distribution
448    self.seed = seed
449    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
450
451  def __call__(self, shape, dtype=None, partition_info=None):
452    if dtype is None:
453      dtype = self.dtype
454    scale = self.scale
455    scale_shape = shape
456    if partition_info is not None:
457      scale_shape = partition_info.full_shape
458    fan_in, fan_out = _compute_fans(scale_shape)
459    if self.mode == "fan_in":
460      scale /= max(1., fan_in)
461    elif self.mode == "fan_out":
462      scale /= max(1., fan_out)
463    else:
464      scale /= max(1., (fan_in + fan_out) / 2.)
465    if self.distribution == "normal":
466      stddev = math.sqrt(scale)
467      return random_ops.truncated_normal(
468          shape, 0.0, stddev, dtype, seed=self.seed)
469    else:
470      limit = math.sqrt(3.0 * scale)
471      return random_ops.random_uniform(
472          shape, -limit, limit, dtype, seed=self.seed)
473
474  def get_config(self):
475    return {
476        "scale": self.scale,
477        "mode": self.mode,
478        "distribution": self.distribution,
479        "seed": self.seed,
480        "dtype": self.dtype.name
481    }
482
483
484@tf_export("keras.initializers.Orthogonal", "initializers.orthogonal",
485           "orthogonal_initializer")
486class Orthogonal(Initializer):
487  """Initializer that generates an orthogonal matrix.
488
489  If the shape of the tensor to initialize is two-dimensional, it is initialized
490  with an orthogonal matrix obtained from the QR decomposition of a matrix of
491  uniform random numbers. If the matrix has fewer rows than columns then the
492  output will have orthogonal rows. Otherwise, the output will have orthogonal
493  columns.
494
495  If the shape of the tensor to initialize is more than two-dimensional,
496  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
497  is initialized, where `n` is the length of the shape vector.
498  The matrix is subsequently reshaped to give a tensor of the desired shape.
499
500  Args:
501    gain: multiplicative factor to apply to the orthogonal matrix
502    dtype: The type of the output.
503    seed: A Python integer. Used to create random seeds. See
504      @{tf.set_random_seed}
505      for behavior.
506  """
507
508  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
509    self.gain = gain
510    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
511    self.seed = seed
512
513  def __call__(self, shape, dtype=None, partition_info=None):
514    if dtype is None:
515      dtype = self.dtype
516    # Check the shape
517    if len(shape) < 2:
518      raise ValueError("The tensor to initialize must be "
519                       "at least two-dimensional")
520    # Flatten the input shape with the last dimension remaining
521    # its original shape so it works for conv2d
522    num_rows = 1
523    for dim in shape[:-1]:
524      num_rows *= dim
525    num_cols = shape[-1]
526    flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows,
527                                                                   num_cols)
528
529    # Generate a random matrix
530    a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
531    # Compute the qr factorization
532    q, r = linalg_ops.qr(a, full_matrices=False)
533    # Make Q uniform
534    d = array_ops.diag_part(r)
535    ph = d / math_ops.abs(d)
536    q *= ph
537    if num_rows < num_cols:
538      q = array_ops.matrix_transpose(q)
539    return self.gain * array_ops.reshape(q, shape)
540
541  def get_config(self):
542    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
543
544
545@tf_export("keras.initializers.Identity", "initializers.identity")
546class Identity(Initializer):
547  """Initializer that generates the identity matrix.
548
549  Only use for 2D matrices.
550
551  Args:
552    gain: Multiplicative factor to apply to the identity matrix.
553    dtype: The type of the output.
554  """
555
556  def __init__(self, gain=1.0, dtype=dtypes.float32):
557    self.gain = gain
558    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
559
560  def __call__(self, shape, dtype=None, partition_info=None):
561    full_shape = shape if partition_info is None else partition_info.full_shape
562    if len(full_shape) != 2:
563      raise ValueError(
564          "Identity matrix initializer can only be used for 2D matrices.")
565    if dtype is None:
566      dtype = self.dtype
567    initializer = linalg_ops.eye(*full_shape, dtype=dtype)
568    if partition_info is not None:
569      initializer = array_ops.slice(initializer, partition_info.var_offset,
570                                    shape)
571    return self.gain * initializer
572
573  def get_config(self):
574    return {"gain": self.gain, "dtype": self.dtype.name}
575
576# Aliases.
577
578# pylint: disable=invalid-name
579zeros_initializer = Zeros
580ones_initializer = Ones
581constant_initializer = Constant
582random_uniform_initializer = RandomUniform
583random_normal_initializer = RandomNormal
584truncated_normal_initializer = TruncatedNormal
585uniform_unit_scaling_initializer = UniformUnitScaling
586variance_scaling_initializer = VarianceScaling
587orthogonal_initializer = Orthogonal
588identity_initializer = Identity
589
590# pylint: enable=invalid-name
591
592
593@tf_export("glorot_uniform_initializer")
594def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
595  """The Glorot uniform initializer, also called Xavier uniform initializer.
596
597  It draws samples from a uniform distribution within [-limit, limit]
598  where `limit` is `sqrt(6 / (fan_in + fan_out))`
599  where `fan_in` is the number of input units in the weight tensor
600  and `fan_out` is the number of output units in the weight tensor.
601
602  Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
603
604  Args:
605    seed: A Python integer. Used to create random seeds. See
606      @{tf.set_random_seed}
607      for behavior.
608    dtype: The data type. Only floating point types are supported.
609
610  Returns:
611    An initializer.
612  """
613  return variance_scaling_initializer(
614      scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
615
616
617@tf_export("glorot_normal_initializer")
618def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
619  """The Glorot normal initializer, also called Xavier normal initializer.
620
621  It draws samples from a truncated normal distribution centered on 0
622  with `stddev = sqrt(2 / (fan_in + fan_out))`
623  where `fan_in` is the number of input units in the weight tensor
624  and `fan_out` is the number of output units in the weight tensor.
625
626  Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
627
628  Args:
629    seed: A Python integer. Used to create random seeds. See
630      @{tf.set_random_seed}
631      for behavior.
632    dtype: The data type. Only floating point types are supported.
633
634  Returns:
635    An initializer.
636  """
637  return variance_scaling_initializer(
638      scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype)
639
640
641# Utility functions.
642
643
644def _compute_fans(shape):
645  """Computes the number of input and output units for a weight shape.
646
647  Args:
648    shape: Integer shape tuple or TF tensor shape.
649
650  Returns:
651    A tuple of scalars (fan_in, fan_out).
652  """
653  if len(shape) < 1:  # Just to avoid errors for constants.
654    fan_in = fan_out = 1
655  elif len(shape) == 1:
656    fan_in = fan_out = shape[0]
657  elif len(shape) == 2:
658    fan_in = shape[0]
659    fan_out = shape[1]
660  else:
661    # Assuming convolution kernels (2D, 3D, or more).
662    # kernel shape: (..., input_depth, depth)
663    receptive_field_size = 1.
664    for dim in shape[:-2]:
665      receptive_field_size *= dim
666    fan_in = shape[-2] * receptive_field_size
667    fan_out = shape[-1] * receptive_field_size
668  return fan_in, fan_out
669
670
671def _assert_float_dtype(dtype):
672  """Validate and return floating point type based on `dtype`.
673
674  `dtype` must be a floating point type.
675
676  Args:
677    dtype: The data type to validate.
678
679  Returns:
680    Validated type.
681
682  Raises:
683    ValueError: if `dtype` is not a floating point type.
684  """
685  if not dtype.is_floating:
686    raise ValueError("Expected floating point type, got %s." % dtype)
687  return dtype
688