1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Optimizer that implements cross-shard gradient reduction for TPU.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23from tensorflow.python.framework import ops 24from tensorflow.python.ops.losses import losses 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.tpu import tpu_function 27from tensorflow.python.tpu.ops import tpu_ops 28from tensorflow.python.training import optimizer 29 30 31class CrossShardOptimizer(optimizer.Optimizer): 32 """An optimizer that averages gradients across TPU shards.""" 33 34 def __init__(self, 35 opt, 36 reduction=losses.Reduction.MEAN, 37 name="CrossShardOptimizer", 38 group_assignment=None): 39 """Construct a new cross-shard optimizer. 40 41 Args: 42 opt: An existing `Optimizer` to encapsulate. 43 reduction: The reduction to apply to the shard losses. 44 name: Optional name prefix for the operations created when applying 45 gradients. Defaults to "CrossShardOptimizer". 46 group_assignment: Optional 2d int32 lists with shape 47 [num_groups, num_replicas_per_group] which describles how to apply 48 optimizer to subgroups. 49 50 Raises: 51 ValueError: If reduction is not a valid cross-shard reduction. 52 """ 53 if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN): 54 raise ValueError("Unsupported reduction: %s." % reduction) 55 56 super(CrossShardOptimizer, self).__init__(False, name) 57 self._opt = opt 58 self._reduction = reduction 59 self._group_assignment = group_assignment 60 61 def _verify_and_get_subgroup_size(self, group_assignment, num_shards): 62 """Verify group_assignment and get the subgroup size". 63 64 Args: 65 group_assignment: list of group ids for applying the optimizer 66 to subgroups. 67 num_shards: The number of TPU shards. 68 69 Returns: 70 The size of one subgroup in group_assignment. 71 72 Raises: 73 ValueError: If group_assignment is invalid. 74 """ 75 if not group_assignment: 76 return None 77 if not (isinstance(group_assignment, list) and 78 all(isinstance(i, list) for i in group_assignment)): 79 raise ValueError("group_assignment must be a list of list. Got {}".format( 80 group_assignment)) 81 82 replica_ids = set() 83 for g in group_assignment: 84 for i in g: 85 replica_ids.add(i) 86 87 if set(range(num_shards)) != replica_ids: 88 raise ValueError("group_assignment must be a permutation of range({0})." 89 " Got group_assignment={1}".format( 90 num_shards, group_assignment)) 91 92 subgroup_size_list = [len(group) for group in group_assignment] 93 if all(subgroup_size_list[0] == size for size in subgroup_size_list): 94 return subgroup_size_list[0] 95 else: 96 raise ValueError("The size of each subgroup in group_assignment must " 97 "be equal. Got group_assignment={}".format( 98 self._group_assignment)) 99 100 def compute_gradients(self, loss, var_list=None, **kwargs): 101 """Compute gradients of "loss" for the variables in "var_list". 102 103 This simply wraps the compute_gradients() from the real optimizer. The 104 gradients will be aggregated in the apply_gradients() so that user can 105 modify the gradients like clipping with per replica global norm if needed. 106 The global norm with aggregated gradients can be bad as one replica's huge 107 gradients can hurt the gradients from other replicas. 108 109 Args: 110 loss: A Tensor containing the value to minimize. 111 var_list: Optional list or tuple of `tf.Variable` to update to minimize 112 `loss`. Defaults to the list of variables collected in the graph 113 under the key `GraphKey.TRAINABLE_VARIABLES`. 114 **kwargs: Keyword arguments for compute_gradients(). 115 116 Returns: 117 A list of (gradient, variable) pairs. 118 119 Raises: 120 ValueError: If not within a tpu_shard_context or group_assignment is 121 invalid. 122 """ 123 num_shards = tpu_function.get_tpu_context().number_of_shards 124 if num_shards is None: 125 logging.warning( 126 "CrossShardOptimizer should be used within a tpu_shard_context, but " 127 "got unset number_of_shards. Assuming 1.") 128 num_shards = 1 129 130 subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, 131 num_shards) 132 133 if num_shards > 1 and self._reduction == losses.Reduction.MEAN: 134 if self._group_assignment: 135 scale = 1.0 / subgroup_size 136 else: 137 scale = 1.0 / num_shards 138 loss *= scale 139 140 return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) 141 142 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 143 """Apply gradients to variables. 144 145 Calls tpu_ops.cross_replica_sum() to sum gradient contributions across 146 replicas, and then applies the real optimizer. 147 148 Args: 149 grads_and_vars: List of (gradient, variable) pairs as returned by 150 compute_gradients(). 151 global_step: Optional Variable to increment by one after the 152 variables have been updated. 153 name: Optional name for the returned operation. Default to the 154 name passed to the Optimizer constructor. 155 156 Returns: 157 An `Operation` that applies the gradients. If `global_step` was not None, 158 that operation also increments `global_step`. 159 160 Raises: 161 ValueError: If the grads_and_vars is malformed. 162 """ 163 summed_grads_and_vars = [] 164 for (grad, var) in grads_and_vars: 165 if grad is None: 166 summed_grads_and_vars.append((grad, var)) 167 else: 168 with ops.colocate_with(grad): 169 summed_grads_and_vars.append((tpu_ops.cross_replica_sum( 170 grad, self._group_assignment), var)) 171 return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) 172 173 def get_slot(self, *args, **kwargs): 174 """Return a slot named "name" created for "var" by the Optimizer. 175 176 This simply wraps the get_slot() from the actual optimizer. 177 178 Args: 179 *args: Arguments for get_slot(). 180 **kwargs: Keyword arguments for get_slot(). 181 182 Returns: 183 The `Variable` for the slot if it was created, `None` otherwise. 184 """ 185 return self._opt.get_slot(*args, **kwargs) 186 187 def get_slot_names(self, *args, **kwargs): 188 """Return a list of the names of slots created by the `Optimizer`. 189 190 This simply wraps the get_slot_names() from the actual optimizer. 191 192 Args: 193 *args: Arguments for get_slot(). 194 **kwargs: Keyword arguments for get_slot(). 195 196 Returns: 197 A list of strings. 198 """ 199 return self._opt.get_slot_names(*args, **kwargs) 200 201 def variables(self): 202 """Forwarding the variables from the underlying optimizer.""" 203 return self._opt.variables() 204