1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Methods for rewriting while_v2 grad functions with IndexedSlices output."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import func_graph
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_resource_variable_ops
28from tensorflow.python.util import nest
31def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
32                                forward_inputs):
33  """Handles special case of IndexedSlices returned from while gradient.
35  Some gradient functions return IndexedSlices instead of a Tensor (e.g. the
36  gradient of Gather ops). When this happens in the gradient of a while body,
37  the resulting gradient body function will have mismatched inputs and outputs,
38  since the input is a single Tensor, but the IndexedSlices gets unnested into
39  three output Tensors.
41  This function fixes this by rewriting the gradient body to have three inputs
42  to match the three outputs, i.e., it effectively converts the input Tensor
43  into an input IndexedSlices. It also returns new `loop_vars` to reflect the
44  new inputs.
46  Args:
47    grads: the input gradient Tensors to the while gradient computation.
48    body_grad_graph: _WhileBodyGradFuncGraph.
49    loop_vars: list of Tensors. The inputs to body_grad_graph.
50    forward_inputs: list of Tensors. The (flat) inputs to the forward-pass
51      While op.
53  Returns:
54    The new loop_vars to pass to body_grad_graph.
55  """
56  # Match up body_grad_graph.structured_outputs with the corresponding
57  # forward_inputs.
58  #
59  # Note that we don't expect a gradient computation to have structured output
60  # (e.g. no nested lists), so no need to flatten
61  # body_grad_graph.structured_outputs. However, structured_outputs may still
62  # contain composite tensors such as IndexedSlices, unlike
63  # body_grad_graph.outputs, which contains flattened composite tensors.
64  inputs_with_grads = [t for g, t in zip(grads, forward_inputs)
65                       if g is not None]
66  # Skip loop counter, maximum_iterations and total number of loop iterations.
67  structured_outputs = body_grad_graph.structured_outputs[3:]
69  for forward_input, output in zip(inputs_with_grads, structured_outputs):
70    if not isinstance(output, ops.IndexedSlices): continue
72    if forward_input.dtype == dtypes.resource:
73      # TODO(skyewm): In theory we should use this for all captured inputs, not
74      # just resource handles (which can only be captured). We can do this by
75      # checking that forward_input is passed straight through to its output.
76      loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
77                                                   forward_input, loop_vars)
78    else:
79      _rewrite_output_as_tensor(body_grad_graph, output)
81  return loop_vars
84def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
85  """Rewrites grad_output_slices to be a Tensor output.
87  Args:
88    body_grad_graph: _WhileBodyGradFuncGraph.
89    grad_output_slices: IndexedSlices output of body_grad_graph.
90  """
91  with body_grad_graph.as_default():
92    new_output = ops.convert_to_tensor_v2(grad_output_slices)
94  idx = body_grad_graph.structured_outputs.index(grad_output_slices)
95  body_grad_graph.structured_outputs[idx] = new_output
96  body_grad_graph.outputs = func_graph.flatten(
97      body_grad_graph.structured_outputs)
100def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
101                                     forward_input, loop_vars):
102  """Rewrites grad_output_slices's corresponding input to be an IndexedSlices.
104  This rewrite requires that forward_input was captured in the forward loop,
105  i.e. is not a user-specified loop variable. This is important because the
106  rewrite assumes that forward_input is passed through to its corresponding
107  output unchanged. This assumption is used in _rewrite_input_as_indexed_slices,
108  which depends on the exact gradient structure produced by the input's fanout.
110  This can yield a more efficient computation than using
111  _rewrite_output_as_tensor, since it preserves the IndexedSlices structure
112  instead of converting the IndexedSlices to a dense Tensor.
114  Args:
115    body_grad_graph: _WhileBodyGradFuncGraph.
116    grad_output_slices: IndexedSlices output of body_grad_graph.
117    forward_input: the corresonding Tensor input to the forward loop.
118    loop_vars: list of Tensors. The inputs to body_grad_graph.
120  Returns:
121    The new loop_vars to pass to body_grad_graph.
122  """
123  # Create initial IndexedSlices that will be the input to the grad While
124  # op. This will start as zeros, and accumulate the IndexedSlices grad output.
125  # Note that because forward_input is captured and not a loop var, its incoming
126  # gradient should always be zero.
127  init_slices = _create_grad_indexed_slices_init(grad_output_slices,
128                                                 forward_input)
130  # Create a new version of grad_output_slices's gradient computation that uses
131  # the new IndexedSlices input instead of the original Tensor input. We'll
132  # return the new computation and leave the old computation as dead code.
133  # TODO(skyewm): considering pruning body_grad_graph to remove the old
134  # computation.
135  with body_grad_graph.as_default():
136    input_slices = ops.IndexedSlices(
137        values=body_grad_graph.capture(init_slices.values, whitelisted=True),
138        indices=body_grad_graph.capture(init_slices.indices, whitelisted=True),
139        dense_shape=body_grad_graph.capture(init_slices.dense_shape,
140                                            whitelisted=True))
142    # Remove the captured tensors from the function inputs. We'll add them back
143    # at the correct index in _update_indexed_slices_param.
144    for t in _flatten(init_slices):
145      captured_t = body_grad_graph.captures.pop(t)
146      body_grad_graph.inputs.remove(captured_t)
148    new_output_slices = _rewrite_grad_indexed_slices_output(grad_output_slices,
149                                                            input_slices)
151  # Update body_grad_graph's inputs and outputs to reflect the new
152  # IndexedSlices computation.
153  return _update_indexed_slices_param(
154      body_grad_graph, loop_vars, init_slices, input_slices, new_output_slices,
155      grad_output_slices)
158def _create_grad_indexed_slices_init(grad_output_slices, forward_input):
159  """Creates an IndexedSlices to pass as input to the while grad function.
161  Args:
162    grad_output_slices: IndexedSlices. The corresponding while grad function
163      output.
164    forward_input: Tensor. The corresonding input to the forward while op.
166  Returns:
167    Zeros IndexedSlices, created in current Graph.
168  """
169  assert isinstance(grad_output_slices, ops.IndexedSlices)
170  assert isinstance(forward_input, ops.Tensor)
171  values_out = grad_output_slices.values
172  indices_out = grad_output_slices.indices
174  # Create the initial values tensor.
175  if values_out.shape.is_fully_defined():
176    values_shape = tensor_shape.TensorShape([0] +
177                                            values_out.shape.as_list()[1:])
178    values = array_ops.zeros(values_shape, dtype=values_out.dtype,
179                             name="values_init")
180  else:
181    if forward_input.dtype == dtypes.resource:
182      forward_shape = gen_resource_variable_ops.variable_shape(forward_input)
183    else:
184      forward_shape = array_ops.shape(forward_input)
185    values_shape = array_ops.concat([[0], forward_shape[1:]], 0)
186    values = array_ops.zeros(values_shape, dtype=values_out.dtype,
187                             name="values_init")
189  # Create the initial indices tensor.
190  indices = constant_op.constant([], indices_out.dtype, name="indices_init")
192  # Create the initial dense_shape tensor. We assume is the same shape as
193  # forward_input, since captured tensors don't change shape across loop
194  # iterations.
195  if forward_input.dtype == dtypes.resource:
196    shape = gen_resource_variable_ops.variable_shape(forward_input,
197                                                     name="shape_init")
198  else:
199    shape = array_ops.shape(forward_input, name="shape_init")
201  return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape)
204def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices):
205  """Creates a new verson of old_output_slices with new_input_slices as input.
207  This method assumes that old_output_slices.{values,indices} are produced by
208  concatenating the incoming gradient Tensor input with the IndexedSlices
209  produced by the gradient computation of the while body. See
210  gradients_impl._AggregateIndexedSlicesGradients for where these concats are
211  constructed. We build new concats that use new_input_slices instead of the
212  original Tensor input.
214  Args:
215    old_output_slices: original IndexedSlices output of while gradient.
216    new_input_slices: new IndexedSlices to use as input to while gradient.
218  Returns:
219    A new IndexedSlices to replace old_output_slices.
220  """
222  def rewrite(old_output, new_input):
223    assert old_output.type == "Identity"
224    concat_op = old_output.inputs[0].op
225    assert concat_op.type == "ConcatV2"
226    # Don't include axis arg
227    old_concat_args = concat_op.inputs[:-1]
228    # We assume that the original gradient input was the first argument to the
229    # concat op.
230    # TODO(skyewm): do this in a more robust way.
231    return array_ops.concat([new_input] + old_concat_args[1:], 0)
233  values = rewrite(old_output_slices.values.op, new_input_slices.values)
234  indices = rewrite(old_output_slices.indices.op, new_input_slices.indices)
235  return ops.IndexedSlices(values=values, indices=indices,
236                           dense_shape=new_input_slices.dense_shape)
239def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
240                                 output_slices, old_output_slices):
241  """Updates graph with new IndexedSlices input/output.
243  Updates graph's metadata to output the gradient computation defined by
244  init_slices, input_slices, and output_slices, instead of outputting
245  old_output_slices. Also returns a new version of loop_vars with init_slices
246  replacing the old input.
248  Args:
249    graph: _WhileBodyGradFuncGraph.
250    loop_vars: the inputs to graph.
251    init_slices: the new IndexedSlices to use as input to graph.
252    input_slices: the new IndexedSlices in graph that should be fed by
253      init_slices.
254    output_slices: the new IndexedSlices in graph that should be the
255      corresonding output to input_slices.
256    old_output_slices: the IndexedSlices in graph that are currently
257      being output.
259  Returns:
260    New loop_vars to pass to graph.
261  """
262  structured_idx = graph.structured_outputs.index(old_output_slices)
263  # We assume that the component tensors of old_output_slices appear
264  # sequentially in graph.outputs. We use the first of these tensors
265  # as the reference index.
266  flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0])
268  graph.structured_outputs[structured_idx] = output_slices
269  graph.outputs = func_graph.flatten(
270      graph.structured_outputs)
272  graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) +
273                  graph.inputs[flat_idx + 1:])
275  return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
278def _flatten(arg):
279  return nest.flatten(arg, expand_composites=True)