1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Operations for clipping (gradient, weight) tensors to min/max values."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import six
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_nn_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.util.tf_export import tf_export
32
33
34@tf_export("clip_by_value")
35def clip_by_value(t, clip_value_min, clip_value_max,
36                  name=None):
37  """Clips tensor values to a specified min and max.
38
39  Given a tensor `t`, this operation returns a tensor of the same type and
40  shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
41  Any values less than `clip_value_min` are set to `clip_value_min`. Any values
42  greater than `clip_value_max` are set to `clip_value_max`.
43
44  Args:
45    t: A `Tensor`.
46    clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
47      as `t`. The minimum value to clip by.
48    clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
49      as `t`. The maximum value to clip by.
50    name: A name for the operation (optional).
51
52  Returns:
53    A clipped `Tensor`.
54
55  Raises:
56    ValueError: if the clip tensors would trigger array broadcasting
57      that would make the returned tensor larger than the input.
58  """
59  with ops.name_scope(name, "clip_by_value",
60                      [t, clip_value_min, clip_value_max]) as name:
61    t = ops.convert_to_tensor(t, name="t")
62
63    # Go through list of tensors, for each value in each tensor clip
64    t_min = math_ops.minimum(t, clip_value_max)
65    # Assert that the shape is compatible with the initial shape,
66    # to prevent unintentional broadcasting.
67    _ = t.shape.merge_with(t_min.shape)
68
69    t_max = math_ops.maximum(t_min, clip_value_min, name=name)
70    _ = t.shape.merge_with(t_max.shape)
71
72  return t_max
73
74
75@tf_export("clip_by_norm")
76def clip_by_norm(t, clip_norm, axes=None, name=None):
77  """Clips tensor values to a maximum L2-norm.
78
79  Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
80  normalizes `t` so that its L2-norm is less than or equal to `clip_norm`,
81  along the dimensions given in `axes`. Specifically, in the default case
82  where all dimensions are used for calculation, if the L2-norm of `t` is
83  already less than or equal to `clip_norm`, then `t` is not modified. If
84  the L2-norm is greater than `clip_norm`, then this operation returns a
85  tensor of the same type and shape as `t` with its values set to:
86
87  `t * clip_norm / l2norm(t)`
88
89  In this case, the L2-norm of the output tensor is `clip_norm`.
90
91  As another example, if `t` is a matrix and `axes == [1]`, then each row
92  of the output will have L2-norm equal to `clip_norm`. If `axes == [0]`
93  instead, each column of the output will be clipped.
94
95  This operation is typically used to clip gradients before applying them with
96  an optimizer.
97
98  Args:
99    t: A `Tensor`.
100    clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
101    axes: A 1-D (vector) `Tensor` of type int32 containing the dimensions
102      to use for computing the L2-norm. If `None` (the default), uses all
103      dimensions.
104    name: A name for the operation (optional).
105
106  Returns:
107    A clipped `Tensor`.
108  """
109  with ops.name_scope(name, "clip_by_norm", [t, clip_norm]) as name:
110    t = ops.convert_to_tensor(t, name="t")
111
112    # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
113    l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True))
114    intermediate = t * clip_norm
115    # Assert that the shape is compatible with the initial shape,
116    # to prevent unintentional broadcasting.
117    _ = t.shape.merge_with(intermediate.shape)
118    tclip = array_ops.identity(
119        intermediate / math_ops.maximum(l2norm, clip_norm), name=name)
120
121  return tclip
122
123
124@tf_export("global_norm")
125def global_norm(t_list, name=None):
126  """Computes the global norm of multiple tensors.
127
128  Given a tuple or list of tensors `t_list`, this operation returns the
129  global norm of the elements in all tensors in `t_list`. The global norm is
130  computed as:
131
132  `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
133
134  Any entries in `t_list` that are of type None are ignored.
135
136  Args:
137    t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
138    name: A name for the operation (optional).
139
140  Returns:
141    A 0-D (scalar) `Tensor` of type `float`.
142
143  Raises:
144    TypeError: If `t_list` is not a sequence.
145  """
146  if (not isinstance(t_list, collections.Sequence)
147      or isinstance(t_list, six.string_types)):
148    raise TypeError("t_list should be a sequence")
149  t_list = list(t_list)
150  with ops.name_scope(name, "global_norm", t_list) as name:
151    values = [
152        ops.convert_to_tensor(
153            t.values if isinstance(t, ops.IndexedSlices) else t,
154            name="t_%d" % i)
155        if t is not None else t
156        for i, t in enumerate(t_list)]
157    half_squared_norms = []
158    for v in values:
159      if v is not None:
160        with ops.colocate_with(v):
161          half_squared_norms.append(gen_nn_ops.l2_loss(v))
162
163    half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms))
164
165    norm = math_ops.sqrt(
166        half_squared_norm *
167        constant_op.constant(2.0, dtype=half_squared_norm.dtype),
168        name="global_norm")
169
170  return norm
171
172
173@tf_export("clip_by_global_norm")
174def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
175  """Clips values of multiple tensors by the ratio of the sum of their norms.
176
177  Given a tuple or list of tensors `t_list`, and a clipping ratio `clip_norm`,
178  this operation returns a list of clipped tensors `list_clipped`
179  and the global norm (`global_norm`) of all tensors in `t_list`. Optionally,
180  if you've already computed the global norm for `t_list`, you can specify
181  the global norm with `use_norm`.
182
183  To perform the clipping, the values `t_list[i]` are set to:
184
185      t_list[i] * clip_norm / max(global_norm, clip_norm)
186
187  where:
188
189      global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))
190
191  If `clip_norm > global_norm` then the entries in `t_list` remain as they are,
192  otherwise they're all shrunk by the global ratio.
193
194  Any of the entries of `t_list` that are of type `None` are ignored.
195
196  This is the correct way to perform gradient clipping (for example, see
197  [Pascanu et al., 2012](http://arxiv.org/abs/1211.5063)
198  ([pdf](http://arxiv.org/pdf/1211.5063.pdf))).
199
200  However, it is slower than `clip_by_norm()` because all the parameters must be
201  ready before the clipping operation can be performed.
202
203  Args:
204    t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
205    clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio.
206    use_norm: A 0-D (scalar) `Tensor` of type `float` (optional). The global
207      norm to use. If not provided, `global_norm()` is used to compute the norm.
208    name: A name for the operation (optional).
209
210  Returns:
211    list_clipped: A list of `Tensors` of the same type as `list_t`.
212    global_norm: A 0-D (scalar) `Tensor` representing the global norm.
213
214  Raises:
215    TypeError: If `t_list` is not a sequence.
216  """
217  if (not isinstance(t_list, collections.Sequence)
218      or isinstance(t_list, six.string_types)):
219    raise TypeError("t_list should be a sequence")
220  t_list = list(t_list)
221  if use_norm is None:
222    use_norm = global_norm(t_list, name)
223
224  with ops.name_scope(name, "clip_by_global_norm",
225                      t_list + [clip_norm]) as name:
226    # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
227    scale = clip_norm * math_ops.minimum(
228        1.0 / use_norm,
229        constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)
230
231    values = [
232        ops.convert_to_tensor(
233            t.values if isinstance(t, ops.IndexedSlices) else t,
234            name="t_%d" % i)
235        if t is not None else t
236        for i, t in enumerate(t_list)]
237
238    values_clipped = []
239    for i, v in enumerate(values):
240      if v is None:
241        values_clipped.append(None)
242      else:
243        with ops.colocate_with(v):
244          values_clipped.append(
245              array_ops.identity(v * scale, name="%s_%d" % (name, i)))
246
247    list_clipped = [
248        ops.IndexedSlices(c_v, t.indices, t.dense_shape)
249        if isinstance(t, ops.IndexedSlices)
250        else c_v
251        for (c_v, t) in zip(values_clipped, t_list)]
252
253  return list_clipped, use_norm
254
255
256@tf_export("clip_by_average_norm")
257def clip_by_average_norm(t, clip_norm, name=None):
258  """Clips tensor values to a maximum average L2-norm.
259
260  Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
261  normalizes `t` so that its average L2-norm is less than or equal to
262  `clip_norm`. Specifically, if the average L2-norm is already less than or
263  equal to `clip_norm`, then `t` is not modified. If the average L2-norm is
264  greater than `clip_norm`, then this operation returns a tensor of the same
265  type and shape as `t` with its values set to:
266
267  `t * clip_norm / l2norm_avg(t)`
268
269  In this case, the average L2-norm of the output tensor is `clip_norm`.
270
271  This operation is typically used to clip gradients before applying them with
272  an optimizer.
273
274  Args:
275    t: A `Tensor`.
276    clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
277    name: A name for the operation (optional).
278
279  Returns:
280    A clipped `Tensor`.
281  """
282  with ops.name_scope(name, "clip_by_average_norm", [t, clip_norm]) as name:
283    t = ops.convert_to_tensor(t, name="t")
284
285    # Calculate L2-norm per element, clip elements by ratio of clip_norm to
286    # L2-norm per element
287    n_element = math_ops.cast(array_ops.size(t), dtypes.float32)
288    l2norm_inv = math_ops.rsqrt(
289        math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
290    tclip = array_ops.identity(
291        t * clip_norm * math_ops.minimum(
292            l2norm_inv * n_element, constant_op.constant(1.0) / clip_norm),
293        name=name)
294
295  return tclip
296