1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Methods for rewriting while_v2 grad functions with IndexedSlices output.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import func_graph 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_resource_variable_ops 28from tensorflow.python.util import nest 29 30 31def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, 32 forward_inputs): 33 """Handles special case of IndexedSlices returned from while gradient. 34 35 Some gradient functions return IndexedSlices instead of a Tensor (e.g. the 36 gradient of Gather ops). When this happens in the gradient of a while body, 37 the resulting gradient body function will have mismatched inputs and outputs, 38 since the input is a single Tensor, but the IndexedSlices gets unnested into 39 three output Tensors. 40 41 This function fixes this by rewriting the gradient body to have three inputs 42 to match the three outputs, i.e., it effectively converts the input Tensor 43 into an input IndexedSlices. It also returns new `loop_vars` to reflect the 44 new inputs. 45 46 Args: 47 grads: the input gradient Tensors to the while gradient computation. 48 body_grad_graph: _WhileBodyGradFuncGraph. 49 loop_vars: list of Tensors. The inputs to body_grad_graph. 50 forward_inputs: list of Tensors. The (flat) inputs to the forward-pass 51 While op. 52 53 Returns: 54 The new loop_vars to pass to body_grad_graph. 55 """ 56 # Match up body_grad_graph.structured_outputs with the corresponding 57 # forward_inputs. 58 # 59 # Note that we don't expect a gradient computation to have structured output 60 # (e.g. no nested lists), so no need to flatten 61 # body_grad_graph.structured_outputs. However, structured_outputs may still 62 # contain composite tensors such as IndexedSlices, unlike 63 # body_grad_graph.outputs, which contains flattened composite tensors. 64 inputs_with_grads = [t for g, t in zip(grads, forward_inputs) 65 if g is not None] 66 # Skip loop counter, maximum_iterations and total number of loop iterations. 67 structured_outputs = body_grad_graph.structured_outputs[3:] 68 69 for forward_input, output in zip(inputs_with_grads, structured_outputs): 70 if not isinstance(output, ops.IndexedSlices): continue 71 72 if forward_input.dtype == dtypes.resource: 73 # TODO(skyewm): In theory we should use this for all captured inputs, not 74 # just resource handles (which can only be captured). We can do this by 75 # checking that forward_input is passed straight through to its output. 76 loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output, 77 forward_input, loop_vars) 78 else: 79 _rewrite_output_as_tensor(body_grad_graph, output) 80 81 return loop_vars 82 83 84def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): 85 """Rewrites grad_output_slices to be a Tensor output. 86 87 Args: 88 body_grad_graph: _WhileBodyGradFuncGraph. 89 grad_output_slices: IndexedSlices output of body_grad_graph. 90 """ 91 with body_grad_graph.as_default(): 92 new_output = ops.convert_to_tensor_v2(grad_output_slices) 93 94 idx = body_grad_graph.structured_outputs.index(grad_output_slices) 95 body_grad_graph.structured_outputs[idx] = new_output 96 body_grad_graph.outputs = func_graph.flatten( 97 body_grad_graph.structured_outputs) 98 99 100def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices, 101 forward_input, loop_vars): 102 """Rewrites grad_output_slices's corresponding input to be an IndexedSlices. 103 104 This rewrite requires that forward_input was captured in the forward loop, 105 i.e. is not a user-specified loop variable. This is important because the 106 rewrite assumes that forward_input is passed through to its corresponding 107 output unchanged. This assumption is used in _rewrite_input_as_indexed_slices, 108 which depends on the exact gradient structure produced by the input's fanout. 109 110 This can yield a more efficient computation than using 111 _rewrite_output_as_tensor, since it preserves the IndexedSlices structure 112 instead of converting the IndexedSlices to a dense Tensor. 113 114 Args: 115 body_grad_graph: _WhileBodyGradFuncGraph. 116 grad_output_slices: IndexedSlices output of body_grad_graph. 117 forward_input: the corresonding Tensor input to the forward loop. 118 loop_vars: list of Tensors. The inputs to body_grad_graph. 119 120 Returns: 121 The new loop_vars to pass to body_grad_graph. 122 """ 123 # Create initial IndexedSlices that will be the input to the grad While 124 # op. This will start as zeros, and accumulate the IndexedSlices grad output. 125 # Note that because forward_input is captured and not a loop var, its incoming 126 # gradient should always be zero. 127 init_slices = _create_grad_indexed_slices_init(grad_output_slices, 128 forward_input) 129 130 # Create a new version of grad_output_slices's gradient computation that uses 131 # the new IndexedSlices input instead of the original Tensor input. We'll 132 # return the new computation and leave the old computation as dead code. 133 # TODO(skyewm): considering pruning body_grad_graph to remove the old 134 # computation. 135 with body_grad_graph.as_default(): 136 input_slices = ops.IndexedSlices( 137 values=body_grad_graph.capture(init_slices.values, whitelisted=True), 138 indices=body_grad_graph.capture(init_slices.indices, whitelisted=True), 139 dense_shape=body_grad_graph.capture(init_slices.dense_shape, 140 whitelisted=True)) 141 142 # Remove the captured tensors from the function inputs. We'll add them back 143 # at the correct index in _update_indexed_slices_param. 144 for t in _flatten(init_slices): 145 captured_t = body_grad_graph.captures.pop(t) 146 body_grad_graph.inputs.remove(captured_t) 147 148 new_output_slices = _rewrite_grad_indexed_slices_output(grad_output_slices, 149 input_slices) 150 151 # Update body_grad_graph's inputs and outputs to reflect the new 152 # IndexedSlices computation. 153 return _update_indexed_slices_param( 154 body_grad_graph, loop_vars, init_slices, input_slices, new_output_slices, 155 grad_output_slices) 156 157 158def _create_grad_indexed_slices_init(grad_output_slices, forward_input): 159 """Creates an IndexedSlices to pass as input to the while grad function. 160 161 Args: 162 grad_output_slices: IndexedSlices. The corresponding while grad function 163 output. 164 forward_input: Tensor. The corresonding input to the forward while op. 165 166 Returns: 167 Zeros IndexedSlices, created in current Graph. 168 """ 169 assert isinstance(grad_output_slices, ops.IndexedSlices) 170 assert isinstance(forward_input, ops.Tensor) 171 values_out = grad_output_slices.values 172 indices_out = grad_output_slices.indices 173 174 # Create the initial values tensor. 175 if values_out.shape.is_fully_defined(): 176 values_shape = tensor_shape.TensorShape([0] + 177 values_out.shape.as_list()[1:]) 178 values = array_ops.zeros(values_shape, dtype=values_out.dtype, 179 name="values_init") 180 else: 181 if forward_input.dtype == dtypes.resource: 182 forward_shape = gen_resource_variable_ops.variable_shape(forward_input) 183 else: 184 forward_shape = array_ops.shape(forward_input) 185 values_shape = array_ops.concat([[0], forward_shape[1:]], 0) 186 values = array_ops.zeros(values_shape, dtype=values_out.dtype, 187 name="values_init") 188 189 # Create the initial indices tensor. 190 indices = constant_op.constant([], indices_out.dtype, name="indices_init") 191 192 # Create the initial dense_shape tensor. We assume is the same shape as 193 # forward_input, since captured tensors don't change shape across loop 194 # iterations. 195 if forward_input.dtype == dtypes.resource: 196 shape = gen_resource_variable_ops.variable_shape(forward_input, 197 name="shape_init") 198 else: 199 shape = array_ops.shape(forward_input, name="shape_init") 200 201 return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape) 202 203 204def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices): 205 """Creates a new verson of old_output_slices with new_input_slices as input. 206 207 This method assumes that old_output_slices.{values,indices} are produced by 208 concatenating the incoming gradient Tensor input with the IndexedSlices 209 produced by the gradient computation of the while body. See 210 gradients_impl._AggregateIndexedSlicesGradients for where these concats are 211 constructed. We build new concats that use new_input_slices instead of the 212 original Tensor input. 213 214 Args: 215 old_output_slices: original IndexedSlices output of while gradient. 216 new_input_slices: new IndexedSlices to use as input to while gradient. 217 218 Returns: 219 A new IndexedSlices to replace old_output_slices. 220 """ 221 222 def rewrite(old_output, new_input): 223 assert old_output.type == "Identity" 224 concat_op = old_output.inputs[0].op 225 assert concat_op.type == "ConcatV2" 226 # Don't include axis arg 227 old_concat_args = concat_op.inputs[:-1] 228 # We assume that the original gradient input was the first argument to the 229 # concat op. 230 # TODO(skyewm): do this in a more robust way. 231 return array_ops.concat([new_input] + old_concat_args[1:], 0) 232 233 values = rewrite(old_output_slices.values.op, new_input_slices.values) 234 indices = rewrite(old_output_slices.indices.op, new_input_slices.indices) 235 return ops.IndexedSlices(values=values, indices=indices, 236 dense_shape=new_input_slices.dense_shape) 237 238 239def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, 240 output_slices, old_output_slices): 241 """Updates graph with new IndexedSlices input/output. 242 243 Updates graph's metadata to output the gradient computation defined by 244 init_slices, input_slices, and output_slices, instead of outputting 245 old_output_slices. Also returns a new version of loop_vars with init_slices 246 replacing the old input. 247 248 Args: 249 graph: _WhileBodyGradFuncGraph. 250 loop_vars: the inputs to graph. 251 init_slices: the new IndexedSlices to use as input to graph. 252 input_slices: the new IndexedSlices in graph that should be fed by 253 init_slices. 254 output_slices: the new IndexedSlices in graph that should be the 255 corresonding output to input_slices. 256 old_output_slices: the IndexedSlices in graph that are currently 257 being output. 258 259 Returns: 260 New loop_vars to pass to graph. 261 """ 262 structured_idx = graph.structured_outputs.index(old_output_slices) 263 # We assume that the component tensors of old_output_slices appear 264 # sequentially in graph.outputs. We use the first of these tensors 265 # as the reference index. 266 flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0]) 267 268 graph.structured_outputs[structured_idx] = output_slices 269 graph.outputs = func_graph.flatten( 270 graph.structured_outputs) 271 272 graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) + 273 graph.inputs[flat_idx + 1:]) 274 275 return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:] 276 277 278def _flatten(arg): 279 return nest.flatten(arg, expand_composites=True) 280