1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in array_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import pywrap_tensorflow
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_util
30from tensorflow.python.ops import gen_array_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import sparse_ops
33
34
35@ops.RegisterGradient("Pack")
36def _PackGrad(op, grad):
37  """Gradient for pack op."""
38  return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
39
40
41@ops.RegisterGradient("Unpack")
42def _UnpackGrad(op, *grads):
43  """Gradient for unpack op."""
44  return array_ops.stack(grads, axis=op.get_attr("axis"))
45
46
47def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
48  """Gradient for concat op.
49
50  Args:
51    op: An operation.
52    grad: `Tensor` or `IndexedSlices` representing the gradients with respect
53      to each output of the op.
54    start_value_index: An integer index of the first value in the op.inputs.
55    end_value_index: An integer index of the last value in the op.inputs.
56    dim_index: An interger index of concat_dim or axis parameter in op.inputs.
57
58  Returns:
59    Tensors representing the partial gradients with respect to each input
60    of the op.
61
62  Raises:
63    ValueError: if concat_dim/axis is not statically known.
64  """
65
66  def _CreateDenseMaskAndBegin(sizes, concat_dim):
67    """Create variables for iteratively slicing a dense gradients tensor."""
68    # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
69    shape_of_shape = array_ops.shape(sizes[0])
70    # Make a vector of length equal to the input's dimensions,
71    # with 0's everywhere and 1 in the concat dim position.
72    # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
73    mask = array_ops.concat([
74        array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1],
75        array_ops.fill(shape_of_shape - concat_dim - 1, 0)
76    ], 0)
77    begin = array_ops.fill(shape_of_shape, 0)
78    return mask, begin
79
80  def _ExtractInputShapes(inputs):
81    """Extract the shapes of a set of input tensors."""
82    if context.executing_eagerly():
83      return array_ops.shape_n(inputs)
84    sizes = []
85    fully_known = True
86    for x in inputs:
87      input_shape = array_ops.shape(x)
88      if not isinstance(input_shape,
89                        ops.Tensor) or input_shape.op.type != "Const":
90        fully_known = False
91        break
92      sizes.append(input_shape)
93
94    if fully_known:
95      return sizes
96    else:
97      return array_ops.shape_n(inputs)
98
99  # Degenerate concatenation, just return grad.
100  if len(op.inputs) == 2:
101    return grad + [None] if end_value_index <= dim_index else [None] + grad
102
103  concat_dim = op.inputs[dim_index]
104  input_values = op.inputs[start_value_index:end_value_index]
105
106  out_grads = []
107  if isinstance(grad, ops.Tensor):
108    if context.executing_eagerly():
109      # Using mod here for convenience since concat_dim is already verified
110      # in concat implementation to be within the allowed [-rank, rank) range.
111      non_neg_concat_dim = (
112          concat_dim._numpy().item(0) % input_values[0]._rank())  # pylint: disable=protected-access
113      # All inputs are guaranteed to be EagerTensors in eager mode
114      sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
115                                                        non_neg_concat_dim)
116      out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
117    else:
118      if constant_op.is_constant(concat_dim):
119        # If concat_dim is a constant defined in a different context,
120        # then we duplicate it in the current context to avoid passing it
121        # through an Enter node.
122        # This is a small optimization in general, but it is required when
123        # compiling with XLA, as XLA needs the concat input to be folded into a
124        # constant.
125        grad_context = control_flow_util.GetOutputContext(grad.op)
126        dim_context = control_flow_util.GetOutputContext(concat_dim.op)
127        if dim_context != grad_context:
128          value = tensor_util.constant_value(concat_dim)
129          concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
130
131      # Using mod here for convenience since concat_dim is already verified
132      # in concat implementation to be within the allowed [-rank, rank) range.
133      non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
134
135      # Get the inputs' tensor shapes
136      sizes = _ExtractInputShapes(input_values)
137      # The magic number of 16 was found through benchmarking a range of sizes
138      # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
139      # cases when switching implementations at N=16, but it is possible that
140      # there will be a small number of performance regressions.
141      if len(sizes) > 16:
142        # extract the size of each input along the concat dimension
143        sizes = array_ops.squeeze(
144            array_ops.slice(
145                array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
146                [1, -1]))
147        out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
148      else:
149        offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
150        for (begin, size) in zip(offset, sizes):
151          out_grads.append(array_ops.slice(grad, begin, size))
152  elif isinstance(grad, ops.IndexedSlices):
153    # Using mod here for convenience since concat_dim is already verified
154    # in concat implementation to be within the allowed [-rank, rank) range.
155    non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
156    concat_dim_static = tensor_util.constant_value(concat_dim)
157    if concat_dim_static is None:
158      raise ValueError("Can only compute IndexedSlices gradient with "
159                       "statically-known concat_dim")
160    if concat_dim_static < 0:
161      rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
162      if rank is None:
163        raise ValueError("Can only compute IndexedSlices gradient with "
164                         "negative concat_dim when first value rank is "
165                         "statically-known.")
166      concat_dim_static %= rank
167    # Get the inputs' tensor shapes
168    sizes = [array_ops.shape(x) for x in input_values]
169    if concat_dim_static > 0:
170      # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
171      # gradients with all the indices, but with grad.values sliced accordingly.
172      # This is like the Tensor case, except shape(grad.values)[0] is not equal
173      # to shape(sizes[i])[0], since only a subset of the dim-0 values are
174      # stored.
175      mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
176      for size in sizes:
177        new_values = array_ops.slice(
178            grad.values, begin,
179            array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
180        out_grads.append(ops.IndexedSlices(new_values, grad.indices, size))
181        # Lint complains begin = begin + ...
182        begin = math_ops.add(begin, size * mask)
183    else:
184      # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
185      # only for the relevant indices.
186      start = constant_op.constant(0, dtype=grad.indices.dtype)
187      for size in sizes:
188        size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
189        if size_concat_dim.dtype != grad.indices.dtype:
190          size_concat_dim = math_ops.cast(
191              size_concat_dim, dtype=grad.indices.dtype)
192        end = start + size_concat_dim
193        # Compute the 1-D Tensor of indices relevant for this input.
194        indices_to_select = array_ops.squeeze(
195            array_ops.where(
196                math_ops.logical_and(grad.indices >= start,
197                                     grad.indices < end)),
198            axis=[1])
199        new_indices = array_ops.gather(grad.indices, indices_to_select) - start
200        new_values = array_ops.gather(grad.values, indices_to_select)
201        out_grads.append(ops.IndexedSlices(new_values, new_indices, size))
202        start = end
203  else:
204    raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
205
206  return (out_grads + [None]
207          if end_value_index <= dim_index else [None] + out_grads)
208
209
210@ops.RegisterGradient("Concat")
211def _ConcatGrad(op, grad):
212  return _ConcatGradHelper(
213      op,
214      grad,
215      start_value_index=1,
216      end_value_index=len(op.inputs),
217      dim_index=0)
218
219
220@ops.RegisterGradient("ConcatV2")
221def _ConcatGradV2(op, grad):
222  return _ConcatGradHelper(
223      op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
224
225
226ops.NotDifferentiable("ConcatOffset")
227
228
229@ops.RegisterGradient("Slice")
230def _SliceGrad(op, grad):
231  """Gradient for Slice op."""
232  # Create an Nx2 padding where the first column represents how many
233  # zeros are to be prepended for each dimension, and the second
234  # column indicates how many zeros are appended.
235  #
236  # The number of zeros to append is the shape of the input
237  # elementwise-subtracted by both the begin vector and sizes vector.
238  #
239  # Some more reshaping is needed to assemble this tensor with the
240  # right dimensions.
241  input_vec = op.inputs[0]
242  begin_vec = op.inputs[1]
243  input_rank = array_ops.rank(input_vec)
244  slice_size = array_ops.shape(op.outputs[0])
245
246  shape = array_ops.stack([input_rank, 1])
247  before_pad = array_ops.reshape(begin_vec, shape)
248  after_pad = array_ops.reshape(
249      array_ops.shape(input_vec) - slice_size - begin_vec, shape)
250  paddings = array_ops.concat([before_pad, after_pad], 1)
251  return array_ops.pad(grad, paddings), None, None
252
253
254@ops.RegisterGradient("StridedSlice")
255def _StridedSliceGrad(op, grad):
256  """Gradient for StridedSlice op."""
257  begin = op.inputs[1]
258  end = op.inputs[2]
259  strides = op.inputs[3]
260  # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
261  # same dtype so we build a shape of the same type as other args.
262  # Note that the choice of `begin` for specifying `out_type` is arbitrary.
263  # We could choose any of {begin|end|strides}.dtype since they are required to
264  # be the same.
265  x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
266
267  return array_ops.strided_slice_grad(
268      x,
269      begin,
270      end,
271      strides,
272      grad,
273      begin_mask=op.get_attr("begin_mask"),
274      end_mask=op.get_attr("end_mask"),
275      ellipsis_mask=op.get_attr("ellipsis_mask"),
276      new_axis_mask=op.get_attr("new_axis_mask"),
277      shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
278
279
280@ops.RegisterGradient("StridedSliceGrad")
281def _StridedSliceGradGrad(op, grad):
282  """Gradient for StridedSliceGrad op."""
283  begin = op.inputs[1]
284  end = op.inputs[2]
285  strides = op.inputs[3]
286
287  return None, None, None, None, array_ops.strided_slice(
288      grad,
289      begin,
290      end,
291      strides,
292      begin_mask=op.get_attr("begin_mask"),
293      end_mask=op.get_attr("end_mask"),
294      ellipsis_mask=op.get_attr("ellipsis_mask"),
295      new_axis_mask=op.get_attr("new_axis_mask"),
296      shrink_axis_mask=op.get_attr("shrink_axis_mask"))
297
298
299@ops.RegisterGradient("Split")
300def _SplitGrad(op, *grads):
301  return None, array_ops.concat(list(grads), op.inputs[0])
302
303
304@ops.RegisterGradient("SplitV")
305def _SplitVGrad(op, *grads):
306  returnval = array_ops.concat(list(grads), op.inputs[2])
307  returnval = [returnval] + [
308      None,
309  ] * (
310      len(op.inputs) - 1)
311  return returnval
312
313
314ops.NotDifferentiable("Const")
315
316
317@ops.RegisterGradient("Diag")
318def _DiagGrad(_, grad):
319  return array_ops.diag_part(grad)
320
321
322@ops.RegisterGradient("DiagPart")
323def _DiagPartGrad(_, grad):
324  return array_ops.diag(grad)
325
326
327@ops.RegisterGradient("MatrixDiag")
328def _MatrixDiagGrad(_, grad):
329  return array_ops.matrix_diag_part(grad)
330
331
332@ops.RegisterGradient("MatrixDiagPart")
333def _MatrixDiagPartGrad(op, grad):
334  matrix_shape = op.inputs[0].get_shape()[-2:]
335  if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
336    return array_ops.matrix_diag(grad)
337  else:
338    return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
339
340
341@ops.RegisterGradient("MatrixSetDiag")
342def _MatrixSetDiagGrad(op, grad):
343  """Gradient for MatrixSetDiag."""
344  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
345  diag_shape = op.inputs[1].get_shape()
346  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
347  matrix_shape = input_shape[-2:]
348  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
349    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
350  else:
351    with ops.colocate_with(grad):
352      grad_shape = array_ops.shape(grad)
353      grad_rank = array_ops.rank(grad)
354      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
355      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
356      min_dim = math_ops.reduce_min(matrix_shape)
357      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
358  grad_input = array_ops.matrix_set_diag(grad,
359                                         array_ops.zeros(
360                                             diag_shape, dtype=grad.dtype))
361  grad_diag = array_ops.matrix_diag_part(grad)
362  return (grad_input, grad_diag)
363
364
365@ops.RegisterGradient("MatrixBandPart")
366def _MatrixBandPartGrad(op, grad):
367  num_lower = op.inputs[1]
368  num_upper = op.inputs[2]
369  return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
370
371
372# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
373ops.NotDifferentiable("EditDistance")
374
375
376@ops.RegisterGradient("Fill")
377def _FillGrad(_, grad):
378  return None, math_ops.reduce_sum(grad)
379
380
381ops.NotDifferentiable("ZerosLike")
382ops.NotDifferentiable("OnesLike")
383
384
385@ops.RegisterGradient("PreventGradient")
386def _PreventGradientGrad(op, _):
387  raise LookupError(
388      "Gradient explicitly disabled. Reason: %s" % op.get_attr("message"))
389
390
391@ops.RegisterGradient("Gather")
392def _GatherGrad(op, grad):
393  """Gradient for Gather op."""
394  # params can be large, so colocate the shape calculation with it.
395  #
396  # params can be very large for sparse model, array_ops.shape raises
397  # exception on the Windows platform when any dimension is larger than
398  # int32. params_shape is not used in optimizer apply_sparse gradients,
399  # so it's fine to convert it back to int32 regardless of truncation.
400  params = op.inputs[0]
401  with ops.colocate_with(params):
402    params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
403    params_shape = math_ops.cast(params_shape, dtypes.int32)
404
405  # Build appropriately shaped IndexedSlices
406  indices = op.inputs[1]
407  size = array_ops.expand_dims(array_ops.size(indices), 0)
408  values_shape = array_ops.concat([size, params_shape[1:]], 0)
409  values = array_ops.reshape(grad, values_shape)
410  indices = array_ops.reshape(indices, size)
411  return [ops.IndexedSlices(values, indices, params_shape), None]
412
413
414@ops.RegisterGradient("GatherV2")
415def _GatherV2Grad(op, grad):
416  """Gradient for GatherV2 op."""
417  # params can be large, so colocate the shape calculation with it.
418  #
419  # params can be very large for sparse model, array_ops.shape raises
420  # exception on the Windows platform when any dimension is larger than
421  # int32. params_shape is not used in optimizer apply_sparse gradients,
422  # so it's fine to convert it back to int32 regardless of truncation.
423  params = op.inputs[0]
424  with ops.colocate_with(params):
425    params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
426    params_shape = math_ops.cast(params_shape, dtypes.int32)
427
428  indices = op.inputs[1]
429  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
430  axis = op.inputs[2]
431  axis_static = tensor_util.constant_value(axis)
432
433  # For axis 0 gathers, build an appropriately shaped IndexedSlices.
434  if axis_static == 0:
435    if context.executing_eagerly():
436      params_tail_shape = params_shape.cpu()[1:]
437    else:
438      params_tail_shape = params_shape[1:]
439    values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
440    values = array_ops.reshape(grad, values_shape)
441    indices = array_ops.reshape(indices, indices_size)
442    return [ops.IndexedSlices(values, indices, params_shape), None, None]
443
444  outer_shape = params_shape[:axis]
445  outer_dims = array_ops.size(outer_shape)
446  inner_shape = params_shape[axis:][1:]
447  inner_dims = array_ops.size(inner_shape)
448
449  outer_axes_indices = math_ops.range(outer_dims)
450  inner_axes_indices = math_ops.range(outer_dims + 1,
451                                      outer_dims + 1 + inner_dims)
452
453  values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
454  values = array_ops.reshape(grad, values_shape)
455  indices = array_ops.reshape(indices, indices_size)
456
457  # We need to sum up every slice `values[..., i, ....]` corresponding to
458  # `params[..., indices[i], ...]`. Since `unsorted_segment_sum` does not
459  # support an axis parameter, we transpose the gather dimension to the front,
460  # then use `unsorted_segment_sum` to build a
461  # [gather_axis, outer_axes, inner_axes] tensor with all the gradients
462  # affecting each index in `gather_axis` summed up.
463  transpose_dims = array_ops.concat(
464      [[outer_dims], outer_axes_indices, inner_axes_indices], 0)
465  values_transpose = array_ops.transpose(values, transpose_dims)
466  num_segments = params_shape[axis]
467
468  params_grad = math_ops.unsorted_segment_sum(values_transpose, indices,
469                                              num_segments)
470
471  # Inverts the above transpose by moving dimension 0 back to its original
472  # position.
473  invert_transpose_dims = array_ops.concat(
474      [outer_axes_indices + 1, [0], inner_axes_indices], 0)
475  params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
476  return [params_grad, None, None]
477
478
479@ops.RegisterGradient("GatherNd")
480def _GatherNdGrad(op, grad):
481  ref = op.inputs[0]
482  indices = op.inputs[1]
483  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
484  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
485    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
486                                 ref_shape)
487  else:
488    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
489  return [ref_grad, None]
490
491
492@ops.RegisterGradient("CheckNumerics")
493def _CheckNumericsGrad(op, grad):
494  """Gradient for check_numerics op."""
495  return array_ops.check_numerics(
496      grad,
497      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
498      op.get_attr("message"))
499
500
501@ops.RegisterGradient("PlaceholderWithDefault")
502@ops.RegisterGradient("Identity")
503def _IdGrad(_, grad):
504  return grad
505
506
507@ops.RegisterGradient("RefIdentity")
508def _RefIdGrad(_, grad):
509  return grad
510
511
512@ops.RegisterGradient("IdentityN")
513def _IdNGrad(_, *grad):
514  return grad
515
516
517ops.NotDifferentiable("StopGradient")
518
519
520@ops.RegisterGradient("Reshape")
521def _ReshapeGrad(op, grad):
522  return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
523
524
525ops.NotDifferentiable("InvertPermutation")
526
527
528def _ReshapeToInput(op, grad):
529  """Reshapes the gradient to the shape of the original input."""
530  return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
531
532
533@ops.RegisterGradient("ExpandDims")
534def _ExpandDimsGrad(op, grad):
535  return [_ReshapeToInput(op, grad), None]
536
537
538@ops.RegisterGradient("Squeeze")
539def _SqueezeGrad(op, grad):
540  return _ReshapeToInput(op, grad)
541
542
543@ops.RegisterGradient("Transpose")
544def _TransposeGrad(op, grad):
545  """Returns unshuffle(grad)."""
546  p = op.inputs[1]
547  return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
548
549
550@ops.RegisterGradient("ConjugateTranspose")
551def _ConjugateTransposeGrad(op, grad):
552  """Returns conj(unshuffle(grad))."""
553  p = op.inputs[1]
554  return [
555      array_ops.transpose(
556          grad, array_ops.invert_permutation(p), conjugate=True), None
557  ]
558
559
560ops.NotDifferentiable("Shape")
561
562ops.NotDifferentiable("ShapeN")
563
564ops.NotDifferentiable("Rank")
565
566ops.NotDifferentiable("Size")
567
568
569@ops.RegisterGradient("Tile")
570def _TileGrad(op, grad):
571  """Sum reduces grad along the tiled dimensions."""
572  input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype)
573  # We interleave multiples and input_shape to get split_shape,
574  # reshape grad to split_shape, and reduce along all even
575  # dimensions (the tiled dimensions) to get the result
576  # with shape input_shape.  For example
577  #   input_shape = [20, 30, 40]
578  #   multiples = [2, 3, 4]
579  #   split_shape = [2, 20, 3, 30, 4, 40]
580  #   axes = [0, 2, 4]
581  split_shape = array_ops.reshape(
582      array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
583  axes = math_ops.range(0, array_ops.size(split_shape), 2)
584  # Sum reduces grad along the first dimension for IndexedSlices
585  if isinstance(grad, ops.IndexedSlices):
586    input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
587    grad = math_ops.unsorted_segment_sum(
588        grad.values,
589        math_ops.mod(grad.indices, input_shape_0),
590        input_shape_0)
591    split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
592  input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
593  # Fix shape inference
594  if not context.executing_eagerly():
595    input_grad.set_shape(op.inputs[0].get_shape())
596  return [input_grad, None]
597
598
599ops.NotDifferentiable("BroadcastGradientArgs")
600
601
602def _PadGrad(op, grad):
603  """Gradient for Pad."""
604  # Pad introduces values around the original tensor, so the gradient function
605  # slices the original shape out of the gradient."""
606  x = op.inputs[0]
607  a = op.inputs[1]  # [Rank(x), 2]
608  # Takes a slice of a. The 1st column. [Rank(x), 1].
609  pad_before = array_ops.slice(a, [0, 0],
610                               array_ops.stack([array_ops.rank(x), 1]))
611  # Make it a 1-D tensor.
612  begin = array_ops.reshape(pad_before, [-1])
613  sizes = array_ops.shape(x)
614  x_grad = array_ops.slice(grad, begin, sizes)
615  if len(op.inputs) == 3:
616    return x_grad, None, None
617  else:
618    return x_grad, None
619
620
621ops.RegisterGradient("Pad")(_PadGrad)
622ops.RegisterGradient("PadV2")(_PadGrad)
623
624
625# ReverseSequence is just a permutation.  The gradient permutes back.
626@ops.RegisterGradient("ReverseSequence")
627def _ReverseSequenceGrad(op, grad):
628  seq_lengths = op.inputs[1]
629  return [
630      array_ops.reverse_sequence(
631          grad,
632          batch_axis=op.get_attr("batch_dim"),
633          seq_axis=op.get_attr("seq_dim"),
634          seq_lengths=seq_lengths), None
635  ]
636
637
638@ops.RegisterGradient("Reverse")
639def _ReverseGrad(op, grad):
640  reverse_dims = op.inputs[1]
641  return gen_array_ops.reverse(grad, reverse_dims), None
642
643
644@ops.RegisterGradient("ReverseV2")
645def _ReverseV2Grad(op, grad):
646  axis = op.inputs[1]
647  return array_ops.reverse_v2(grad, axis), None
648
649
650@ops.RegisterGradient("SpaceToBatch")
651def _SpaceToBatchGrad(op, grad):
652  # Its gradient is the opposite op: BatchToSpace.
653  block_size = op.get_attr("block_size")
654  return [
655      array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
656  ]
657
658
659@ops.RegisterGradient("SpaceToBatchND")
660def _SpaceToBatchNDGrad(op, grad):
661  # Its gradient is the opposite op: BatchToSpaceND.
662  return [
663      array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
664  ]
665
666
667@ops.RegisterGradient("BatchToSpace")
668def _BatchToSpaceGrad(op, grad):
669  # Its gradient is the opposite op: SpaceToBatch.
670  block_size = op.get_attr("block_size")
671  return [
672      array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
673  ]
674
675
676@ops.RegisterGradient("BatchToSpaceND")
677def _BatchToSpaceNDGrad(op, grad):
678  # Its gradient is the opposite op: SpaceToBatchND.
679  return [
680      array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
681  ]
682
683
684@ops.RegisterGradient("SpaceToDepth")
685def _SpaceToDepthGrad(op, grad):
686  # Its gradient is the opposite op: DepthToSpace.
687  block_size = op.get_attr("block_size")
688  data_format = op.get_attr("data_format")
689  if data_format == "NCHW_VECT_C":
690    raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. "
691                     "NCHW_VECT_C requires qint8 data type.")
692  return array_ops.depth_to_space(grad, block_size, data_format=data_format)
693
694
695@ops.RegisterGradient("DepthToSpace")
696def _DepthToSpaceGrad(op, grad):
697  # Its gradient is the opposite op: SpaceToDepth.
698  block_size = op.get_attr("block_size")
699  data_format = op.get_attr("data_format")
700  if data_format == "NCHW_VECT_C":
701    raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. "
702                     "NCHW_VECT_C requires qint8 data type.")
703  return array_ops.space_to_depth(grad, block_size, data_format=data_format)
704
705
706ops.NotDifferentiable("OneHot")
707
708
709@ops.RegisterGradient("MirrorPad")
710def _MirrorPadGrad(op, grad):
711  mode = op.get_attr("mode")
712  return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None]
713
714
715@ops.RegisterGradient("MirrorPadGrad")
716def _MirrorPadGradGrad(op, grad):
717  mode = op.get_attr("mode")
718  return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None]
719
720
721@ops.RegisterGradient("QuantizeAndDequantize")
722def _QuantizeAndDequantizeGrad(_, grad):
723  return grad
724
725
726@ops.RegisterGradient("QuantizeAndDequantizeV2")
727def _QuantizeAndDequantizeV2Grad(_, grad):
728  return [grad, None, None]
729
730
731@ops.RegisterGradient("QuantizeAndDequantizeV3")
732def _QuantizeAndDequantizeV3Grad(_, grad):
733  # Only propagate the gradient for the unquantized input.
734  return [grad, None, None, None]
735
736
737@ops.RegisterGradient("ExtractImagePatches")
738def _ExtractImagePatchesGrad(op, grad):
739  batch_size, rows_in, cols_in, channels = [
740      dim.value for dim in op.inputs[0].shape.dims
741  ]
742  input_bhwc = array_ops.shape(op.inputs[0])
743  batch_size = input_bhwc[0]
744  channels = input_bhwc[3]
745
746  # Create indices matrix for input tensor.
747  # Note that 0 is preserved for padding location,
748  # so indices for input start from 1 to 1 + rows_in * cols_in.
749  input_indices_num = 1 + rows_in * cols_in
750  input_idx = array_ops.reshape(math_ops.range(1, input_indices_num,
751                                               dtype=ops.dtypes.int64),
752                                (1, rows_in, cols_in, 1))
753  input_idx_patched = gen_array_ops.extract_image_patches(
754      input_idx,
755      op.get_attr("ksizes"),
756      op.get_attr("strides"),
757      op.get_attr("rates"),
758      op.get_attr("padding"))
759
760  # Create indices matrix for output tensor.
761  _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
762  _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
763  # Indices for output start from 0.
764  output_indices_num = rows_out * cols_out * ksize_r * ksize_c
765  output_idx = array_ops.reshape(math_ops.range(output_indices_num,
766                                                dtype=ops.dtypes.int64),
767                                 (1, rows_out, cols_out, ksize_r * ksize_c))
768
769  # Construct mapping table for indices: (input -> output).
770  idx_matrix = array_ops.concat(
771      [array_ops.expand_dims(input_idx_patched, axis=-1),
772       array_ops.expand_dims(output_idx, axis=-1)],
773      axis=-1)
774  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
775
776  sp_shape = (input_indices_num, output_indices_num)
777  sp_mat_full = sparse_tensor.SparseTensor(
778      idx_map,
779      array_ops.ones([output_indices_num], dtype=grad.dtype),
780      sp_shape)
781  # Remove all padding locations [0, :].
782  sp_mat = sparse_ops.sparse_slice(sp_mat_full,
783                                   (1, 0),
784                                   (input_indices_num - 1, output_indices_num))
785
786  grad_expanded = array_ops.transpose(
787      array_ops.reshape(
788          grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
789      (1, 2, 3, 4, 0, 5))
790  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
791
792  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
793
794  grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
795  grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
796
797  return [grad_out]
798
799
800@ops.RegisterGradient("ExtractVolumePatches")
801def _ExtractVolumePatchesGrad(op, grad):
802  batch_size, planes_in, rows_in, cols_in, channels = [
803      dim.value for dim in op.inputs[0].shape.dims
804  ]
805  input_bphwc = array_ops.shape(op.inputs[0])
806  batch_size = input_bphwc[0]
807  channels = input_bphwc[4]
808
809  # Create indices matrix for input tensor.
810  # Note that 0 is preserved for padding location,
811  # so indices for input start from 1 to 1 + rows_in * cols_in.
812  input_indices_num = 1 + planes_in * rows_in * cols_in
813  input_idx = array_ops.reshape(
814      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
815      (1, planes_in, rows_in, cols_in, 1))
816  input_idx_patched = gen_array_ops.extract_volume_patches(
817      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
818      op.get_attr("padding"))
819
820  # Create indices matrix for output tensor.
821  _, planes_out, rows_out, cols_out, _ = [
822      dim.value for dim in op.outputs[0].shape.dims
823  ]
824  _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
825  # Indices for output start from 0.
826  prc_indices_num = planes_out * rows_out * cols_out
827  output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
828  output_idx = array_ops.reshape(
829      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
830      (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))
831
832  # Construct mapping table for indices: (input -> output).
833  idx_matrix = array_ops.concat([
834      array_ops.expand_dims(input_idx_patched, axis=-1),
835      array_ops.expand_dims(output_idx, axis=-1)
836  ],
837                                axis=-1)
838  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
839
840  sp_shape = (input_indices_num, output_indices_num)
841  sp_mat_full = sparse_tensor.SparseTensor(
842      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
843  # Remove all padding locations [0, :].
844  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
845                                   (input_indices_num - 1, output_indices_num))
846
847  grad_expanded = array_ops.transpose(
848      array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
849                               ksize_p, ksize_r, ksize_c, channels)),
850      (1, 2, 3, 4, 5, 6, 0, 7))
851  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
852
853  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
854
855  grad_out = array_ops.reshape(
856      jac, (planes_in, rows_in, cols_in, batch_size, channels))
857  grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))
858
859  return [grad_out]
860
861
862@ops.RegisterGradient("ScatterNd")
863def _ScatterNdGrad(op, grad):
864  indices = op.inputs[0]
865  updates_grad = array_ops.gather_nd(grad, indices)
866  return [None, updates_grad, None]
867
868
869@ops.RegisterGradient("TensorScatterUpdate")
870def _TensorScatterUpdateGrad(op, grad):
871  indices = op.inputs[1]
872  updates_grad = array_ops.gather_nd(grad, indices)
873  tensor_grad = array_ops.tensor_scatter_update(
874      array_ops.identity(grad), indices,
875      array_ops.zeros_like(op.inputs[2], dtype=grad.dtype))
876  return [tensor_grad, None, updates_grad]
877
878
879@ops.RegisterGradient("TensorScatterAdd")
880def _TensorScatterAddGrad(op, grad):
881  indices = op.inputs[1]
882  updates_grad = array_ops.gather_nd(grad, indices)
883  tensor_grad = array_ops.identity(grad)
884  return [tensor_grad, None, updates_grad]
885
886
887@ops.RegisterGradient("TensorScatterSub")
888def _TensorScatterSubGrad(op, grad):
889  indices = op.inputs[1]
890  updates_grad = array_ops.gather_nd(grad, indices)
891  tensor_grad = array_ops.identity(grad)
892  return [tensor_grad, None, -updates_grad]
893
894
895@ops.RegisterGradient("ScatterNdNonAliasingAdd")
896def _ScatterNdNonAliasingAddGrad(op, grad):
897  indices = op.inputs[1]
898  updates_grad = array_ops.gather_nd(grad, indices)
899  return [grad, None, updates_grad]
900
901
902@ops.RegisterGradient("BroadcastTo")
903def _BroadcastToGrad(op, grad):
904  input_value = op.inputs[0]
905  broadcast_shape = op.inputs[1]
906  input_value_shape = array_ops.shape(input_value)
907  _, reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape,
908                                                            input_value_shape)
909  updates_grad_reshaped = math_ops.reduce_sum(grad,
910                                              axis=reduction_axes,
911                                              keepdims=True)
912  updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
913  return [updates_grad, None]
914