1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Weight initializers for use with layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.ops import random_ops
25
26
27__all__ = ['xavier_initializer', 'xavier_initializer_conv2d',
28           'variance_scaling_initializer']
29
30
31def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
32  """Returns an initializer performing "Xavier" initialization for weights.
33
34  This function implements the weight initialization from:
35
36  Xavier Glorot and Yoshua Bengio (2010):
37           [Understanding the difficulty of training deep feedforward neural
38           networks. International conference on artificial intelligence and
39           statistics.](
40           http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
41
42  This initializer is designed to keep the scale of the gradients roughly the
43  same in all layers. In uniform distribution this ends up being the range:
44  `x = sqrt(6. / (in + out)); [-x, x]` and for normal distribution a standard
45  deviation of `sqrt(2. / (in + out))` is used.
46
47  Args:
48    uniform: Whether to use uniform or normal distributed random initialization.
49    seed: A Python integer. Used to create random seeds. See
50          `tf.set_random_seed` for behavior.
51    dtype: The data type. Only floating point types are supported.
52
53  Returns:
54    An initializer for a weight matrix.
55  """
56  return variance_scaling_initializer(factor=1.0, mode='FAN_AVG',
57                                      uniform=uniform, seed=seed, dtype=dtype)
58
59xavier_initializer_conv2d = xavier_initializer
60
61
62def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
63                                 seed=None, dtype=dtypes.float32):
64  """Returns an initializer that generates tensors without scaling variance.
65
66  When initializing a deep network, it is in principle advantageous to keep
67  the scale of the input variance constant, so it does not explode or diminish
68  by reaching the final layer. This initializer use the following formula:
69
70  ```python
71    if mode='FAN_IN': # Count only number of input connections.
72      n = fan_in
73    elif mode='FAN_OUT': # Count only number of output connections.
74      n = fan_out
75    elif mode='FAN_AVG': # Average number of inputs and output connections.
76      n = (fan_in + fan_out)/2.0
77
78      truncated_normal(shape, 0.0, stddev=sqrt(factor / n))
79  ```
80
81  * To get [Delving Deep into Rectifiers](
82     http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA
83     initialization"), use (Default):<br/>
84    `factor=2.0 mode='FAN_IN' uniform=False`
85  * To get [Convolutional Architecture for Fast Feature Embedding](
86     http://arxiv.org/abs/1408.5093), use:<br/>
87    `factor=1.0 mode='FAN_IN' uniform=True`
88  * To get [Understanding the difficulty of training deep feedforward neural
89    networks](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf),
90    use:<br/>
91    `factor=1.0 mode='FAN_AVG' uniform=True.`
92  * To get `xavier_initializer` use either:<br/>
93    `factor=1.0 mode='FAN_AVG' uniform=True`, or<br/>
94    `factor=1.0 mode='FAN_AVG' uniform=False`.
95
96  Args:
97    factor: Float.  A multiplicative factor.
98    mode: String.  'FAN_IN', 'FAN_OUT', 'FAN_AVG'.
99    uniform: Whether to use uniform or normal distributed random initialization.
100    seed: A Python integer. Used to create random seeds. See
101          `tf.set_random_seed` for behavior.
102    dtype: The data type. Only floating point types are supported.
103
104  Returns:
105    An initializer that generates tensors with unit variance.
106
107  Raises:
108    ValueError: if `dtype` is not a floating point type.
109    TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].
110  """
111  if not dtype.is_floating:
112    raise TypeError('Cannot create initializer for non-floating point type.')
113  if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
114    raise TypeError('Unknown mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
115
116  # pylint: disable=unused-argument
117  def _initializer(shape, dtype=dtype, partition_info=None):
118    """Initializer function."""
119    if not dtype.is_floating:
120      raise TypeError('Cannot create initializer for non-floating point type.')
121    # Estimating fan_in and fan_out is not possible to do perfectly, but we try.
122    # This is the right thing for matrix multiply and convolutions.
123    if shape:
124      fan_in = float(shape[-2]) if len(shape) > 1 else float(shape[-1])
125      fan_out = float(shape[-1])
126    else:
127      fan_in = 1.0
128      fan_out = 1.0
129    for dim in shape[:-2]:
130      fan_in *= float(dim)
131      fan_out *= float(dim)
132    if mode == 'FAN_IN':
133      # Count only number of input connections.
134      n = fan_in
135    elif mode == 'FAN_OUT':
136      # Count only number of output connections.
137      n = fan_out
138    elif mode == 'FAN_AVG':
139      # Average number of inputs and output connections.
140      n = (fan_in + fan_out) / 2.0
141    if uniform:
142      # To get stddev = math.sqrt(factor / n) need to adjust for uniform.
143      limit = math.sqrt(3.0 * factor / n)
144      return random_ops.random_uniform(shape, -limit, limit,
145                                       dtype, seed=seed)
146    else:
147      # To get stddev = math.sqrt(factor / n) need to adjust for truncated.
148      trunc_stddev = math.sqrt(1.3 * factor / n)
149      return random_ops.truncated_normal(shape, 0.0, trunc_stddev, dtype,
150                                         seed=seed)
151  # pylint: enable=unused-argument
152
153  return _initializer
154