1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Optimizer utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import central_storage_strategy
22from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
23from tensorflow.python.distribute import reduce_util as ds_reduce_util
24from tensorflow.python.ops import clip_ops
25from tensorflow.python.platform import tf_logging as logging
26
27
28def all_reduce_sum_gradients(grads_and_vars):
29  """Returns all-reduced gradients aggregated via summation.
30
31  Args:
32    grads_and_vars: List of (gradient, variable) pairs.
33
34  Returns:
35    List of (gradient, variable) pairs where gradients have been all-reduced.
36  """
37  grads_and_vars = list(grads_and_vars)
38  filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
39  # We switch to a cross-replica context since there is a bug which causes
40  # IndexedSlices to be converted to dense tensors when all-reduced in a
41  # replica context.
42  # TODO(b/150507409): Do not switch to a cross-replica context once the bug
43  # is fixed.
44  if filtered_grads_and_vars:
45    reduced = distribute_ctx.get_replica_context().merge_call(
46        _all_reduce_sum_fn, args=(filtered_grads_and_vars,))
47  else:
48    reduced = []
49  # Copy 'reduced' but add None gradients back in
50  reduced_with_nones = []
51  reduced_pos = 0
52  for g, v in grads_and_vars:
53    if g is None:
54      reduced_with_nones.append((None, v))
55    else:
56      reduced_with_nones.append((reduced[reduced_pos], v))
57      reduced_pos += 1
58  assert reduced_pos == len(reduced), "Failed to add all gradients"
59  return reduced_with_nones
60
61
62def filter_empty_gradients(grads_and_vars):
63  """Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
64  grads_and_vars = tuple(grads_and_vars)
65  if not grads_and_vars:
66    return grads_and_vars
67
68  filtered = []
69  vars_with_empty_grads = []
70  for grad, var in grads_and_vars:
71    if grad is None:
72      vars_with_empty_grads.append(var)
73    else:
74      filtered.append((grad, var))
75  filtered = tuple(filtered)
76
77  if not filtered:
78    raise ValueError("No gradients provided for any variable: %s." %
79                     ([v.name for _, v in grads_and_vars],))
80  if vars_with_empty_grads:
81    logging.warning(
82        ("Gradients do not exist for variables %s when minimizing the loss."),
83        ([v.name for v in vars_with_empty_grads]))
84  return filtered
85
86
87def make_gradient_clipnorm_fn(clipnorm):
88  """Creates a gradient transformation function for clipping by norm."""
89  if clipnorm is None:
90    return lambda grads_and_vars: grads_and_vars
91
92  def gradient_clipnorm_fn(grads_and_vars):
93
94    if isinstance(distribute_ctx.get_strategy(),
95                  (central_storage_strategy.CentralStorageStrategy,
96                   central_storage_strategy.CentralStorageStrategyV1)):
97      raise ValueError(
98          "`clipnorm` is not supported with `CenteralStorageStrategy`")
99
100    clipped_grads_and_vars = [
101        (clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars
102    ]
103    return clipped_grads_and_vars
104
105  return gradient_clipnorm_fn
106
107
108def make_global_gradient_clipnorm_fn(clipnorm):
109  """Creates a gradient transformation function for clipping by norm."""
110  if clipnorm is None:
111    return lambda grads_and_vars: grads_and_vars
112
113  def gradient_clipnorm_fn(grads_and_vars):
114
115    if isinstance(distribute_ctx.get_strategy(),
116                  (central_storage_strategy.CentralStorageStrategy,
117                   central_storage_strategy.CentralStorageStrategyV1)):
118      raise ValueError(
119          "`global_clipnorm` is not supported with `CenteralStorageStrategy`")
120
121    grads, variables = zip(*grads_and_vars)
122    clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm)
123    clipped_grads_and_vars = list(zip(clipped_grads, variables))
124    return clipped_grads_and_vars
125
126  return gradient_clipnorm_fn
127
128
129def make_gradient_clipvalue_fn(clipvalue):
130  """Creates a gradient transformation function for clipping by value."""
131  if clipvalue is None:
132    return lambda grads_and_vars: grads_and_vars
133
134  def gradient_clipvalue_fn(grads_and_vars):
135
136    if isinstance(distribute_ctx.get_strategy(),
137                  (central_storage_strategy.CentralStorageStrategy,
138                   central_storage_strategy.CentralStorageStrategyV1)):
139      raise ValueError(
140          "`clipvalue` is not supported with `CenteralStorageStrategy`")
141
142    clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue,
143                                                      clipvalue), v)
144                              for g, v in grads_and_vars]
145    return clipped_grads_and_vars
146
147  return gradient_clipvalue_fn
148
149
150def _all_reduce_sum_fn(distribution, grads_and_vars):
151  return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM,
152                                               grads_and_vars)
153