1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Initializers for TF 2."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import math
21
22import numpy as np
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_linalg_ops
28from tensorflow.python.ops import linalg_ops_impl
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import random_ops
31from tensorflow.python.ops import stateless_random_ops
32from tensorflow.python.ops.init_ops import _compute_fans
33from tensorflow.python.util.tf_export import tf_export
34
35_PARTITION_SHAPE = "partition_shape"
36_PARTITION_OFFSET = "partition_offset"
37
38
39class Initializer(object):
40  """Initializer base class: all initializers inherit from this class.
41
42  Initializers should implement a `__call__` method with the following
43  signature:
44
45  ```python
46  def __call__(self, shape, dtype=None, **kwargs):
47    # returns a tensor of shape `shape` and dtype `dtype`
48    # containing values drawn from a distribution of your choice.
49  ```
50  """
51
52  def __call__(self, shape, dtype=None, **kwargs):
53    """Returns a tensor object initialized as specified by the initializer.
54
55    Args:
56      shape: Shape of the tensor.
57      dtype: Optional dtype of the tensor. If not provided will return tensor
58        of `tf.float32`.
59      **kwargs: Additional keyword arguments. Accepted values:
60        `partition_shape` and `partition_offset`. Used when creating a single
61        partition in a partitioned variable. `partition_shape` is the shape of
62        the partition (i.e. the shape of the returned tensor) and
63        `partition_offset` is a tuple of `int` specifying the offset of this
64        partition w.r.t each axis. For example, a tensor of shape `(30, 100)`
65        can be partitioned into two partitions: `p0` of shape `(10, 100)` and
66        `p1` of shape `(20, 100)`; if the initializer is called with
67        `partition_shape=(20, 100)` and `partition_offset=(10, 0)`, it should
68        return the value for `p1`.
69    """
70    raise NotImplementedError
71
72  def get_config(self):
73    """Returns the configuration of the initializer as a JSON-serializable dict.
74
75    Returns:
76      A JSON-serializable Python dict.
77    """
78    return {}
79
80  @classmethod
81  def from_config(cls, config):
82    """Instantiates an initializer from a configuration dictionary.
83
84    Example:
85
86    ```python
87    initializer = RandomUniform(-1, 1)
88    config = initializer.get_config()
89    initializer = RandomUniform.from_config(config)
90    ```
91
92    Args:
93      config: A Python dictionary.
94        It will typically be the output of `get_config`.
95
96    Returns:
97      An Initializer instance.
98    """
99    config.pop("dtype", None)
100    return cls(**config)
101
102  def _validate_kwargs(self, kwargs, support_partition=True):
103    for kwarg in kwargs:
104      if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
105        raise TypeError("Unknown keyword arguments: %s" % kwarg)
106      elif not support_partition:
107        raise ValueError("%s initializer doesn't support partition-related"
108                         " arguments" % self.__class__.__name__)
109
110
111@tf_export("zeros_initializer", v1=[])
112class Zeros(Initializer):
113  """Initializer that generates tensors initialized to 0.
114
115  Initializers allow you to pre-specify an initialization strategy, encoded in
116  the Initializer object, without knowing the shape and dtype of the variable
117  being initialized.
118
119  Examples:
120
121  >>> def make_variables(k, initializer):
122  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
123  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
124  >>> v1, v2 = make_variables(3, tf.zeros_initializer())
125  >>> v1
126  <tf.Variable ... shape=(3,) ... numpy=array([0., 0., 0.], dtype=float32)>
127  >>> v2
128  <tf.Variable ... shape=(3, 3) ... numpy=
129  array([[0., 0., 0.],
130         [0., 0., 0.],
131         [0., 0., 0.]], dtype=float32)>
132  >>> make_variables(4, tf.random_uniform_initializer(minval=-1., maxval=1.))
133  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
134  """
135
136  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
137    """Returns a tensor object initialized as specified by the initializer.
138
139    Args:
140      shape: Shape of the tensor.
141      dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
142       supported.
143      **kwargs: Additional keyword arguments.
144
145    Raises:
146      ValuesError: If the dtype is not numeric or boolean.
147    """
148    self._validate_kwargs(kwargs)
149    dtype = dtypes.as_dtype(dtype)
150    if not dtype.is_numpy_compatible or dtype == dtypes.string:
151      raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
152    if _PARTITION_SHAPE in kwargs:
153      shape = kwargs[_PARTITION_SHAPE]
154    return array_ops.zeros(shape, dtype)
155
156
157@tf_export("ones_initializer", v1=[])
158class Ones(Initializer):
159  """Initializer that generates tensors initialized to 1.
160
161  Initializers allow you to pre-specify an initialization strategy, encoded in
162  the Initializer object, without knowing the shape and dtype of the variable
163  being initialized.
164
165  Examples:
166
167  >>> def make_variables(k, initializer):
168  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
169  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
170  >>> v1, v2 = make_variables(3, tf.ones_initializer())
171  >>> v1
172  <tf.Variable ... shape=(3,) ... numpy=array([1., 1., 1.], dtype=float32)>
173  >>> v2
174  <tf.Variable ... shape=(3, 3) ... numpy=
175  array([[1., 1., 1.],
176         [1., 1., 1.],
177         [1., 1., 1.]], dtype=float32)>
178  >>> make_variables(4, tf.random_uniform_initializer(minval=-1., maxval=1.))
179  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
180  """
181
182  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
183    """Returns a tensor object initialized as specified by the initializer.
184
185    Args:
186      shape: Shape of the tensor.
187      dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
188        supported.
189      **kwargs: Additional keyword arguments.
190
191    Raises:
192      ValuesError: If the dtype is not numeric or boolean.
193    """
194    self._validate_kwargs(kwargs)
195    dtype = dtypes.as_dtype(dtype)
196    if not dtype.is_numpy_compatible or dtype == dtypes.string:
197      raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
198    if _PARTITION_SHAPE in kwargs:
199      shape = kwargs[_PARTITION_SHAPE]
200    return array_ops.ones(shape, dtype)
201
202
203@tf_export("constant_initializer", v1=[])
204class Constant(Initializer):
205  """Initializer that generates tensors with constant values.
206
207  Initializers allow you to pre-specify an initialization strategy, encoded in
208  the Initializer object, without knowing the shape and dtype of the variable
209  being initialized.
210
211  `tf.constant_initializer` returns an object which when called returns a tensor
212  populated with the `value` specified in the constructor. This `value` must be
213  convertible to the requested `dtype`.
214
215  The argument `value` can be a scalar constant value, or a list of
216  values. Scalars broadcast to whichever shape is requested from the
217  initializer.
218
219  If `value` is a list, then the length of the list must be equal to the number
220  of elements implied by the desired shape of the tensor. If the total number of
221  elements in `value` is not equal to the number of elements required by the
222  tensor shape, the initializer will raise a `TypeError`.
223
224  Examples:
225
226  >>> def make_variables(k, initializer):
227  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
228  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
229  >>> v1, v2 = make_variables(3, tf.constant_initializer(2.))
230  >>> v1
231  <tf.Variable ... shape=(3,) ... numpy=array([2., 2., 2.], dtype=float32)>
232  >>> v2
233  <tf.Variable ... shape=(3, 3) ... numpy=
234  array([[2., 2., 2.],
235         [2., 2., 2.],
236         [2., 2., 2.]], dtype=float32)>
237  >>> make_variables(4, tf.random_uniform_initializer(minval=-1., maxval=1.))
238  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
239
240  >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
241  >>> init = tf.constant_initializer(value)
242  >>> # Fitting shape
243  >>> tf.Variable(init(shape=[2, 4], dtype=tf.float32))
244  <tf.Variable ...
245  array([[0., 1., 2., 3.],
246         [4., 5., 6., 7.]], dtype=float32)>
247  >>> # Larger shape
248  >>> tf.Variable(init(shape=[3, 4], dtype=tf.float32))
249  Traceback (most recent call last):
250  ...
251  TypeError: ...value has 8 elements, shape is (3, 4) with 12 elements...
252  >>> # Smaller shape
253  >>> tf.Variable(init(shape=[2, 3], dtype=tf.float32))
254  Traceback (most recent call last):
255  ...
256  TypeError: ...value has 8 elements, shape is (2, 3) with 6 elements...
257
258  Args:
259    value: A Python scalar, list or tuple of values, or a N-dimensional numpy
260      array. All elements of the initialized variable will be set to the
261      corresponding value in the `value` argument.
262
263  Raises:
264    TypeError: If the input `value` is not one of the expected types.
265  """
266
267  def __init__(self, value=0):
268    if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
269      raise TypeError(
270          "Invalid type for initial value: %s (expected Python scalar, list or "
271          "tuple of values, or numpy.ndarray)." % type(value))
272    self.value = value
273
274  def __call__(self, shape, dtype=None, **kwargs):
275    """Returns a tensor object initialized as specified by the initializer.
276
277    Args:
278      shape: Shape of the tensor.
279      dtype: Optional dtype of the tensor. If not provided the dtype of the
280        tensor created will be the type of the inital value.
281      **kwargs: Additional keyword arguments.
282
283    Raises:
284      TypeError: If the initializer cannot create a tensor of the requested
285       dtype.
286    """
287    self._validate_kwargs(kwargs, support_partition=False)
288    if dtype is not None:
289      dtype = dtypes.as_dtype(dtype)
290    return constant_op.constant(self.value, dtype=dtype, shape=shape)
291
292  def get_config(self):
293    return {"value": self.value}
294
295
296@tf_export("random_uniform_initializer", v1=[])
297class RandomUniform(Initializer):
298  """Initializer that generates tensors with a uniform distribution.
299
300  Initializers allow you to pre-specify an initialization strategy, encoded in
301  the Initializer object, without knowing the shape and dtype of the variable
302  being initialized.
303
304  Examples:
305
306  >>> def make_variables(k, initializer):
307  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
308  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
309  >>> v1, v2 = make_variables(3, tf.ones_initializer())
310  >>> v1
311  <tf.Variable ... shape=(3,) ... numpy=array([1., 1., 1.], dtype=float32)>
312  >>> v2
313  <tf.Variable ... shape=(3, 3) ... numpy=
314  array([[1., 1., 1.],
315         [1., 1., 1.],
316         [1., 1., 1.]], dtype=float32)>
317  >>> make_variables(4, tf.random_uniform_initializer(minval=-1., maxval=1.))
318  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
319
320  Args:
321    minval: A python scalar or a scalar tensor. Lower bound of the range of
322      random values to generate (inclusive).
323    maxval: A python scalar or a scalar tensor. Upper bound of the range of
324      random values to generate (exclusive).
325    seed: A Python integer. Used to create random seeds. See
326      `tf.random.set_seed` for behavior.
327  """
328
329  def __init__(self, minval=-0.05, maxval=0.05, seed=None):
330    self.minval = minval
331    self.maxval = maxval
332    self.seed = seed
333    self._random_generator = _RandomGenerator(seed)
334
335  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
336    """Returns a tensor object initialized as specified by the initializer.
337
338    Args:
339      shape: Shape of the tensor.
340      dtype: Optional dtype of the tensor. Only floating point and integer
341        types are supported.
342      **kwargs: Additional keyword arguments.
343
344    Raises:
345      ValueError: If the dtype is not numeric.
346    """
347    self._validate_kwargs(kwargs)
348    dtype = dtypes.as_dtype(dtype)
349    if not dtype.is_floating and not dtype.is_integer:
350      raise ValueError("Expected float or integer dtype, got %s." % dtype)
351    if _PARTITION_SHAPE in kwargs:
352      shape = kwargs[_PARTITION_SHAPE]
353    return self._random_generator.random_uniform(shape, self.minval,
354                                                 self.maxval, dtype)
355
356  def get_config(self):
357    return {
358        "minval": self.minval,
359        "maxval": self.maxval,
360        "seed": self.seed
361    }
362
363
364@tf_export("random_normal_initializer", v1=[])
365class RandomNormal(Initializer):
366  """Initializer that generates tensors with a normal distribution.
367
368  Initializers allow you to pre-specify an initialization strategy, encoded in
369  the Initializer object, without knowing the shape and dtype of the variable
370  being initialized.
371
372  Examples:
373
374  >>> def make_variables(k, initializer):
375  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
376  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
377  >>> v1, v2 = make_variables(3,
378  ...                         tf.random_normal_initializer(mean=1., stddev=2.))
379  >>> v1
380  <tf.Variable ... shape=(3,) ... numpy=array([...], dtype=float32)>
381  >>> v2
382  <tf.Variable ... shape=(3, 3) ... numpy=
383  ...
384  >>> make_variables(4, tf.random_uniform_initializer(minval=-1., maxval=1.))
385  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
386
387  Args:
388    mean: a python scalar or a scalar tensor. Mean of the random values to
389      generate.
390    stddev: a python scalar or a scalar tensor. Standard deviation of the random
391      values to generate.
392    seed: A Python integer. Used to create random seeds. See
393      `tf.random.set_seed` for behavior.
394
395  """
396
397  def __init__(self, mean=0.0, stddev=0.05, seed=None):
398    self.mean = mean
399    self.stddev = stddev
400    self.seed = seed
401    self._random_generator = _RandomGenerator(seed)
402
403  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
404    """Returns a tensor object initialized as specified by the initializer.
405
406    Args:
407      shape: Shape of the tensor.
408      dtype: Optional dtype of the tensor. Only floating point types are
409        supported.
410      **kwargs: Additional keyword arguments.
411
412    Raises:
413      ValueError: If the dtype is not floating point
414    """
415    self._validate_kwargs(kwargs)
416    dtype = _assert_float_dtype(dtype)
417    if _PARTITION_SHAPE in kwargs:
418      shape = kwargs[_PARTITION_SHAPE]
419    return self._random_generator.random_normal(shape, self.mean, self.stddev,
420                                                dtype)
421
422  def get_config(self):
423    return {
424        "mean": self.mean,
425        "stddev": self.stddev,
426        "seed": self.seed
427    }
428
429
430class TruncatedNormal(Initializer):
431  """Initializer that generates a truncated normal distribution.
432
433  Initializers allow you to pre-specify an initialization strategy, encoded in
434  the Initializer object, without knowing the shape and dtype of the variable
435  being initialized.
436
437  These values are similar to values from a `tf.initializers.RandomNormal`
438  except that values more than two standard deviations from the mean are
439  discarded and re-drawn. This is the recommended initializer for neural network
440  weights and filters.
441
442  Examples:
443
444  >>> def make_variables(k, initializer):
445  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
446  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
447  >>> v1, v2 = make_variables(
448  ...     3, tf.initializers.TruncatedNormal(mean=1., stddev=2.))
449  >>> v1
450  <tf.Variable ... shape=(3,) ... numpy=array([...], dtype=float32)>
451  >>> v2
452  <tf.Variable ... shape=(3, 3) ... numpy=
453  ...
454  >>> make_variables(4, tf.initializers.RandomUniform(minval=-1., maxval=1.))
455  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
456
457  Args:
458    mean: a python scalar or a scalar tensor. Mean of the random values
459      to generate.
460    stddev: a python scalar or a scalar tensor. Standard deviation of the
461      random values to generate.
462    seed: A Python integer. Used to create random seeds. See
463      `tf.random.set_seed` for behavior.
464  """
465
466  def __init__(self, mean=0.0, stddev=0.05, seed=None):
467    self.mean = mean
468    self.stddev = stddev
469    self.seed = seed
470    self._random_generator = _RandomGenerator(seed)
471
472  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
473    """Returns a tensor object initialized as specified by the initializer.
474
475    Args:
476      shape: Shape of the tensor.
477      dtype: Optional dtype of the tensor. Only floating point types are
478        supported.
479      **kwargs: Additional keyword arguments.
480
481    Raises:
482      ValueError: If the dtype is not floating point
483    """
484    self._validate_kwargs(kwargs)
485    dtype = _assert_float_dtype(dtype)
486    if _PARTITION_SHAPE in kwargs:
487      shape = kwargs[_PARTITION_SHAPE]
488    return self._random_generator.truncated_normal(shape, self.mean,
489                                                   self.stddev, dtype)
490
491  def get_config(self):
492    return {
493        "mean": self.mean,
494        "stddev": self.stddev,
495        "seed": self.seed
496    }
497
498
499class VarianceScaling(Initializer):
500  """Initializer capable of adapting its scale to the shape of weights tensors.
501
502  Initializers allow you to pre-specify an initialization strategy, encoded in
503  the Initializer object, without knowing the shape and dtype of the variable
504  being initialized.
505
506  With `distribution="truncated_normal" or "untruncated_normal"`, samples are
507  drawn from a truncated/untruncated normal distribution with a mean of zero and
508  a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)`
509  where n is:
510
511    - number of input units in the weight tensor, if mode = "fan_in"
512    - number of output units, if mode = "fan_out"
513    - average of the numbers of input and output units, if mode = "fan_avg"
514
515  With `distribution="uniform"`, samples are drawn from a uniform distribution
516  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
517
518  Examples:
519
520  >>> def make_variables(k, initializer):
521  ...   return (tf.Variable(initializer(shape=[k], dtype=tf.float32)),
522  ...           tf.Variable(initializer(shape=[k, k], dtype=tf.float32)))
523  >>> v1, v2 = make_variables(3, tf.initializers.VarianceScaling(scale=1.))
524  >>> v1
525  <tf.Variable ... shape=(3,) ... numpy=array([...], dtype=float32)>
526  >>> v2
527  <tf.Variable ... shape=(3, 3) ... numpy=
528  ...
529  >>> make_variables(4, tf.initializers.VarianceScaling(distribution='uniform'))
530  (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
531
532  Args:
533    scale: Scaling factor (positive float).
534    mode: One of "fan_in", "fan_out", "fan_avg".
535    distribution: Random distribution to use. One of "truncated_normal",
536      "untruncated_normal" and  "uniform".
537    seed: A Python integer. Used to create random seeds. See
538      `tf.random.set_seed` for behavior.
539
540  Raises:
541    ValueError: In case of an invalid value for the "scale", mode" or
542      "distribution" arguments.
543  """
544
545  def __init__(self,
546               scale=1.0,
547               mode="fan_in",
548               distribution="truncated_normal",
549               seed=None):
550    if scale <= 0.:
551      raise ValueError("`scale` must be positive float.")
552    if mode not in {"fan_in", "fan_out", "fan_avg"}:
553      raise ValueError("Invalid `mode` argument:", mode)
554    distribution = distribution.lower()
555    # Compatibility with keras-team/keras.
556    if distribution == "normal":
557      distribution = "truncated_normal"
558    if distribution not in {"uniform", "truncated_normal",
559                            "untruncated_normal"}:
560      raise ValueError("Invalid `distribution` argument:", distribution)
561    self.scale = scale
562    self.mode = mode
563    self.distribution = distribution
564    self.seed = seed
565    self._random_generator = _RandomGenerator(seed)
566
567  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
568    """Returns a tensor object initialized as specified by the initializer.
569
570    Args:
571      shape: Shape of the tensor.
572      dtype: Optional dtype of the tensor. Only floating point types are
573        supported.
574      **kwargs: Additional keyword arguments.
575
576    Raises:
577      ValueError: If the dtype is not floating point
578    """
579    self._validate_kwargs(kwargs)
580    dtype = _assert_float_dtype(dtype)
581    scale = self.scale
582    fan_in, fan_out = _compute_fans(shape)
583    if _PARTITION_SHAPE in kwargs:
584      shape = kwargs[_PARTITION_SHAPE]
585    if self.mode == "fan_in":
586      scale /= max(1., fan_in)
587    elif self.mode == "fan_out":
588      scale /= max(1., fan_out)
589    else:
590      scale /= max(1., (fan_in + fan_out) / 2.)
591    if self.distribution == "truncated_normal":
592      # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
593      stddev = math.sqrt(scale) / .87962566103423978
594      return self._random_generator.truncated_normal(shape, 0.0, stddev, dtype)
595    elif self.distribution == "untruncated_normal":
596      stddev = math.sqrt(scale)
597      return self._random_generator.random_normal(shape, 0.0, stddev, dtype)
598    else:
599      limit = math.sqrt(3.0 * scale)
600      return self._random_generator.random_uniform(shape, -limit, limit, dtype)
601
602  def get_config(self):
603    return {
604        "scale": self.scale,
605        "mode": self.mode,
606        "distribution": self.distribution,
607        "seed": self.seed
608    }
609
610
611class Orthogonal(Initializer):
612  """Initializer that generates an orthogonal matrix.
613
614  Initializers allow you to pre-specify an initialization strategy, encoded in
615  the Initializer object, without knowing the shape and dtype of the variable
616  being initialized.
617
618  If the shape of the tensor to initialize is two-dimensional, it is initialized
619  with an orthogonal matrix obtained from the QR decomposition of a matrix of
620  random numbers drawn from a normal distribution.
621  If the matrix has fewer rows than columns then the output will have orthogonal
622  rows. Otherwise, the output will have orthogonal columns.
623
624  If the shape of the tensor to initialize is more than two-dimensional,
625  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
626  is initialized, where `n` is the length of the shape vector.
627  The matrix is subsequently reshaped to give a tensor of the desired shape.
628
629  Examples:
630
631  >>> def make_variables(k, initializer):
632  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
633  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
634  >>> v1, v2 = make_variables(3, tf.initializers.Orthogonal())
635  >>> v1
636  <tf.Variable ... shape=(3, 3) ...
637  >>> v2
638  <tf.Variable ... shape=(3, 3, 3) ...
639  >>> make_variables(4, tf.initializers.Orthogonal(gain=0.5))
640  (<tf.Variable ... shape=(4, 4) dtype=float32...
641   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
642
643  Args:
644    gain: multiplicative factor to apply to the orthogonal matrix
645    seed: A Python integer. Used to create random seeds. See
646      `tf.random.set_seed` for behavior.
647
648  References:
649      [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
650      ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
651  """
652
653  def __init__(self, gain=1.0, seed=None):
654    self.gain = gain
655    self.seed = seed
656    self._random_generator = _RandomGenerator(seed)
657
658  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
659    """Returns a tensor object initialized as specified by the initializer.
660
661    Args:
662      shape: Shape of the tensor.
663      dtype: Optional dtype of the tensor. Only floating point types are
664        supported.
665      **kwargs: Additional keyword arguments.
666
667    Raises:
668      ValueError: If the dtype is not floating point or the input shape is not
669       valid.
670    """
671    self._validate_kwargs(kwargs, support_partition=False)
672    dtype = _assert_float_dtype(dtype)
673    # Check the shape
674    if len(shape) < 2:
675      raise ValueError("The tensor to initialize must be "
676                       "at least two-dimensional")
677    # Flatten the input shape with the last dimension remaining
678    # its original shape so it works for conv2d
679    num_rows = 1
680    for dim in shape[:-1]:
681      num_rows *= dim
682    num_cols = shape[-1]
683    flat_shape = (max(num_cols, num_rows), min(num_cols, num_rows))
684
685    # Generate a random matrix
686    a = self._random_generator.random_normal(flat_shape, dtype=dtype)
687    # Compute the qr factorization
688    q, r = gen_linalg_ops.qr(a, full_matrices=False)
689    # Make Q uniform
690    d = array_ops.diag_part(r)
691    q *= math_ops.sign(d)
692    if num_rows < num_cols:
693      q = array_ops.matrix_transpose(q)
694    return self.gain * array_ops.reshape(q, shape)
695
696  def get_config(self):
697    return {"gain": self.gain, "seed": self.seed}
698
699
700class Identity(Initializer):
701  """Initializer that generates the identity matrix.
702
703  Initializers allow you to pre-specify an initialization strategy, encoded in
704  the Initializer object, without knowing the shape and dtype of the variable
705  being initialized.
706
707  Only usable for generating 2D matrices.
708
709  Examples:
710
711  >>> def make_variable(k, initializer):
712  ...   return tf.Variable(initializer(shape=[k, k], dtype=tf.float32))
713  >>> make_variable(2, tf.initializers.Identity())
714  <tf.Variable ... shape=(2, 2) dtype=float32, numpy=
715  array([[1., 0.],
716         [0., 1.]], dtype=float32)>
717  >>> make_variable(3, tf.initializers.Identity(gain=0.5))
718  <tf.Variable ... shape=(3, 3) dtype=float32, numpy=
719  array([[0.5, 0. , 0. ],
720         [0. , 0.5, 0. ],
721         [0. , 0. , 0.5]], dtype=float32)>
722
723  Args:
724    gain: Multiplicative factor to apply to the identity matrix.
725  """
726
727  def __init__(self, gain=1.0):
728    self.gain = gain
729
730  def __call__(self, shape, dtype=dtypes.float32, **kwargs):
731    """Returns a tensor object initialized as specified by the initializer.
732
733    Args:
734      shape: Shape of the tensor.
735      dtype: Optional dtype of the tensor. Only floating point types are
736       supported.
737      **kwargs: Additional keyword arguments.
738
739    Raises:
740      ValueError: If the dtype is not floating point
741      ValueError: If the requested shape does not have exactly two axes.
742    """
743    self._validate_kwargs(kwargs, support_partition=False)
744    dtype = _assert_float_dtype(dtype)
745    if len(shape) != 2:
746      raise ValueError(
747          "Identity matrix initializer can only be used for 2D matrices.")
748    initializer = linalg_ops_impl.eye(*shape, dtype=dtype)
749    return self.gain * initializer
750
751  def get_config(self):
752    return {"gain": self.gain}
753
754
755class GlorotUniform(VarianceScaling):
756  """The Glorot uniform initializer, also called Xavier uniform initializer.
757
758  Initializers allow you to pre-specify an initialization strategy, encoded in
759  the Initializer object, without knowing the shape and dtype of the variable
760  being initialized.
761
762  Draws samples from a uniform distribution within [-limit, limit] where `limit`
763  is `sqrt(6 / (fan_in + fan_out))` where `fan_in` is the number of input units
764  in the weight tensor and `fan_out` is the number of output units in the weight
765  tensor.
766
767  Examples:
768
769  >>> def make_variables(k, initializer):
770  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
771  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
772  >>> v1, v2 = make_variables(3, tf.initializers.GlorotUniform())
773  >>> v1
774  <tf.Variable ... shape=(3, 3) ...
775  >>> v2
776  <tf.Variable ... shape=(3, 3, 3) ...
777  >>> make_variables(4, tf.initializers.RandomNormal())
778  (<tf.Variable ... shape=(4, 4) dtype=float32...
779   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
780
781  Args:
782    seed: A Python integer. Used to create random seeds. See
783      `tf.random.set_seed` for behavior.
784
785  References:
786      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
787      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
788  """
789
790  def __init__(self, seed=None):
791    super(GlorotUniform, self).__init__(
792        scale=1.0,
793        mode="fan_avg",
794        distribution="uniform",
795        seed=seed)
796
797  def get_config(self):
798    return {"seed": self.seed}
799
800
801class GlorotNormal(VarianceScaling):
802  """The Glorot normal initializer, also called Xavier normal initializer.
803
804  Initializers allow you to pre-specify an initialization strategy, encoded in
805  the Initializer object, without knowing the shape and dtype of the variable
806  being initialized.
807
808  Draws samples from a truncated normal distribution centered on 0 with `stddev
809  = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of input units in
810  the weight tensor and `fan_out` is the number of output units in the weight
811  tensor.
812
813  Examples:
814
815  >>> def make_variables(k, initializer):
816  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
817  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
818  >>> v1, v2 = make_variables(3, tf.initializers.GlorotNormal())
819  >>> v1
820  <tf.Variable ... shape=(3, 3) ...
821  >>> v2
822  <tf.Variable ... shape=(3, 3, 3) ...
823  >>> make_variables(4, tf.initializers.RandomNormal())
824  (<tf.Variable ... shape=(4, 4) dtype=float32...
825   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
826
827  Args:
828    seed: A Python integer. Used to create random seeds. See
829      `tf.random.set_seed` for behavior.
830
831  References:
832      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
833      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
834  """
835
836  def __init__(self, seed=None):
837    super(GlorotNormal, self).__init__(
838        scale=1.0,
839        mode="fan_avg",
840        distribution="truncated_normal",
841        seed=seed)
842
843  def get_config(self):
844    return {"seed": self.seed}
845
846
847# Aliases.
848
849# pylint: disable=invalid-name
850zeros_initializer = Zeros
851ones_initializer = Ones
852constant_initializer = Constant
853random_uniform_initializer = RandomUniform
854random_normal_initializer = RandomNormal
855truncated_normal_initializer = TruncatedNormal
856variance_scaling_initializer = VarianceScaling
857glorot_uniform_initializer = GlorotUniform
858glorot_normal_initializer = GlorotNormal
859orthogonal_initializer = Orthogonal
860identity_initializer = Identity
861# pylint: enable=invalid-name
862
863
864def lecun_normal(seed=None):
865  """LeCun normal initializer.
866
867  Initializers allow you to pre-specify an initialization strategy, encoded in
868  the Initializer object, without knowing the shape and dtype of the variable
869  being initialized.
870
871  Draws samples from a truncated normal distribution centered on 0 with `stddev
872  = sqrt(1 / fan_in)` where `fan_in` is the number of input units in the weight
873  tensor.
874
875  Examples:
876
877  >>> def make_variables(k, initializer):
878  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
879  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
880  >>> v1, v2 = make_variables(3, tf.initializers.lecun_normal())
881  >>> v1
882  <tf.Variable ... shape=(3, 3) ...
883  >>> v2
884  <tf.Variable ... shape=(3, 3, 3) ...
885  >>> make_variables(4, tf.initializers.RandomNormal())
886  (<tf.Variable ... shape=(4, 4) dtype=float32...
887   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
888
889  Args:
890    seed: A Python integer. Used to seed the random generator.
891
892  Returns:
893    A callable Initializer with `shape` and `dtype` arguments which generates a
894    tensor.
895
896  References:
897      - Self-Normalizing Neural Networks,
898      [Klambauer et al., 2017]
899      (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
900      ([pdf]
901      (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
902      - Efficient Backprop,
903      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
904  """
905  return VarianceScaling(
906      scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
907
908
909def lecun_uniform(seed=None):
910  """LeCun uniform initializer.
911
912  Initializers allow you to pre-specify an initialization strategy, encoded in
913  the Initializer object, without knowing the shape and dtype of the variable
914  being initialized.
915
916  Draws samples from a uniform distribution within [-limit, limit] where `limit`
917  is `sqrt(3 / fan_in)` where `fan_in` is the number of input units in the
918  weight tensor.
919
920  Examples:
921
922  >>> def make_variables(k, initializer):
923  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
924  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
925  >>> v1, v2 = make_variables(3, tf.initializers.lecun_uniform())
926  >>> v1
927  <tf.Variable ... shape=(3, 3) ...
928  >>> v2
929  <tf.Variable ... shape=(3, 3, 3) ...
930  >>> make_variables(4, tf.initializers.RandomNormal())
931  (<tf.Variable ... shape=(4, 4) dtype=float32...
932   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
933
934  Args:
935    seed: A Python integer. Used to seed the random generator.
936
937  Returns:
938    A callable Initializer with `shape` and `dtype` arguments which generates a
939    tensor.
940
941  References:
942      - Self-Normalizing Neural Networks,
943      [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) # pylint: disable=line-too-long
944      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
945      - Efficient Backprop,
946      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
947  """
948  return VarianceScaling(
949      scale=1., mode="fan_in", distribution="uniform", seed=seed)
950
951
952def he_normal(seed=None):
953  """He normal initializer.
954
955  Initializers allow you to pre-specify an initialization strategy, encoded in
956  the Initializer object, without knowing the shape and dtype of the variable
957  being initialized.
958
959  It draws samples from a truncated normal distribution centered on 0 with
960  `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in the
961  weight tensor.
962
963  Examples:
964
965  >>> def make_variables(k, initializer):
966  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
967  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
968  >>> v1, v2 = make_variables(3, tf.initializers.he_normal())
969  >>> v1
970  <tf.Variable ... shape=(3, 3) ...
971  >>> v2
972  <tf.Variable ... shape=(3, 3, 3) ...
973  >>> make_variables(4, tf.initializers.RandomNormal())
974  (<tf.Variable ... shape=(4, 4) dtype=float32...
975   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
976
977  Args:
978    seed: A Python integer. Used to seed the random generator.
979
980  Returns:
981    A callable Initializer with `shape` and `dtype` arguments which generates a
982    tensor.
983
984  References:
985      [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) # pylint: disable=line-too-long
986      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
987  """
988  return VarianceScaling(
989      scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
990
991
992def he_uniform(seed=None):
993  """He uniform variance scaling initializer.
994
995  Initializers allow you to pre-specify an initialization strategy, encoded in
996  the Initializer object, without knowing the shape and dtype of the variable
997  being initialized.
998
999  Draws samples from a uniform distribution within [-limit, limit] where `limit`
1000  is `sqrt(6 / fan_in)` where `fan_in` is the number of input units in the
1001  weight tensor.
1002
1003  Examples:
1004
1005  >>> def make_variables(k, initializer):
1006  ...   return (tf.Variable(initializer(shape=[k, k], dtype=tf.float32)),
1007  ...           tf.Variable(initializer(shape=[k, k, k], dtype=tf.float32)))
1008  >>> v1, v2 = make_variables(3, tf.initializers.he_uniform())
1009  >>> v1
1010  <tf.Variable ... shape=(3, 3) ...
1011  >>> v2
1012  <tf.Variable ... shape=(3, 3, 3) ...
1013  >>> make_variables(4, tf.initializers.RandomNormal())
1014  (<tf.Variable ... shape=(4, 4) dtype=float32...
1015   <tf.Variable ... shape=(4, 4, 4) dtype=float32...
1016
1017  Args:
1018    seed: A Python integer. Used to seed the random generator.
1019
1020  Returns:
1021    A callable Initializer with `shape` and `dtype` arguments which generates a
1022    tensor.
1023
1024  References:
1025      [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) # pylint: disable=line-too-long
1026      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1027  """
1028  return VarianceScaling(
1029      scale=2., mode="fan_in", distribution="uniform", seed=seed)
1030
1031
1032# Utility functions.
1033
1034
1035def _assert_float_dtype(dtype):
1036  """Validate and return floating point type based on `dtype`.
1037
1038  `dtype` must be a floating point type.
1039
1040  Args:
1041    dtype: The data type to validate.
1042
1043  Returns:
1044    Validated type.
1045
1046  Raises:
1047    ValueError: if `dtype` is not a floating point type.
1048  """
1049  dtype = dtypes.as_dtype(dtype)
1050  if not dtype.is_floating:
1051    raise ValueError("Expected floating point type, got %s." % dtype)
1052  return dtype
1053
1054
1055class _RandomGenerator(object):
1056  """Random generator that selects appropriate random ops."""
1057
1058  def __init__(self, seed=None):
1059    super(_RandomGenerator, self).__init__()
1060    if seed is not None:
1061      # Stateless random ops requires 2-int seed.
1062      self.seed = [seed, 0]
1063    else:
1064      self.seed = None
1065
1066  def random_normal(self, shape, mean=0.0, stddev=1, dtype=dtypes.float32):
1067    """A deterministic random normal if seed is passed."""
1068    if self.seed:
1069      op = stateless_random_ops.stateless_random_normal
1070    else:
1071      op = random_ops.random_normal
1072    return op(
1073        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
1074
1075  def random_uniform(self, shape, minval, maxval, dtype):
1076    """A deterministic random uniform if seed is passed."""
1077    if self.seed:
1078      op = stateless_random_ops.stateless_random_uniform
1079    else:
1080      op = random_ops.random_uniform
1081    return op(
1082        shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=self.seed)
1083
1084  def truncated_normal(self, shape, mean, stddev, dtype):
1085    """A deterministic truncated normal if seed is passed."""
1086    if self.seed:
1087      op = stateless_random_ops.stateless_truncated_normal
1088    else:
1089      op = random_ops.truncated_normal
1090    return op(
1091        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
1092
1093# Compatibility aliases
1094
1095# pylint: disable=invalid-name
1096zero = zeros = Zeros
1097one = ones = Ones
1098constant = Constant
1099uniform = random_uniform = RandomUniform
1100normal = random_normal = RandomNormal
1101truncated_normal = TruncatedNormal
1102identity = Identity
1103orthogonal = Orthogonal
1104glorot_normal = GlorotNormal
1105glorot_uniform = GlorotUniform
1106