1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations often used for initializing tensors.
16
17All variable initializers returned by functions in this file should have the
18following signature:
19
20def _initializer(shape, dtype=dtypes.float32, partition_info=None):
21  Args:
22    shape: List of `int` representing the shape of the output `Tensor`. Some
23      initializers may also be able to accept a `Tensor`.
24    dtype: (Optional) Type of the output `Tensor`.
25    partition_info: (Optional) variable_scope._PartitionInfo object holding
26      additional information about how the variable is partitioned. May be
27      `None` if the variable is not partitioned.
28  Returns:
29    A `Tensor` of type `dtype` and `shape`.
30"""
31from __future__ import absolute_import
32from __future__ import division
33from __future__ import print_function
34
35import math
36
37import numpy as np
38
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import gen_linalg_ops
44from tensorflow.python.ops import linalg_ops_impl
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.util import deprecation
48from tensorflow.python.util.deprecation import deprecated
49from tensorflow.python.util.deprecation import  deprecated_arg_values
50from tensorflow.python.util.deprecation import  deprecated_args
51from tensorflow.python.util.tf_export import tf_export
52
53
54class Initializer(object):
55  """Initializer base class: all initializers inherit from this class.
56  """
57
58  def __call__(self, shape, dtype=None, partition_info=None):
59    """Returns a tensor object initialized as specified by the initializer.
60
61    Args:
62      shape: Shape of the tensor.
63      dtype: Optional dtype of the tensor. If not provided use the initializer
64        dtype.
65      partition_info: Optional information about the possible partitioning of a
66        tensor.
67    """
68    raise NotImplementedError
69
70  def get_config(self):
71    """Returns the configuration of the initializer as a JSON-serializable dict.
72
73    Returns:
74      A JSON-serializable Python dict.
75    """
76    return {}
77
78  @classmethod
79  def from_config(cls, config):
80    """Instantiates an initializer from a configuration dictionary.
81
82    Example:
83
84    ```python
85    initializer = RandomUniform(-1, 1)
86    config = initializer.get_config()
87    initializer = RandomUniform.from_config(config)
88    ```
89
90    Args:
91      config: A Python dictionary.
92        It will typically be the output of `get_config`.
93
94    Returns:
95      An Initializer instance.
96    """
97    return cls(**config)
98
99
100@tf_export(v1=["initializers.zeros", "zeros_initializer"])
101@deprecation.deprecated_endpoints("initializers.zeros")
102class Zeros(Initializer):
103  """Initializer that generates tensors initialized to 0."""
104
105  @deprecated_args(None,
106                   "Call initializer instance with the dtype argument instead "
107                   "of passing it to the constructor",
108                   "dtype")
109  def __init__(self, dtype=dtypes.float32):
110    self.dtype = dtypes.as_dtype(dtype)
111
112  def __call__(self, shape, dtype=None, partition_info=None):
113    if dtype is None:
114      dtype = self.dtype
115    return array_ops.zeros(shape, dtype)
116
117  def get_config(self):
118    return {"dtype": self.dtype.name}
119
120
121@tf_export(v1=["initializers.ones", "ones_initializer"])
122@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer")
123class Ones(Initializer):
124  """Initializer that generates tensors initialized to 1."""
125
126  @deprecated_args(None,
127                   "Call initializer instance with the dtype argument instead "
128                   "of passing it to the constructor",
129                   "dtype")
130  def __init__(self, dtype=dtypes.float32):
131    self.dtype = dtypes.as_dtype(dtype)
132
133  def __call__(self, shape, dtype=None, partition_info=None):
134    if dtype is None:
135      dtype = self.dtype
136    return array_ops.ones(shape, dtype)
137
138  def get_config(self):
139    return {"dtype": self.dtype.name}
140
141
142@tf_export(v1=["initializers.constant", "constant_initializer"])
143@deprecation.deprecated_endpoints("constant_initializer")
144class Constant(Initializer):
145  """Initializer that generates tensors with constant values.
146
147  The resulting tensor is populated with values of type `dtype`, as
148  specified by arguments `value` following the desired `shape` of the
149  new tensor (see examples below).
150
151  The argument `value` can be a constant value, or a list of values of type
152  `dtype`. If `value` is a list, then the length of the list must be less
153  than or equal to the number of elements implied by the desired shape of the
154  tensor. In the case where the total number of elements in `value` is less
155  than the number of elements required by the tensor shape, the last element
156  in `value` will be used to fill the remaining entries. If the total number of
157  elements in `value` is greater than the number of elements required by the
158  tensor shape, the initializer will raise a `ValueError`.
159
160  Args:
161    value: A Python scalar, list or tuple of values, or a N-dimensional numpy
162      array. All elements of the initialized variable will be set to the
163      corresponding value in the `value` argument.
164    dtype: Default data type, used if no `dtype` argument is provided when
165      calling the initializer.
166    verify_shape: Boolean that enables verification of the shape of `value`. If
167      `True`, the initializer will throw an error if the shape of `value` is not
168      compatible with the shape of the initialized tensor.
169
170  Raises:
171    TypeError: If the input `value` is not one of the expected types.
172
173  Examples:
174    The following example can be rewritten using a numpy.ndarray instead
175    of the `value` list, even reshaped, as shown in the two commented lines
176    below the `value` list initialization.
177
178  ```python
179    >>> import numpy as np
180    >>> import tensorflow as tf
181
182    >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
183    >>> # value = np.array(value)
184    >>> # value = value.reshape([2, 4])
185    >>> init = tf.constant_initializer(value)
186
187    >>> print('fitting shape:')
188    >>> with tf.Session():
189    >>>   x = tf.get_variable('x', shape=[2, 4], initializer=init)
190    >>>   x.initializer.run()
191    >>>   print(x.eval())
192
193    fitting shape:
194    [[ 0.  1.  2.  3.]
195     [ 4.  5.  6.  7.]]
196
197    >>> print('larger shape:')
198    >>> with tf.Session():
199    >>>   x = tf.get_variable('x', shape=[3, 4], initializer=init)
200    >>>   x.initializer.run()
201    >>>   print(x.eval())
202
203    larger shape:
204    [[ 0.  1.  2.  3.]
205     [ 4.  5.  6.  7.]
206     [ 7.  7.  7.  7.]]
207
208    >>> print('smaller shape:')
209    >>> with tf.Session():
210    >>>   x = tf.get_variable('x', shape=[2, 3], initializer=init)
211
212    ValueError: Too many elements provided. Needed at most 6, but received 8
213
214    >>> print('shape verification:')
215    >>> init_verify = tf.constant_initializer(value, verify_shape=True)
216    >>> with tf.Session():
217    >>>   x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
218
219    TypeError: Expected Tensor's shape: (3, 4), got (8,).
220  ```
221  """
222
223  @deprecated_args(None,
224                   "Call initializer instance with the dtype argument instead "
225                   "of passing it to the constructor",
226                   "dtype")
227  @deprecated_args(None,
228                   "Objects must now be the required shape or no shape "
229                   "can be specified",
230                   "verify_shape")
231  def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
232    if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
233      raise TypeError(
234          "Invalid type for initial value: %s (expected Python scalar, list or "
235          "tuple of values, or numpy.ndarray)." % type(value))
236
237    self.value = value
238    self.dtype = dtypes.as_dtype(dtype)
239    self._verify_shape = verify_shape
240
241  def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
242    if dtype is None:
243      dtype = self.dtype
244    if verify_shape is None:
245      verify_shape = self._verify_shape
246    return constant_op.constant_v1(
247        self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
248
249  def get_config(self):
250    # We don't include `verify_shape` for compatibility with Keras.
251    # `verify_shape` should be passed as an argument to `__call__` rather
252    # than as a constructor argument: conceptually it isn't a property
253    # of the initializer.
254    return {"value": self.value, "dtype": self.dtype.name}
255
256
257@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"])
258@deprecation.deprecated_endpoints("initializers.random_uniform")
259class RandomUniform(Initializer):
260  """Initializer that generates tensors with a uniform distribution.
261
262  Args:
263    minval: A python scalar or a scalar tensor. Lower bound of the range
264      of random values to generate.
265    maxval: A python scalar or a scalar tensor. Upper bound of the range
266      of random values to generate.  Defaults to 1 for float types.
267    seed: A Python integer. Used to create random seeds. See
268      `tf.set_random_seed`
269      for behavior.
270    dtype: Default data type, used if no `dtype` argument is provided when
271      calling the initializer.
272  """
273
274  @deprecated_args(None,
275                   "Call initializer instance with the dtype argument instead "
276                   "of passing it to the constructor",
277                   "dtype")
278  def __init__(self, minval=0, maxval=None, seed=None, dtype=dtypes.float32):
279    self.minval = minval
280    self.maxval = maxval
281    self.seed = seed
282    self.dtype = dtypes.as_dtype(dtype)
283
284  def __call__(self, shape, dtype=None, partition_info=None):
285    if dtype is None:
286      dtype = self.dtype
287    return random_ops.random_uniform(
288        shape, self.minval, self.maxval, dtype, seed=self.seed)
289
290  def get_config(self):
291    return {
292        "minval": self.minval,
293        "maxval": self.maxval,
294        "seed": self.seed,
295        "dtype": self.dtype.name
296    }
297
298
299@tf_export(v1=["initializers.random_normal", "random_normal_initializer"])
300@deprecation.deprecated_endpoints("initializers.random_normal")
301class RandomNormal(Initializer):
302  """Initializer that generates tensors with a normal distribution.
303
304  Args:
305    mean: a python scalar or a scalar tensor. Mean of the random values
306      to generate.
307    stddev: a python scalar or a scalar tensor. Standard deviation of the
308      random values to generate.
309    seed: A Python integer. Used to create random seeds. See
310      `tf.set_random_seed`
311      for behavior.
312    dtype: Default data type, used if no `dtype` argument is provided when
313      calling the initializer. Only floating point types are supported.
314  """
315
316  @deprecated_args(None,
317                   "Call initializer instance with the dtype argument instead "
318                   "of passing it to the constructor",
319                   "dtype")
320  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
321    self.mean = mean
322    self.stddev = stddev
323    self.seed = seed
324    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
325
326  def __call__(self, shape, dtype=None, partition_info=None):
327    if dtype is None:
328      dtype = self.dtype
329    return random_ops.random_normal(
330        shape, self.mean, self.stddev, dtype, seed=self.seed)
331
332  def get_config(self):
333    return {
334        "mean": self.mean,
335        "stddev": self.stddev,
336        "seed": self.seed,
337        "dtype": self.dtype.name
338    }
339
340
341@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"])
342@deprecation.deprecated_endpoints("initializers.truncated_normal",
343                                  "truncated_normal_initializer")
344class TruncatedNormal(Initializer):
345  """Initializer that generates a truncated normal distribution.
346
347  These values are similar to values from a `random_normal_initializer`
348  except that values more than two standard deviations from the mean
349  are discarded and re-drawn. This is the recommended initializer for
350  neural network weights and filters.
351
352  Args:
353    mean: a python scalar or a scalar tensor. Mean of the random values
354      to generate.
355    stddev: a python scalar or a scalar tensor. Standard deviation of the
356      random values to generate.
357    seed: A Python integer. Used to create random seeds. See
358      `tf.set_random_seed`
359      for behavior.
360    dtype: Default data type, used if no `dtype` argument is provided when
361      calling the initializer. Only floating point types are supported.
362  """
363
364  @deprecated_args(None,
365                   "Call initializer instance with the dtype argument instead "
366                   "of passing it to the constructor",
367                   "dtype")
368  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
369    self.mean = mean
370    self.stddev = stddev
371    self.seed = seed
372    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
373
374  def __call__(self, shape, dtype=None, partition_info=None):
375    if dtype is None:
376      dtype = self.dtype
377    return random_ops.truncated_normal(
378        shape, self.mean, self.stddev, dtype, seed=self.seed)
379
380  def get_config(self):
381    return {
382        "mean": self.mean,
383        "stddev": self.stddev,
384        "seed": self.seed,
385        "dtype": self.dtype.name
386    }
387
388
389@tf_export(v1=["initializers.uniform_unit_scaling",
390               "uniform_unit_scaling_initializer"])
391@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer",
392                                  "initializers.uniform_unit_scaling")
393class UniformUnitScaling(Initializer):
394  """Initializer that generates tensors without scaling variance.
395
396  When initializing a deep network, it is in principle advantageous to keep
397  the scale of the input variance constant, so it does not explode or diminish
398  by reaching the final layer. If the input is `x` and the operation `x * W`,
399  and we want to initialize `W` uniformly at random, we need to pick `W` from
400
401      [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
402
403  to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
404  A similar calculation for convolutional networks gives an analogous result
405  with `dim` equal to the product of the first 3 dimensions.  When
406  nonlinearities are present, we need to multiply this by a constant `factor`.
407  See (Sussillo et al., 2014) for deeper motivation, experiments
408  and the calculation of constants. In section 2.3 there, the constants were
409  numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
410
411  Args:
412    factor: Float.  A multiplicative factor by which the values will be scaled.
413    seed: A Python integer. Used to create random seeds. See
414      `tf.set_random_seed`
415      for behavior.
416    dtype: Default data type, used if no `dtype` argument is provided when
417      calling the initializer. Only floating point types are supported.
418
419  References:
420      [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
421      ([pdf](http://arxiv.org/pdf/1412.6558.pdf))
422  """
423
424  @deprecated_args(None,
425                   "Call initializer instance with the dtype argument instead "
426                   "of passing it to the constructor",
427                   "dtype")
428  @deprecated(None,
429              "Use tf.initializers.variance_scaling instead with distribution="
430              "uniform to get equivalent behavior.")
431  def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
432    self.factor = factor
433    self.seed = seed
434    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
435
436  def __call__(self, shape, dtype=None, partition_info=None):
437    if dtype is None:
438      dtype = self.dtype
439    scale_shape = shape
440    if partition_info is not None:
441      scale_shape = partition_info.full_shape
442
443    input_size = 1.0
444    # Estimating input size is not possible to do perfectly, but we try.
445    # The estimate, obtained by multiplying all dimensions but the last one,
446    # is the right thing for matrix multiply and convolutions (see above).
447    for dim in scale_shape[:-1]:
448      input_size *= float(dim)
449    # Avoid errors when initializing zero-size tensors.
450    input_size = max(input_size, 1.0)
451    max_val = math.sqrt(3 / input_size) * self.factor
452    return random_ops.random_uniform(
453        shape, -max_val, max_val, dtype, seed=self.seed)
454
455  def get_config(self):
456    return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
457
458
459@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"])
460@deprecation.deprecated_endpoints("initializers.variance_scaling",
461                                  "variance_scaling_initializer")
462class VarianceScaling(Initializer):
463  """Initializer capable of adapting its scale to the shape of weights tensors.
464
465  With `distribution="truncated_normal" or "untruncated_normal"`,
466  samples are drawn from a truncated/untruncated normal
467  distribution with a mean of zero and a standard deviation (after truncation,
468  if used) `stddev = sqrt(scale / n)`
469  where n is:
470    - number of input units in the weight tensor, if mode = "fan_in"
471    - number of output units, if mode = "fan_out"
472    - average of the numbers of input and output units, if mode = "fan_avg"
473
474  With `distribution="uniform"`, samples are drawn from a uniform distribution
475  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
476
477  Args:
478    scale: Scaling factor (positive float).
479    mode: One of "fan_in", "fan_out", "fan_avg".
480    distribution: Random distribution to use. One of "normal", "uniform".
481    seed: A Python integer. Used to create random seeds. See
482      `tf.set_random_seed`
483      for behavior.
484    dtype: Default data type, used if no `dtype` argument is provided when
485      calling the initializer. Only floating point types are supported.
486
487  Raises:
488    ValueError: In case of an invalid value for the "scale", mode" or
489      "distribution" arguments.
490  """
491
492  @deprecated_args(None,
493                   "Call initializer instance with the dtype argument instead "
494                   "of passing it to the constructor",
495                   "dtype")
496  @deprecated_arg_values(
497      None,
498      "`normal` is a deprecated alias for `truncated_normal`",
499      distribution="normal")
500  def __init__(self,
501               scale=1.0,
502               mode="fan_in",
503               distribution="truncated_normal",
504               seed=None,
505               dtype=dtypes.float32):
506    if scale <= 0.:
507      raise ValueError("`scale` must be positive float.")
508    if mode not in {"fan_in", "fan_out", "fan_avg"}:
509      raise ValueError("Invalid `mode` argument:", mode)
510    distribution = distribution.lower()
511    if distribution not in {"normal", "uniform",
512                            "truncated_normal", "untruncated_normal"}:
513      raise ValueError("Invalid `distribution` argument:", distribution)
514    self.scale = scale
515    self.mode = mode
516    self.distribution = distribution
517    self.seed = seed
518    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
519
520  def __call__(self, shape, dtype=None, partition_info=None):
521    if dtype is None:
522      dtype = self.dtype
523    scale = self.scale
524    scale_shape = shape
525    if partition_info is not None:
526      scale_shape = partition_info.full_shape
527    fan_in, fan_out = _compute_fans(scale_shape)
528    if self.mode == "fan_in":
529      scale /= max(1., fan_in)
530    elif self.mode == "fan_out":
531      scale /= max(1., fan_out)
532    else:
533      scale /= max(1., (fan_in + fan_out) / 2.)
534    if self.distribution == "normal" or self.distribution == "truncated_normal":
535      # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
536      stddev = math.sqrt(scale) / .87962566103423978
537      return random_ops.truncated_normal(
538          shape, 0.0, stddev, dtype, seed=self.seed)
539    elif self.distribution == "untruncated_normal":
540      stddev = math.sqrt(scale)
541      return random_ops.random_normal(
542          shape, 0.0, stddev, dtype, seed=self.seed)
543    else:
544      limit = math.sqrt(3.0 * scale)
545      return random_ops.random_uniform(
546          shape, -limit, limit, dtype, seed=self.seed)
547
548  def get_config(self):
549    return {
550        "scale": self.scale,
551        "mode": self.mode,
552        "distribution": self.distribution,
553        "seed": self.seed,
554        "dtype": self.dtype.name
555    }
556
557
558@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"])
559@deprecation.deprecated_endpoints("initializers.orthogonal",
560                                  "orthogonal_initializer")
561class Orthogonal(Initializer):
562  """Initializer that generates an orthogonal matrix.
563
564  If the shape of the tensor to initialize is two-dimensional, it is initialized
565  with an orthogonal matrix obtained from the QR decomposition of a matrix of
566  random numbers drawn from a normal distribution.
567  If the matrix has fewer rows than columns then the output will have orthogonal
568  rows. Otherwise, the output will have orthogonal columns.
569
570  If the shape of the tensor to initialize is more than two-dimensional,
571  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
572  is initialized, where `n` is the length of the shape vector.
573  The matrix is subsequently reshaped to give a tensor of the desired shape.
574
575  Args:
576    gain: multiplicative factor to apply to the orthogonal matrix
577    seed: A Python integer. Used to create random seeds. See
578      `tf.set_random_seed`
579      for behavior.
580    dtype: Default data type, used if no `dtype` argument is provided when
581      calling the initializer. Only floating point types are supported.
582
583  References:
584      [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
585      ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
586  """
587
588  @deprecated_args(None,
589                   "Call initializer instance with the dtype argument instead "
590                   "of passing it to the constructor",
591                   "dtype")
592  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
593    self.gain = gain
594    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
595    self.seed = seed
596
597  def __call__(self, shape, dtype=None, partition_info=None):
598    if dtype is None:
599      dtype = self.dtype
600    # Check the shape
601    if len(shape) < 2:
602      raise ValueError("The tensor to initialize must be "
603                       "at least two-dimensional")
604    # Flatten the input shape with the last dimension remaining
605    # its original shape so it works for conv2d
606    num_rows = 1
607    for dim in shape[:-1]:
608      num_rows *= dim
609    num_rows = int(num_rows)
610    num_cols = int(shape[-1])
611    if num_rows < num_cols:
612      flat_shape = (num_cols, num_rows)
613    else:
614      flat_shape = (num_rows, num_cols)
615
616    # Generate a random matrix
617    a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
618    # Compute the qr factorization
619    q, r = gen_linalg_ops.qr(a, full_matrices=False)
620    # Make Q uniform
621    d = array_ops.diag_part(r)
622    q *= math_ops.sign(d)
623    if num_rows < num_cols:
624      q = array_ops.matrix_transpose(q)
625    return self.gain * array_ops.reshape(q, shape)
626
627  def get_config(self):
628    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
629
630
631# Note these haven't been ported to TF2.0. They are not currently visible and
632# the tests are non trivial to port
633class ConvolutionDeltaOrthogonal(Initializer):
634  """Initializer that generates a delta orthogonal kernel for ConvNets.
635
636  The shape of the tensor must have length 3, 4 or 5. The number of input
637  filters must not exceed the number of output filters. The center pixels of the
638  tensor form an orthogonal matrix. Other pixels are set to be zero. See
639  algorithm 2 in (Xiao et al., 2018).
640
641
642  Args:
643    gain: Multiplicative factor to apply to the orthogonal
644      matrix. Default is 1. The 2-norm of an input is multiplied by a factor of
645      `gain` after applying this convolution.
646    seed: A Python integer. Used to create random seeds. See
647      `tf.set_random_seed` for behavior.
648    dtype: Default data type, used if no `dtype` argument is provided when
649      calling the initializer. Only floating point types are supported.
650
651  References:
652      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
653      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
654  """
655
656  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
657    self.gain = gain
658    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
659    self.seed = seed
660
661  def __call__(self, shape, dtype=None, partition_info=None):
662    if dtype is None:
663      dtype = self.dtype
664    # Check the shape
665    if len(shape) < 3 or len(shape) > 5:
666      raise ValueError("The tensor to initialize must be at least "
667                       "three-dimensional and at most five-dimensional")
668
669    if shape[-2] > shape[-1]:
670      raise ValueError("In_filters cannot be greater than out_filters.")
671
672    # Generate a random matrix
673    a = random_ops.random_normal([shape[-1], shape[-1]],
674                                 dtype=dtype, seed=self.seed)
675    # Compute the qr factorization
676    q, r = gen_linalg_ops.qr(a, full_matrices=False)
677    # Make Q uniform
678    d = array_ops.diag_part(r)
679    q *= math_ops.sign(d)
680    q = q[:shape[-2], :]
681    q *= math_ops.cast(self.gain, dtype=dtype)
682    if len(shape) == 3:
683      weight = array_ops.scatter_nd([[(shape[0]-1)//2]],
684                                    array_ops.expand_dims(q, 0), shape)
685    elif len(shape) == 4:
686      weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2]],
687                                    array_ops.expand_dims(q, 0), shape)
688    else:
689      weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2,
690                                      (shape[2]-1)//2]],
691                                    array_ops.expand_dims(q, 0), shape)
692    return weight
693
694  def get_config(self):
695    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
696
697
698class ConvolutionOrthogonal(Initializer):
699  """Initializer that generates orthogonal kernel for ConvNets.
700
701  Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution.
702
703  Args:
704    gain: multiplicative factor to apply to the orthogonal
705      matrix. Default is 1. The 2-norm of an input is multiplied by a factor of
706      `gain` after applying this convolution.
707    seed: A Python integer. Used to create random seeds. See
708      `tf.set_random_seed` for behavior.
709    dtype: Default data type, used if no `dtype` argument is provided when
710      calling the initializer. Only floating point types are supported.
711
712  References:
713      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
714      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
715  """
716
717  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
718    self.gain = gain
719    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
720    self.seed = seed
721
722  def __call__(self, shape, dtype=None, partition_info=None):
723    raise NotImplementedError
724
725  def get_config(self):
726    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
727
728  # Helper functions.
729  def _orthogonal_matrix(self, n):
730    """Construct an n x n orthogonal matrix.
731
732    Args:
733      n: Dimension.
734    Returns:
735      A n x n orthogonal matrix.
736    """
737    a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
738    if self.seed:
739      self.seed += 1
740    q, r = gen_linalg_ops.qr(a)
741    d = array_ops.diag_part(r)
742    # make q uniform
743    q *= math_ops.sign(d)
744    return q
745
746  def _symmetric_projection(self, n):
747    """Compute a n x n symmetric projection matrix.
748
749    Args:
750      n: Dimension.
751    Returns:
752      A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
753    """
754    q = self._orthogonal_matrix(n)
755    # randomly zeroing out some columns
756    mask = math_ops.cast(random_ops.random_normal([n], seed=self.seed) > 0,
757                         self.dtype)
758    if self.seed:
759      self.seed += 1
760    c = math_ops.multiply(q, mask)
761    return math_ops.matmul(c, array_ops.matrix_transpose(c))
762
763
764class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
765  """Initializer that generates a 2D orthogonal kernel for ConvNets.
766
767  The shape of the tensor must have length 4. The number of input
768  filters must not exceed the number of output filters.
769  The orthogonality(==isometry) is exact when the inputs are circular padded.
770  There are finite-width effects with non-circular padding (e.g. zero padding).
771  See algorithm 1 in (Xiao et al., 2018).
772
773  Args:
774    gain: Multiplicative factor to apply to the orthogonal
775      matrix. Default is 1. This has the effect of scaling the output 2-norm by
776      a factor of `gain`.
777    seed: A Python integer. Used to create random seeds. See
778      `tf.set_random_seed` for behavior.
779    dtype: Default data type, used if no `dtype` argument is provided when
780      calling the initializer. Only floating point types are supported.
781
782  References:
783      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
784      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
785  """
786
787  def __call__(self, shape, dtype=None, partition_info=None):
788    if dtype is None:
789      dtype = self.dtype
790    if len(shape) != 4:
791      raise ValueError("The tensor to initialize must be four-dimensional")
792
793    if shape[-2] > shape[-1]:
794      raise ValueError("In_filters cannot be greater than out_filters.")
795
796    if shape[0] != shape[1]:
797      raise ValueError("Kernel sizes must be equal.")
798
799    kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
800    kernel *= math_ops.cast(self.gain, dtype=dtype)
801    return kernel
802
803  def _dict_to_tensor(self, x, k1, k2):
804    """Convert a dictionary to a tensor.
805
806    Args:
807      x: A k1 * k2 dictionary.
808      k1: First dimension of x.
809      k2: Second dimension of x.
810    Returns:
811      A k1 * k2 tensor.
812    """
813
814    return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)])
815                            for i in range(k1)])
816
817  def _block_orth(self, p1, p2):
818    """Construct a 2 x 2 kernel. Used to construct orthgonal kernel.
819
820    Args:
821      p1: A symmetric projection matrix.
822      p2: A symmetric projection matrix.
823    Returns:
824      A 2 x 2 kernel [[p1p2,         p1(1-p2)],
825                      [(1-p1)p2, (1-p1)(1-p2)]].
826    Raises:
827      ValueError: If the dimensions of p1 and p2 are different.
828    """
829    if p1.shape.as_list() != p2.shape.as_list():
830      raise ValueError("The dimension of the matrices must be the same.")
831    n = p1.shape.as_list()[0]
832    kernel2x2 = {}
833    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
834    kernel2x2[0, 0] = math_ops.matmul(p1, p2)
835    kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
836    kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
837    kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2))
838
839    return kernel2x2
840
841  def _matrix_conv(self, m1, m2):
842    """Matrix convolution.
843
844    Args:
845      m1: A k x k dictionary, each element is a n x n matrix.
846      m2: A l x l dictionary, each element is a n x n matrix.
847
848    Returns:
849      (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix.
850    Raises:
851      ValueError: if the entries of m1 and m2 are of different dimensions.
852    """
853
854    n = (m1[0, 0]).shape.as_list()[0]
855    if n != (m2[0, 0]).shape.as_list()[0]:
856      raise ValueError("The entries in matrices m1 and m2 "
857                       "must have the same dimensions!")
858    k = int(np.sqrt(len(m1)))
859    l = int(np.sqrt(len(m2)))
860    result = {}
861    size = k + l - 1
862    # Compute matrix convolution between m1 and m2.
863    for i in range(size):
864      for j in range(size):
865        result[i, j] = array_ops.zeros([n, n], self.dtype)
866        for index1 in range(min(k, i + 1)):
867          for index2 in range(min(k, j + 1)):
868            if (i - index1) < l and (j - index2) < l:
869              result[i, j] += math_ops.matmul(m1[index1, index2],
870                                              m2[i - index1, j - index2])
871    return result
872
873  def _orthogonal_kernel(self, ksize, cin, cout):
874    """Construct orthogonal kernel for convolution.
875
876    Args:
877      ksize: Kernel size.
878      cin: Number of input channels.
879      cout: Number of output channels.
880    Returns:
881      An [ksize, ksize, cin, cout] orthogonal kernel.
882    Raises:
883      ValueError: If cin > cout.
884    """
885    if cin > cout:
886      raise ValueError("The number of input channels cannot exceed "
887                       "the number of output channels.")
888    orth = self._orthogonal_matrix(cout)[0:cin, :]
889    if ksize == 1:
890      return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0)
891
892    p = self._block_orth(self._symmetric_projection(cout),
893                         self._symmetric_projection(cout))
894    for _ in range(ksize - 2):
895      temp = self._block_orth(self._symmetric_projection(cout),
896                              self._symmetric_projection(cout))
897      p = self._matrix_conv(p, temp)
898    for i in range(ksize):
899      for j in range(ksize):
900        p[i, j] = math_ops.matmul(orth, p[i, j])
901
902    return self._dict_to_tensor(p, ksize, ksize)
903
904
905class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
906  """Initializer that generates a 1D orthogonal kernel for ConvNets.
907
908  The shape of the tensor must have length 3. The number of input
909  filters must not exceed the number of output filters.
910  The orthogonality(==isometry) is exact when the inputs are circular padded.
911  There are finite-width effects with non-circular padding (e.g. zero padding).
912  See algorithm 1 in (Xiao et al., 2018).
913
914  Args:
915    gain: Multiplicative factor to apply to the orthogonal
916      matrix. Default is 1. The 2-norm of an input is multiplied by a factor of
917      `gain` after applying this convolution.
918    seed: A Python integer. Used to create random seeds. See
919      `tf.set_random_seed`
920      for behavior.
921    dtype: Default data type, used if no `dtype` argument is provided when
922      calling the initializer. Only floating point types are supported.
923
924  References:
925      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
926      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
927  """
928
929  def __call__(self, shape, dtype=None, partition_info=None):
930    if dtype is None:
931      dtype = self.dtype
932    if len(shape) != 3:
933      raise ValueError("The tensor to initialize must be three-dimensional")
934
935    if shape[-2] > shape[-1]:
936      raise ValueError("In_filters cannot be greater than out_filters.")
937
938    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
939    kernel *= math_ops.cast(self.gain, dtype=dtype)
940    return kernel
941
942  def _dict_to_tensor(self, x, k):
943    """Convert a dictionary to a tensor.
944
945    Args:
946      x: A dictionary of length k.
947      k: Dimension of x.
948    Returns:
949      A tensor with the same dimension.
950    """
951
952    return array_ops.stack([x[i] for i in range(k)])
953
954  def _block_orth(self, projection_matrix):
955    """Construct a kernel. Used to construct orthgonal kernel.
956
957    Args:
958      projection_matrix: A symmetric projection matrix of size n x n.
959    Returns:
960      [projection_matrix, (1 - projection_matrix)].
961    """
962    n = projection_matrix.shape.as_list()[0]
963    kernel = {}
964    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
965    kernel[0] = projection_matrix
966    kernel[1] = eye - projection_matrix
967    return kernel
968
969  def _matrix_conv(self, m1, m2):
970    """Matrix convolution.
971
972    Args:
973      m1: A dictionary of length k, each element is a n x n matrix.
974      m2: A dictionary of length l, each element is a n x n matrix.
975
976    Returns:
977      (k + l - 1)  dictionary each element is a n x n matrix.
978    Raises:
979      ValueError: Ff the entries of m1 and m2 are of different dimensions.
980    """
981
982    n = (m1[0]).shape.as_list()[0]
983    if n != (m2[0]).shape.as_list()[0]:
984      raise ValueError("The entries in matrices m1 and m2 "
985                       "must have the same dimensions!")
986    k = len(m1)
987    l = len(m2)
988    result = {}
989    size = k + l - 1
990    # Compute matrix convolution between m1 and m2.
991    for i in range(size):
992      result[i] = array_ops.zeros([n, n], self.dtype)
993      for index in range(min(k, i + 1)):
994        if (i - index) < l:
995          result[i] += math_ops.matmul(m1[index], m2[i - index])
996    return result
997
998  def _orthogonal_kernel(self, ksize, cin, cout):
999    """Construct orthogonal kernel for convolution.
1000
1001    Args:
1002      ksize: Kernel size.
1003      cin: Number of input channels.
1004      cout: Number of output channels.
1005    Returns:
1006      An [ksize, ksize, cin, cout] orthogonal kernel.
1007    Raises:
1008      ValueError: If cin > cout.
1009    """
1010    if cin > cout:
1011      raise ValueError("The number of input channels cannot exceed "
1012                       "the number of output channels.")
1013    orth = self._orthogonal_matrix(cout)[0:cin, :]
1014    if ksize == 1:
1015      return array_ops.expand_dims(orth, 0)
1016
1017    p = self._block_orth(self._symmetric_projection(cout))
1018    for _ in range(ksize - 2):
1019      temp = self._block_orth(self._symmetric_projection(cout))
1020      p = self._matrix_conv(p, temp)
1021    for i in range(ksize):
1022      p[i] = math_ops.matmul(orth, p[i])
1023
1024    return self._dict_to_tensor(p, ksize)
1025
1026
1027class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
1028  """Initializer that generates a 3D orthogonal kernel for ConvNets.
1029
1030  The shape of the tensor must have length 5. The number of input
1031  filters must not exceed the number of output filters.
1032  The orthogonality(==isometry) is exact when the inputs are circular padded.
1033  There are finite-width effects with non-circular padding (e.g. zero padding).
1034  See algorithm 1 (Xiao et al., 2018).
1035
1036  Args:
1037    gain: Multiplicative factor to apply to the orthogonal
1038      matrix. Default is 1. The 2-norm of an input is multiplied by a factor of
1039      `gain` after applying this convolution.
1040    seed: A Python integer. Used to create random seeds. See
1041      `tf.set_random_seed` for behavior.
1042    dtype: Default data type, used if no `dtype` argument is provided when
1043      calling the initializer. Only floating point types are supported.
1044
1045  References:
1046      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1047      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1048  """
1049
1050  def __call__(self, shape, dtype=None, partition_info=None):
1051    if dtype is None:
1052      dtype = self.dtype
1053    if len(shape) != 5:
1054      raise ValueError("The tensor to initialize must be five-dimensional")
1055
1056    if shape[-2] > shape[-1]:
1057      raise ValueError("In_filters cannot be greater than out_filters.")
1058
1059    if shape[0] != shape[1] or shape[0] != shape[2]:
1060      raise ValueError("Kernel sizes must be equal.")
1061
1062    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
1063    kernel *= math_ops.cast(self.gain, dtype=dtype)
1064    return kernel
1065
1066  def _dict_to_tensor(self, x, k1, k2, k3):
1067    """Convert a dictionary to a tensor.
1068
1069    Args:
1070      x: A k1 * k2 dictionary.
1071      k1: First dimension of x.
1072      k2: Second dimension of x.
1073      k3: Third dimension of x.
1074    Returns:
1075      A k1 * k2 * k3 tensor.
1076    """
1077
1078    return array_ops.stack([array_ops.stack(
1079        [array_ops.stack([x[i, j, k] for k in range(k3)])
1080         for j in range(k2)]) for i in range(k1)])
1081
1082  def _block_orth(self, p1, p2, p3):
1083    """Construct a 3 x 3 kernel. Used to construct orthgonal kernel.
1084
1085    Args:
1086      p1: A symmetric projection matrix.
1087      p2: A symmetric projection matrix.
1088      p3: A symmetric projection matrix.
1089    Returns:
1090      A 2 x 2 x 2 kernel.
1091    Raises:
1092      ValueError: If the dimensions of p1, p2 and p3 are different.
1093    """
1094    p1_shape = p1.shape.as_list()
1095    if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
1096      raise ValueError("The dimension of the matrices must be the same.")
1097    n = p1_shape[0]
1098    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1099    kernel2x2x2 = {}
1100    def matmul(p1, p2, p3):
1101      return math_ops.matmul(math_ops.matmul(p1, p2), p3)
1102    def cast(i, p):
1103      """Return p or (1-p)."""
1104      return i * p + (1-i) * (eye - p)
1105    for i in [0, 1]:
1106      for j in [0, 1]:
1107        for k in [0, 1]:
1108          kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3))
1109    return kernel2x2x2
1110
1111  def _matrix_conv(self, m1, m2):
1112    """Matrix convolution.
1113
1114    Args:
1115      m1: is a k x k x k  dictionary, each element is a n x n matrix.
1116      m2: is a l x l x l dictionary, each element is a n x n matrix.
1117
1118    Returns:
1119      (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each
1120      element is a n x n matrix.
1121    Raises:
1122      ValueError: if the entries of m1 and m2 are of different dimensions.
1123    """
1124
1125    n = (m1[0, 0, 0]).shape.as_list()[0]
1126    if n != (m2[0, 0, 0]).shape.as_list()[0]:
1127      raise ValueError("The entries in matrices m1 and m2 "
1128                       "must have the same dimensions!")
1129    k = int(np.cbrt(len(m1)))
1130    l = int(np.cbrt(len(m2)))
1131    result = {}
1132    size = k + l - 1
1133    # Compute matrix convolution between m1 and m2.
1134    for i in range(size):
1135      for j in range(size):
1136        for r in range(size):
1137          result[i, j, r] = array_ops.zeros([n, n], self.dtype)
1138          for index1 in range(min(k, i + 1)):
1139            for index2 in range(min(k, j + 1)):
1140              for index3 in range(min(k, r + 1)):
1141                if (i - index1) < l and (j - index2) < l and (r - index3) < l:
1142                  result[i, j, r] += math_ops.matmul(m1[index1, index2, index3],
1143                                                     m2[i - index1, j - index2,
1144                                                        r - index3])
1145    return result
1146
1147  def _orthogonal_kernel(self, ksize, cin, cout):
1148    """Construct orthogonal kernel for convolution.
1149
1150    Args:
1151      ksize: Kernel size.
1152      cin: Number of input channels.
1153      cout: Number of output channels.
1154    Returns:
1155      An [ksize, ksize, ksize, cin, cout] orthogonal kernel.
1156    Raises:
1157      ValueError: If cin > cout.
1158    """
1159    if cin > cout:
1160      raise ValueError("The number of input channels cannot exceed "
1161                       "the number of output channels.")
1162    orth = self._orthogonal_matrix(cout)[0:cin, :]
1163    if ksize == 1:
1164      return array_ops.expand_dims(
1165          array_ops.expand_dims(
1166              array_ops.expand_dims(orth, 0), 0), 0)
1167
1168    p = self._block_orth(self._symmetric_projection(cout),
1169                         self._symmetric_projection(cout),
1170                         self._symmetric_projection(cout))
1171    for _ in range(ksize - 2):
1172      temp = self._block_orth(self._symmetric_projection(cout),
1173                              self._symmetric_projection(cout),
1174                              self._symmetric_projection(cout))
1175      p = self._matrix_conv(p, temp)
1176    for i in range(ksize):
1177      for j in range(ksize):
1178        for k in range(ksize):
1179          p[i, j, k] = math_ops.matmul(orth, p[i, j, k])
1180
1181    return self._dict_to_tensor(p, ksize, ksize, ksize)
1182
1183
1184@tf_export(v1=["initializers.identity"])
1185@deprecation.deprecated_endpoints("initializers.identity")
1186class Identity(Initializer):
1187  """Initializer that generates the identity matrix.
1188
1189  Only use for 2D matrices.
1190
1191  Args:
1192    gain: Multiplicative factor to apply to the identity matrix.
1193    dtype: Default data type, used if no `dtype` argument is provided when
1194      calling the initializer. Only floating point types are supported.
1195  """
1196
1197  @deprecated_args(None,
1198                   "Call initializer instance with the dtype argument instead "
1199                   "of passing it to the constructor",
1200                   "dtype")
1201  def __init__(self, gain=1.0, dtype=dtypes.float32):
1202    self.gain = gain
1203    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1204
1205  def __call__(self, shape, dtype=None, partition_info=None):
1206    full_shape = shape if partition_info is None else partition_info.full_shape
1207    if len(full_shape) != 2:
1208      raise ValueError(
1209          "Identity matrix initializer can only be used for 2D matrices.")
1210    if dtype is None:
1211      dtype = self.dtype
1212    if isinstance(full_shape, tensor_shape.TensorShape):
1213      full_shape = full_shape.as_list()
1214    initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
1215    if partition_info is not None:
1216      initializer = array_ops.slice(initializer, partition_info.var_offset,
1217                                    shape)
1218    return self.gain * initializer
1219
1220  def get_config(self):
1221    return {"gain": self.gain, "dtype": self.dtype.name}
1222
1223
1224@tf_export(v1=["glorot_uniform_initializer", "initializers.glorot_uniform"])
1225@deprecation.deprecated_endpoints("glorot_uniform_initializer",
1226                                  "initializers.glorot_uniform")
1227class GlorotUniform(VarianceScaling):
1228  """The Glorot uniform initializer, also called Xavier uniform initializer.
1229
1230  It draws samples from a uniform distribution within [-limit, limit]
1231  where `limit` is `sqrt(6 / (fan_in + fan_out))`
1232  where `fan_in` is the number of input units in the weight tensor
1233  and `fan_out` is the number of output units in the weight tensor.
1234
1235  Args:
1236    seed: A Python integer. Used to create random seeds. See
1237      `tf.set_random_seed`
1238      for behavior.
1239    dtype: Default data type, used if no `dtype` argument is provided when
1240      calling the initializer. Only floating point types are supported.
1241
1242  References:
1243      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1244      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1245  """
1246
1247  @deprecated_args(None,
1248                   "Call initializer instance with the dtype argument instead "
1249                   "of passing it to the constructor",
1250                   "dtype")
1251  def __init__(self, seed=None, dtype=dtypes.float32):
1252    super(GlorotUniform, self).__init__(
1253        scale=1.0,
1254        mode="fan_avg",
1255        distribution="uniform",
1256        seed=seed,
1257        dtype=dtype)
1258
1259  def get_config(self):
1260    return {"seed": self.seed, "dtype": self.dtype.name}
1261
1262
1263@tf_export(v1=["glorot_normal_initializer", "initializers.glorot_normal"])
1264@deprecation.deprecated_endpoints("glorot_normal_initializer",
1265                                  "initializers.glorot_normal")
1266class GlorotNormal(VarianceScaling):
1267  """The Glorot normal initializer, also called Xavier normal initializer.
1268
1269  It draws samples from a truncated normal distribution centered on 0
1270  with standard deviation (after truncation) given by
1271  `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number
1272  of input units in the weight tensor and `fan_out` is the number of
1273  output units in the weight tensor.
1274
1275  Args:
1276    seed: A Python integer. Used to create random seeds. See
1277      `tf.set_random_seed` for behavior.
1278    dtype: Default data type, used if no `dtype` argument is provided when
1279      calling the initializer. Only floating point types are supported.
1280
1281  References:
1282      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1283      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1284  """
1285
1286  @deprecated_args(None,
1287                   "Call initializer instance with the dtype argument instead "
1288                   "of passing it to the constructor",
1289                   "dtype")
1290  def __init__(self, seed=None, dtype=dtypes.float32):
1291    super(GlorotNormal, self).__init__(
1292        scale=1.0,
1293        mode="fan_avg",
1294        distribution="truncated_normal",
1295        seed=seed,
1296        dtype=dtype)
1297
1298  def get_config(self):
1299    return {"seed": self.seed, "dtype": self.dtype.name}
1300
1301
1302# Aliases.
1303
1304# pylint: disable=invalid-name
1305zeros_initializer = Zeros
1306ones_initializer = Ones
1307constant_initializer = Constant
1308random_uniform_initializer = RandomUniform
1309random_normal_initializer = RandomNormal
1310truncated_normal_initializer = TruncatedNormal
1311uniform_unit_scaling_initializer = UniformUnitScaling
1312variance_scaling_initializer = VarianceScaling
1313glorot_uniform_initializer = GlorotUniform
1314glorot_normal_initializer = GlorotNormal
1315orthogonal_initializer = Orthogonal
1316identity_initializer = Identity
1317convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
1318convolutional_orthogonal_1d = ConvolutionOrthogonal1D
1319convolutional_orthogonal_2d = ConvolutionOrthogonal2D
1320convolutional_orthogonal_3d = ConvolutionOrthogonal3D
1321# pylint: enable=invalid-name
1322
1323
1324@tf_export(v1=["initializers.lecun_normal"])
1325def lecun_normal(seed=None):
1326  """LeCun normal initializer.
1327
1328  It draws samples from a truncated normal distribution centered on 0
1329  with standard deviation (after truncation) given by
1330  `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of
1331  input units in the weight tensor.
1332
1333  Arguments:
1334      seed: A Python integer. Used to seed the random generator.
1335
1336  Returns:
1337      An initializer.
1338
1339  References:
1340      - Self-Normalizing Neural Networks,
1341      [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)  # pylint: disable=line-too-long
1342      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1343      - Efficient Backprop,
1344      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1345  """
1346  return VarianceScaling(
1347      scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
1348
1349
1350@tf_export(v1=["initializers.lecun_uniform"])
1351def lecun_uniform(seed=None):
1352  """LeCun uniform initializer.
1353
1354  It draws samples from a uniform distribution within [-limit, limit]
1355  where `limit` is `sqrt(3 / fan_in)`
1356  where `fan_in` is the number of input units in the weight tensor.
1357
1358  Arguments:
1359      seed: A Python integer. Used to seed the random generator.
1360
1361  Returns:
1362      An initializer.
1363
1364  References:
1365      - Self-Normalizing Neural Networks,
1366      [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)  # pylint: disable=line-too-long
1367      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1368      - Efficient Backprop,
1369      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1370  """
1371  return VarianceScaling(
1372      scale=1., mode="fan_in", distribution="uniform", seed=seed)
1373
1374
1375@tf_export(v1=["initializers.he_normal"])
1376def he_normal(seed=None):
1377  """He normal initializer.
1378
1379  It draws samples from a truncated normal distribution centered on 0
1380  with standard deviation (after truncation) given by
1381  `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of
1382  input units in the weight tensor.
1383
1384  Arguments:
1385      seed: A Python integer. Used to seed the random generator.
1386
1387  Returns:
1388      An initializer.
1389
1390  References:
1391      [He et al., 2015]
1392      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)  # pylint: disable=line-too-long
1393      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1394  """
1395  return VarianceScaling(
1396      scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
1397
1398
1399@tf_export(v1=["initializers.he_uniform"])
1400def he_uniform(seed=None):
1401  """He uniform variance scaling initializer.
1402
1403  It draws samples from a uniform distribution within [-limit, limit]
1404  where `limit` is `sqrt(6 / fan_in)`
1405  where `fan_in` is the number of input units in the weight tensor.
1406
1407  Arguments:
1408      seed: A Python integer. Used to seed the random generator.
1409
1410  Returns:
1411      An initializer.
1412
1413  References:
1414      [He et al., 2015]
1415      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)  # pylint: disable=line-too-long
1416      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1417  """
1418  return VarianceScaling(
1419      scale=2., mode="fan_in", distribution="uniform", seed=seed)
1420
1421
1422# Utility functions.
1423
1424
1425def _compute_fans(shape):
1426  """Computes the number of input and output units for a weight shape.
1427
1428  Args:
1429    shape: Integer shape tuple or TF tensor shape.
1430
1431  Returns:
1432    A tuple of integer scalars (fan_in, fan_out).
1433  """
1434  if len(shape) < 1:  # Just to avoid errors for constants.
1435    fan_in = fan_out = 1
1436  elif len(shape) == 1:
1437    fan_in = fan_out = shape[0]
1438  elif len(shape) == 2:
1439    fan_in = shape[0]
1440    fan_out = shape[1]
1441  else:
1442    # Assuming convolution kernels (2D, 3D, or more).
1443    # kernel shape: (..., input_depth, depth)
1444    receptive_field_size = 1
1445    for dim in shape[:-2]:
1446      receptive_field_size *= dim
1447    fan_in = shape[-2] * receptive_field_size
1448    fan_out = shape[-1] * receptive_field_size
1449  return int(fan_in), int(fan_out)
1450
1451
1452def _assert_float_dtype(dtype):
1453  """Validate and return floating point type based on `dtype`.
1454
1455  `dtype` must be a floating point type.
1456
1457  Args:
1458    dtype: The data type to validate.
1459
1460  Returns:
1461    Validated type.
1462
1463  Raises:
1464    ValueError: if `dtype` is not a floating point type.
1465  """
1466  if not dtype.is_floating:
1467    raise ValueError("Expected floating point type, got %s." % dtype)
1468  return dtype
1469