1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Optimizer that implements cross-shard gradient reduction for TPU."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23from tensorflow.python.framework import ops
24from tensorflow.python.ops.losses import losses
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.tpu import tpu_function
27from tensorflow.python.tpu.ops import tpu_ops
28from tensorflow.python.training import optimizer
29
30
31class CrossShardOptimizer(optimizer.Optimizer):
32  """An optimizer that averages gradients across TPU shards."""
33
34  def __init__(self,
35               opt,
36               reduction=losses.Reduction.MEAN,
37               name="CrossShardOptimizer",
38               group_assignment=None):
39    """Construct a new cross-shard optimizer.
40
41    Args:
42      opt: An existing `Optimizer` to encapsulate.
43      reduction: The reduction to apply to the shard losses.
44      name: Optional name prefix for the operations created when applying
45        gradients. Defaults to "CrossShardOptimizer".
46      group_assignment: Optional 2d int32 lists with shape
47        [num_groups, num_replicas_per_group] which describles how to apply
48        optimizer to subgroups.
49
50    Raises:
51      ValueError: If reduction is not a valid cross-shard reduction.
52    """
53    if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
54      raise ValueError("Unsupported reduction: %s." % reduction)
55
56    super(CrossShardOptimizer, self).__init__(False, name)
57    self._opt = opt
58    self._reduction = reduction
59    self._group_assignment = group_assignment
60
61  def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
62    """Verify group_assignment and get the subgroup size".
63
64    Args:
65      group_assignment: list of group ids for applying the optimizer
66        to subgroups.
67      num_shards: The number of TPU shards.
68
69    Returns:
70      The size of one subgroup in group_assignment.
71
72    Raises:
73      ValueError: If group_assignment is invalid.
74    """
75    if not group_assignment:
76      return None
77    if not (isinstance(group_assignment, list) and
78            all(isinstance(i, list) for i in group_assignment)):
79      raise ValueError("group_assignment must be a list of list. Got {}".format(
80          group_assignment))
81
82    replica_ids = set()
83    for g in group_assignment:
84      for i in g:
85        replica_ids.add(i)
86
87    if set(range(num_shards)) != replica_ids:
88      raise ValueError("group_assignment must be a permutation of range({0})."
89                       " Got group_assignment={1}".format(
90                           num_shards, group_assignment))
91
92    subgroup_size_list = [len(group) for group in group_assignment]
93    if all(subgroup_size_list[0] == size for size in subgroup_size_list):
94      return subgroup_size_list[0]
95    else:
96      raise ValueError("The size of each subgroup in group_assignment must "
97                       "be equal. Got group_assignment={}".format(
98                           self._group_assignment))
99
100  def compute_gradients(self, loss, var_list=None, **kwargs):
101    """Compute gradients of "loss" for the variables in "var_list".
102
103    This simply wraps the compute_gradients() from the real optimizer. The
104    gradients will be aggregated in the apply_gradients() so that user can
105    modify the gradients like clipping with per replica global norm if needed.
106    The global norm with aggregated gradients can be bad as one replica's huge
107    gradients can hurt the gradients from other replicas.
108
109    Args:
110      loss: A Tensor containing the value to minimize.
111      var_list: Optional list or tuple of `tf.Variable` to update to minimize
112        `loss`.  Defaults to the list of variables collected in the graph
113        under the key `GraphKey.TRAINABLE_VARIABLES`.
114      **kwargs: Keyword arguments for compute_gradients().
115
116    Returns:
117      A list of (gradient, variable) pairs.
118
119    Raises:
120      ValueError: If not within a tpu_shard_context or group_assignment is
121        invalid.
122    """
123    num_shards = tpu_function.get_tpu_context().number_of_shards
124    if num_shards is None:
125      logging.warning(
126          "CrossShardOptimizer should be used within a tpu_shard_context, but "
127          "got unset number_of_shards. Assuming 1.")
128      num_shards = 1
129
130    subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
131                                                       num_shards)
132
133    if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
134      if self._group_assignment:
135        scale = 1.0 / subgroup_size
136      else:
137        scale = 1.0 / num_shards
138      loss *= scale
139
140    return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
141
142  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
143    """Apply gradients to variables.
144
145    Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
146    replicas, and then applies the real optimizer.
147
148    Args:
149      grads_and_vars: List of (gradient, variable) pairs as returned by
150        compute_gradients().
151      global_step: Optional Variable to increment by one after the
152        variables have been updated.
153      name: Optional name for the returned operation.  Default to the
154        name passed to the Optimizer constructor.
155
156    Returns:
157      An `Operation` that applies the gradients. If `global_step` was not None,
158      that operation also increments `global_step`.
159
160    Raises:
161      ValueError: If the grads_and_vars is malformed.
162    """
163    summed_grads_and_vars = []
164    for (grad, var) in grads_and_vars:
165      if grad is None:
166        summed_grads_and_vars.append((grad, var))
167      else:
168        with ops.colocate_with(grad):
169          summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
170              grad, self._group_assignment), var))
171    return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
172
173  def get_slot(self, *args, **kwargs):
174    """Return a slot named "name" created for "var" by the Optimizer.
175
176    This simply wraps the get_slot() from the actual optimizer.
177
178    Args:
179      *args: Arguments for get_slot().
180      **kwargs: Keyword arguments for get_slot().
181
182    Returns:
183      The `Variable` for the slot if it was created, `None` otherwise.
184    """
185    return self._opt.get_slot(*args, **kwargs)
186
187  def get_slot_names(self, *args, **kwargs):
188    """Return a list of the names of slots created by the `Optimizer`.
189
190    This simply wraps the get_slot_names() from the actual optimizer.
191
192    Args:
193      *args: Arguments for get_slot().
194      **kwargs: Keyword arguments for get_slot().
195
196    Returns:
197      A list of strings.
198    """
199    return self._opt.get_slot_names(*args, **kwargs)
200
201  def variables(self):
202    """Forwarding the variables from the underlying optimizer."""
203    return self._opt.variables()
204