1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in nn_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_nn_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_ops
30
31
32@ops.RegisterGradient("Conv2DBackpropInput")
33def _Conv2DBackpropInputGrad(op, grad):
34  """The derivatives for deconvolution.
35
36  Args:
37    op: the Deconvolution op.
38    grad: the tensor representing the gradient w.r.t. the output
39
40  Returns:
41    the gradients w.r.t. the input and the filter
42  """
43  return [
44      None,
45      nn_ops.conv2d_backprop_filter(
46          grad,
47          array_ops.shape(op.inputs[1]),
48          op.inputs[2],
49          dilations=op.get_attr("dilations"),
50          strides=op.get_attr("strides"),
51          padding=op.get_attr("padding"),
52          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
53          data_format=op.get_attr("data_format").decode()),
54      nn_ops.conv2d(
55          grad,
56          op.inputs[1],
57          dilations=op.get_attr("dilations"),
58          strides=op.get_attr("strides"),
59          padding=op.get_attr("padding"),
60          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
61          data_format=op.get_attr("data_format").decode())
62  ]
63
64
65@ops.RegisterGradient("Conv2DBackpropFilter")
66def _Conv2DBackpropFilterGrad(op, grad):
67  return [
68      nn_ops.conv2d_backprop_input(
69          array_ops.shape(op.inputs[0]),
70          grad,
71          op.inputs[2],
72          dilations=op.get_attr("dilations"),
73          strides=op.get_attr("strides"),
74          padding=op.get_attr("padding"),
75          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
76          data_format=op.get_attr("data_format").decode()), None,
77      nn_ops.conv2d(
78          op.inputs[0],
79          grad,
80          dilations=op.get_attr("dilations"),
81          strides=op.get_attr("strides"),
82          padding=op.get_attr("padding"),
83          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
84          data_format=op.get_attr("data_format").decode())
85  ]
86
87
88@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput")
89def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
90  """The derivatives for deconvolution.
91
92  Args:
93    op: the Deconvolution op.
94    grad: the tensor representing the gradient w.r.t. the output
95
96  Returns:
97    the gradients w.r.t. the input and the filter
98  """
99  return [
100      None,
101      nn_ops.depthwise_conv2d_native_backprop_filter(
102          grad,
103          array_ops.shape(op.inputs[1]),
104          op.inputs[2],
105          dilations=op.get_attr("dilations"),
106          strides=op.get_attr("strides"),
107          padding=op.get_attr("padding"),
108          data_format=op.get_attr("data_format")),
109      nn_ops.depthwise_conv2d_native(
110          grad,
111          op.inputs[1],
112          dilations=op.get_attr("dilations"),
113          strides=op.get_attr("strides"),
114          padding=op.get_attr("padding"),
115          data_format=op.get_attr("data_format"))
116  ]
117
118
119@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter")
120def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad):
121  return [
122      nn_ops.depthwise_conv2d_native_backprop_input(
123          array_ops.shape(op.inputs[0]),
124          grad,
125          op.inputs[2],
126          dilations=op.get_attr("dilations"),
127          strides=op.get_attr("strides"),
128          padding=op.get_attr("padding"),
129          data_format=op.get_attr("data_format")), None,
130      nn_ops.depthwise_conv2d_native(
131          op.inputs[0],
132          grad,
133          dilations=op.get_attr("dilations"),
134          strides=op.get_attr("strides"),
135          padding=op.get_attr("padding"),
136          data_format=op.get_attr("data_format"))
137  ]
138
139
140@ops.RegisterGradient("Conv3D")
141def _Conv3DGrad(op, grad):
142  data_format = op.get_attr("data_format").decode()
143  return [
144      nn_ops.conv3d_backprop_input_v2(
145          array_ops.shape(op.inputs[0]),
146          op.inputs[1],
147          grad,
148          dilations=op.get_attr("dilations"),
149          strides=op.get_attr("strides"),
150          padding=op.get_attr("padding"),
151          data_format=data_format),
152      nn_ops.conv3d_backprop_filter_v2(
153          op.inputs[0],
154          array_ops.shape(op.inputs[1]),
155          grad,
156          dilations=op.get_attr("dilations"),
157          strides=op.get_attr("strides"),
158          padding=op.get_attr("padding"),
159          data_format=data_format)
160  ]
161
162
163@ops.RegisterGradient("Conv3DBackpropInputV2")
164def _Conv3DBackpropInputGrad(op, grad):
165  data_format = op.get_attr("data_format").decode()
166  return [
167      None,
168      nn_ops.conv3d_backprop_filter_v2(
169          grad,
170          array_ops.shape(op.inputs[1]),
171          op.inputs[2],
172          dilations=op.get_attr("dilations"),
173          strides=op.get_attr("strides"),
174          padding=op.get_attr("padding"),
175          data_format=data_format),
176      nn_ops.conv3d(
177          grad,
178          op.inputs[1],
179          dilations=op.get_attr("dilations"),
180          strides=op.get_attr("strides"),
181          padding=op.get_attr("padding"),
182          data_format=data_format)
183  ]
184
185
186@ops.RegisterGradient("Conv3DBackpropFilterV2")
187def _Conv3DBackpropFilterGrad(op, grad):
188  data_format = op.get_attr("data_format").decode()
189  return [
190      nn_ops.conv3d_backprop_input_v2(
191          array_ops.shape(op.inputs[0]),
192          grad,
193          op.inputs[2],
194          dilations=op.get_attr("dilations"),
195          strides=op.get_attr("strides"),
196          padding=op.get_attr("padding"),
197          data_format=data_format), None,
198      nn_ops.conv3d(
199          op.inputs[0],
200          grad,
201          dilations=op.get_attr("dilations"),
202          strides=op.get_attr("strides"),
203          padding=op.get_attr("padding"),
204          data_format=data_format)
205  ]
206
207
208@ops.RegisterGradient("AvgPool3D")
209def _AvgPool3DGrad(op, grad):
210  return gen_nn_ops.avg_pool3d_grad(
211      array_ops.shape(op.inputs[0]),
212      grad,
213      ksize=op.get_attr("ksize"),
214      strides=op.get_attr("strides"),
215      padding=op.get_attr("padding"),
216      data_format=op.get_attr("data_format").decode())
217
218
219@ops.RegisterGradient("AvgPool3DGrad")
220def _AvgPool3DGradGrad(op, grad):
221  return (array_ops.stop_gradient(op.inputs[0]),
222          gen_nn_ops.avg_pool3d(
223              grad,
224              op.get_attr("ksize"),
225              op.get_attr("strides"),
226              op.get_attr("padding"),
227              data_format=op.get_attr("data_format").decode()))
228
229
230@ops.RegisterGradient("MaxPool3D")
231def _MaxPool3DGrad(op, grad):
232  return gen_nn_ops.max_pool3d_grad(
233      op.inputs[0],
234      op.outputs[0],
235      grad,
236      ksize=op.get_attr("ksize"),
237      strides=op.get_attr("strides"),
238      padding=op.get_attr("padding"),
239      data_format=op.get_attr("data_format").decode())
240
241
242@ops.RegisterGradient("MaxPool3DGrad")
243def _MaxPool3DGradGrad(op, grad):
244  return (array_ops.zeros(
245      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
246          array_ops.zeros(
247              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
248          gen_nn_ops.max_pool3d_grad_grad(
249              op.inputs[0],
250              op.inputs[1],
251              grad,
252              op.get_attr("ksize"),
253              op.get_attr("strides"),
254              padding=op.get_attr("padding"),
255              data_format=op.get_attr("data_format").decode()))
256
257
258@ops.RegisterGradient("MaxPool3DGradGrad")
259def _MaxPool3DGradGradGrad(op, grad):
260  return (array_ops.zeros(
261      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
262          array_ops.zeros(
263              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
264          gen_nn_ops.max_pool3d_grad(
265              op.inputs[0],
266              op.inputs[1],
267              grad,
268              op.get_attr("ksize"),
269              op.get_attr("strides"),
270              padding=op.get_attr("padding"),
271              data_format=op.get_attr("data_format").decode()))
272
273
274@ops.RegisterGradient("Softmax")
275def _SoftmaxGrad(op, grad_softmax):
276  """The derivative of the softmax nonlinearity.
277
278  We assume that probs is of shape [batch_size * dim]
279  The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
280  This matrix is diagonal minus a rank one matrix, so it is easy to implement
281  as follows:
282
283    grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
284
285  Args:
286     op: the Softmax op.
287     grad_softmax:  the tensor representing the gradient w.r.t. the softmax
288       output.
289
290  Returns:
291     gradient w.r.t the input to the softmax
292
293  """
294  softmax = op.outputs[0]
295  sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True)
296  return (grad_softmax - sum_channels) * softmax
297
298
299@ops.RegisterGradient("LogSoftmax")
300def _LogSoftmaxGrad(op, grad):
301  """The gradient for log_softmax.
302
303      log_softmax = input - log(sum(exp(input))
304      dlog_softmax/dinput = diag - softmax(input)
305
306  Args:
307    op: The log softmax op.
308    grad: The tensor representing the gradient w.r.t. the output.
309
310  Returns:
311    The gradients w.r.t. the input.
312  """
313  softmax = math_ops.exp(op.outputs[0])
314  return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax
315
316
317@ops.RegisterGradient("BiasAdd")
318def _BiasAddGrad(op, received_grad):
319  """Return the gradients for the 2 inputs of bias_op.
320
321  The first input of unused_bias_op is the tensor t, and its gradient is
322  just the gradient the unused_bias_op received.
323
324  The second input of unused_bias_op is the bias vector which has one fewer
325  dimension than "received_grad" (the batch dimension.)  Its gradient is the
326  received gradient Summed on the batch dimension, which is the first dimension.
327
328  Args:
329    op: The BiasOp for which we need to generate gradients.
330    received_grad: Tensor.  The gradients passed to the BiasOp.
331
332  Returns:
333    Two tensors, the first one for the "tensor" input of the BiasOp,
334    the second one for the "bias" input of the BiasOp.
335  """
336  try:
337    data_format = op.get_attr("data_format")
338  except ValueError:
339    data_format = None
340  return (received_grad,
341          gen_nn_ops.bias_add_grad(
342              out_backprop=received_grad, data_format=data_format))
343
344
345@ops.RegisterGradient("BiasAddGrad")
346def _BiasAddGradGrad(op, received_grad):
347  """Gradient for the BiasAddGrad op.
348
349  Args:
350    op: BiasAddGrad op for which we are calculating gradients.
351    received_grad: The gradients passed to the BiasAddGrad op.
352
353  Returns:
354    A single gradient Tensor for the input to BiasAddGrad (which
355    is the gradient of the bias term in BiasAdd)
356  """
357
358  try:
359    data_format = op.get_attr("data_format")
360  except ValueError:
361    data_format = None
362
363  shape = array_ops.shape(op.inputs[0])
364  bias_shape = array_ops.shape(received_grad)
365
366  if data_format == b"NCHW":
367    expanded_shape = array_ops.concat([
368        array_ops.ones_like(shape[:1]), bias_shape,
369        array_ops.ones_like(shape[2:])
370    ], 0)
371    tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
372  else:
373    expanded_shape = array_ops.concat(
374        [array_ops.ones_like(shape[:-1]), bias_shape], 0)
375    tile_mults = array_ops.concat([shape[:-1], [1]], 0)
376
377  expanded_grad = array_ops.reshape(received_grad, expanded_shape)
378  return array_ops.tile(expanded_grad, tile_mults)
379
380
381@ops.RegisterGradient("BiasAddV1")
382def _BiasAddGradV1(unused_bias_op, received_grad):
383  """Return the gradients for the 2 inputs of bias_op.
384
385  The first input of unused_bias_op is the tensor t, and its gradient is
386  just the gradient the unused_bias_op received.
387
388  The second input of unused_bias_op is the bias vector which has one fewer
389  dimension than "received_grad" (the batch dimension.)  Its gradient is the
390  received gradient Summed on the batch dimension, which is the first dimension.
391
392  Args:
393    unused_bias_op: The BiasOp for which we need to generate gradients.
394    received_grad: Tensor.  The gradients passed to the BiasOp.
395
396  Returns:
397    Two tensors, the first one for the "tensor" input of the BiasOp,
398    the second one for the "bias" input of the BiasOp.
399  """
400  reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
401  return (received_grad, math_ops.reduce_sum(received_grad,
402                                             reduction_dim_tensor))
403
404
405@ops.RegisterGradient("Relu")
406def _ReluGrad(op, grad):
407  return gen_nn_ops.relu_grad(grad, op.outputs[0])
408
409
410@ops.RegisterGradient("EluGrad")
411def _EluGradGrad(op, grad):
412  elu_x = op.inputs[1]
413  return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
414          array_ops.where(
415              elu_x < 0, grad * op.inputs[0],
416              array_ops.zeros(shape=array_ops.shape(elu_x), dtype=elu_x.dtype)))
417
418
419@ops.RegisterGradient("SeluGrad")
420def _SeluGradGrad(op, grad):
421  x = op.inputs[1]
422  scale_alpha = 1.7580993408473768599402175208123
423  return (gen_nn_ops.elu_grad(grad, op.outputs[0]),
424          array_ops.where(
425              x < 0., gen_nn_ops.elu_grad(grad, op.outputs[0] + scale_alpha),
426              array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
427
428
429@ops.RegisterGradient("Relu6")
430def _Relu6Grad(op, grad):
431  return gen_nn_ops.relu6_grad(grad, op.outputs[0])
432
433
434@ops.RegisterGradient("Relu6Grad")
435def _Relu6GradGrad(op, grad):
436  x = op.inputs[1]
437  return (gen_nn_ops.relu6_grad(grad, x),
438          array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
439
440
441@ops.RegisterGradient("LeakyRelu")
442def _LeakyReluGrad(op, grad):
443  x = op.inputs[0]
444  alpha = op.get_attr("alpha")
445  return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
446
447
448@ops.RegisterGradient("LeakyReluGrad")
449def _LeakyReluGradGrad(op, grad):
450  x = op.inputs[1]
451  alpha = op.get_attr("alpha")
452  return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
453          array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
454
455
456@ops.RegisterGradient("Elu")
457def _EluGrad(op, grad):
458  return gen_nn_ops.elu_grad(grad, op.outputs[0])
459
460
461@ops.RegisterGradient("Selu")
462def _SeluGrad(op, grad):
463  return gen_nn_ops.selu_grad(grad, op.outputs[0])
464
465
466@ops.RegisterGradient("Softplus")
467def _SoftplusGrad(op, grad):
468  return gen_nn_ops.softplus_grad(grad, op.inputs[0])
469
470
471@ops.RegisterGradient("SoftplusGrad")
472def _SoftplusGradGrad(op, grad):
473  # Let:
474  #   y = tf.nn.softplus(x)
475  #   dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x))
476  # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx.
477  dy, x = op.inputs
478  with ops.control_dependencies([grad]):
479    ddy = gen_nn_ops.softplus_grad(grad, x)
480    d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x))
481    return (ddy, d2x)
482
483
484@ops.RegisterGradient("Softsign")
485def _SoftsignGrad(op, grad):
486  return gen_nn_ops.softsign_grad(grad, op.inputs[0])
487
488
489@ops.RegisterGradient("ReluGrad")
490def _ReluGradGrad(op, grad):
491  x = op.inputs[1]
492  return (gen_nn_ops.relu_grad(grad, x),
493          array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
494
495
496def _BroadcastMul(vec, mat):
497  """Multiply after broadcasting vec to match dimensions of mat.
498
499  Args:
500    vec: A 1-D tensor of dimension [D0]
501    mat: A 2-D tensor of dimension [D0, D1]
502
503  Returns:
504    A tensor of dimension [D0, D1], the result of vec * mat
505  """
506  # Reshape vec to [D0, 1]
507  vec = array_ops.expand_dims(vec, -1)
508  return vec * mat
509
510
511@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
512def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
513  """Gradient function for SoftmaxCrossEntropyWithLogits."""
514  # grad_loss is the backprop for cost, and we multiply it with the gradients
515  # (which is output[1])
516  # grad_grad is the backprop for softmax gradient.
517  #
518  # Second derivative is just softmax derivative w.r.t. logits.
519  softmax_grad = op.outputs[1]
520  grad = _BroadcastMul(grad_loss, softmax_grad)
521
522  def IsZero(g):
523    # Some introspection to check if the gradient is feeding zeros
524    if context.executing_eagerly():
525      # TODO(apassos) add an efficient way to detect eager zeros here.
526      return False
527    if g.op.type in ("ZerosLike", "Zeros"):
528      return True
529    const_fill_value = tensor_util.constant_value(g)
530    return const_fill_value is not None and (const_fill_value == 0).all()
531
532  logits = op.inputs[0]
533  if grad_grad is not None and not IsZero(grad_grad):
534    softmax = nn_ops.softmax(logits)
535
536    grad += ((grad_grad - array_ops.squeeze(
537        math_ops.matmul(
538            array_ops.expand_dims(grad_grad, 1),
539            array_ops.expand_dims(softmax, 2)),
540        axis=1)) * softmax)
541
542  return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
543
544
545@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
546def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
547  """Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
548  # grad_0 is the backprop for cost, and we multiply it with the gradients
549  # (which is output[1])
550  # There is no gradient for the labels
551  #
552  # Currently there is no way to take the second derivative of this op
553  # due to the fused implementation's interaction with tf.gradients(),
554  # so we make sure we prevent silently incorrect results by raising
555  # an error if the second derivative is requested via prevent_gradient.
556  sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
557      op.outputs[1],
558      message="Currently there is no way to take the second "
559      "derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
560      "implementation's interaction with tf.gradients()")
561  return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
562
563
564@ops.RegisterGradient("Conv2D")
565def _Conv2DGrad(op, grad):
566  """Gradient function for Conv2D."""
567  dilations = op.get_attr("dilations")
568  strides = op.get_attr("strides")
569  padding = op.get_attr("padding")
570  explicit_paddings = op.get_attr("explicit_paddings")
571  use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
572  data_format = op.get_attr("data_format")
573  shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
574
575  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
576  # functions for performance reasons in Eager mode. gen_nn_ops functions take a
577  # `explicit_paddings` parameter, but nn_ops functions do not. So if were were
578  # to use the nn_ops functions, we would have to convert `padding` and
579  # `explicit_paddings` into a single `padding` parameter, increasing overhead
580  # in Eager mode.
581  return [
582      gen_nn_ops.conv2d_backprop_input(
583          shape_0,
584          op.inputs[1],
585          grad,
586          dilations=dilations,
587          strides=strides,
588          padding=padding,
589          explicit_paddings=explicit_paddings,
590          use_cudnn_on_gpu=use_cudnn_on_gpu,
591          data_format=data_format),
592      gen_nn_ops.conv2d_backprop_filter(
593          op.inputs[0],
594          shape_1,
595          grad,
596          dilations=dilations,
597          strides=strides,
598          padding=padding,
599          explicit_paddings=explicit_paddings,
600          use_cudnn_on_gpu=use_cudnn_on_gpu,
601          data_format=data_format)
602  ]
603
604
605@ops.RegisterGradient("DepthwiseConv2dNative")
606def _DepthwiseConv2dNativeGrad(op, grad):
607  return [
608      nn_ops.depthwise_conv2d_native_backprop_input(
609          array_ops.shape(op.inputs[0]),
610          op.inputs[1],
611          grad,
612          op.get_attr("strides"),
613          op.get_attr("padding"),
614          data_format=op.get_attr("data_format")),
615      nn_ops.depthwise_conv2d_native_backprop_filter(
616          op.inputs[0],
617          array_ops.shape(op.inputs[1]),
618          grad,
619          op.get_attr("strides"),
620          op.get_attr("padding"),
621          data_format=op.get_attr("data_format"))
622  ]
623
624
625@ops.RegisterGradient("Dilation2D")
626def _Dilation2DGrad(op, grad):
627  return [
628      nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
629                                       op.get_attr("strides"),
630                                       op.get_attr("rates"),
631                                       op.get_attr("padding")),
632      nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
633                                        op.get_attr("strides"),
634                                        op.get_attr("rates"),
635                                        op.get_attr("padding"))
636  ]
637
638
639@ops.RegisterGradient("LRN")
640def _LRNGrad(op, grad):
641  depth_radius = op.get_attr("depth_radius")
642  bias = op.get_attr("bias")
643  alpha = op.get_attr("alpha")
644  beta = op.get_attr("beta")
645  return [
646      gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias,
647                          alpha, beta)
648  ]
649
650
651@ops.RegisterGradient("AvgPool")
652def _AvgPoolGrad(op, grad):
653  return gen_nn_ops.avg_pool_grad(
654      array_ops.shape(op.inputs[0]),
655      grad,
656      op.get_attr("ksize"),
657      op.get_attr("strides"),
658      op.get_attr("padding"),
659      data_format=op.get_attr("data_format"))
660
661
662@ops.RegisterGradient("AvgPoolGrad")
663def _AvgPoolGradGrad(op, grad):
664  return (array_ops.stop_gradient(op.inputs[0]),
665          gen_nn_ops.avg_pool(
666              grad,
667              op.get_attr("ksize"),
668              op.get_attr("strides"),
669              op.get_attr("padding"),
670              data_format=op.get_attr("data_format")))
671
672
673@ops.RegisterGradient("MaxPool")
674def _MaxPoolGrad(op, grad):
675  return gen_nn_ops.max_pool_grad(
676      op.inputs[0],
677      op.outputs[0],
678      grad,
679      op.get_attr("ksize"),
680      op.get_attr("strides"),
681      padding=op.get_attr("padding"),
682      data_format=op.get_attr("data_format"))
683
684
685@ops.RegisterGradient("MaxPoolV2")
686def _MaxPoolGradV2(op, grad):
687  ksize = op.inputs[1]
688  strides = op.inputs[2]
689  return gen_nn_ops.max_pool_grad_v2(
690      op.inputs[0],
691      op.outputs[0],
692      grad,
693      ksize,
694      strides,
695      padding=op.get_attr("padding"),
696      data_format=op.get_attr("data_format")), None, None
697
698
699@ops.RegisterGradient("MaxPoolWithArgmax")
700def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
701  del unused_argmax_grad
702  return gen_nn_ops.max_pool_grad_with_argmax(
703      op.inputs[0],
704      grad,
705      op.outputs[1],
706      op.get_attr("ksize"),
707      op.get_attr("strides"),
708      padding=op.get_attr("padding"),
709      include_batch_in_index=op.get_attr("include_batch_in_index"))
710
711
712@ops.RegisterGradient("MaxPoolGrad")
713def _MaxPoolGradGrad(op, grad):
714  return (array_ops.zeros(
715      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
716          array_ops.zeros(
717              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
718          gen_nn_ops.max_pool_grad_grad(
719              op.inputs[0],
720              op.inputs[1],
721              grad,
722              op.get_attr("ksize"),
723              op.get_attr("strides"),
724              padding=op.get_attr("padding"),
725              data_format=op.get_attr("data_format")))
726
727
728@ops.RegisterGradient("MaxPoolGradV2")
729def _MaxPoolGradGradV2(op, grad):
730  ksize = op.inputs[3]
731  strides = op.inputs[4]
732  return (array_ops.zeros(
733      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
734          array_ops.zeros(
735              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
736          gen_nn_ops.max_pool_grad_grad_v2(
737              op.inputs[0],
738              op.inputs[1],
739              grad,
740              ksize,
741              strides,
742              padding=op.get_attr("padding"),
743              data_format=op.get_attr("data_format")), None, None)
744
745
746@ops.RegisterGradient("MaxPoolGradGrad")
747def _MaxPoolGradGradGrad(op, grad):
748  return (array_ops.zeros(
749      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
750          array_ops.zeros(
751              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
752          gen_nn_ops.max_pool_grad(
753              op.inputs[0],
754              op.inputs[1],
755              grad,
756              op.get_attr("ksize"),
757              op.get_attr("strides"),
758              padding=op.get_attr("padding"),
759              data_format=op.get_attr("data_format")))
760
761
762@ops.RegisterGradient("FractionalMaxPool")
763def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
764  """Returns gradient for FractionalMaxPool.
765
766  Since FractionalMaxPool has three outputs, there are three gradients passed in
767  for each of the outputs. Only the first one is useful, the other two gradients
768  are empty.
769
770  Args:
771    op: The FractionalMaxPoolOp.
772    grad_0: Gradient with respect to op.outputs[0]
773    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
774    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
775
776  Returns:
777    Input backprop for FractionalMaxPool op.
778  """
779  return gen_nn_ops.fractional_max_pool_grad(
780      op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
781      op.get_attr("overlapping"))
782
783
784@ops.RegisterGradient("FractionalAvgPool")
785def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
786  """Returns gradient for FractionalAvgPool.
787
788  Since FractionalAvgPool has three outputs, there are three gradients passed in
789  for each of the outputs. Only the first one is useful, the other two gradients
790  are empty.
791
792  Args:
793    op: The FractionalAvgPoolOp.
794    grad_0: Gradient with respect to op.outputs[0]
795    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
796    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
797
798  Returns:
799    Input backprop for FractionalAvgPool op.
800  """
801  return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
802                                             op.outputs[1], op.outputs[2],
803                                             op.get_attr("overlapping"))
804
805
806@ops.RegisterGradient("BatchNormWithGlobalNormalization")
807def _BatchNormWithGlobalNormalizationGrad(op, grad):
808  """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
809
810  We do not backprop anything for the mean and var intentionally as they are
811  not being trained with backprop in the operation.
812
813  Args:
814    op: The BatchNormOp for which we need to generate gradients.
815    grad: Tensor.  The gradients passed to the BatchNormOp.
816
817  Returns:
818    dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
819    dm: Backprop for mean, which is
820        sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
821    dv: Backprop for variance, which is
822        sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
823    db: Backprop for beta, which is grad reduced in all except the
824        last dimension.
825    dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
826  """
827  dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad(
828      op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
829      op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
830  return dx, dm, dv, db, dg
831
832
833def _BaseFusedBatchNormGrad(op, use_v2, *grad):
834  """Return the gradients for the 3 inputs of BatchNorm.
835
836  Args:
837    op: The BatchNormOp for which we need to compute gradients.
838    use_v2: Boolean indicating whether to use the V2 version of the fused batch
839      norm gradient.
840    *grad: An argument list for tensors of gradients wrt the outputs with
841      grad[0] as grad_y.
842
843  Returns:
844    grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) *
845            [grad_y - mean(grad_y) - (x - mean(x)) *
846            mean(grad_y * (x - mean(x))) / (variance + epsilon)]
847            in training mode; grad_y * scale * rsqrt(pop_variance + epsilon)
848            in freeze mode.
849
850    grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) *
851                rsqrt(variance + epsilon)) in training mode;
852                sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon))
853                in freeze mode.
854
855    grad_offset: gradient for offset, which is sum(grad_y) in training mode;
856                 sum(grad_y) in freeze mode.
857  """
858  x = op.inputs[0]
859  grad_y = grad[0]
860  scale = op.inputs[1]
861  epsilon = op.get_attr("epsilon")
862  data_format = op.get_attr("data_format")
863  is_training = op.get_attr("is_training")
864  grad_fun = (
865      gen_nn_ops.fused_batch_norm_grad_v2
866      if use_v2 else gen_nn_ops.fused_batch_norm_grad)
867  if is_training:
868    return grad_fun(
869        grad_y,
870        x,
871        scale,
872        op.outputs[3],
873        op.outputs[4],
874        epsilon=epsilon,
875        data_format=data_format,
876        is_training=is_training)
877  else:
878    pop_mean = op.inputs[3]
879    pop_var = op.inputs[4]
880    if data_format == b"NCHW":
881      x = array_ops.transpose(x, [0, 2, 3, 1])
882      grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
883    dx, dscale, doffset, _, _ = grad_fun(
884        grad_y,
885        x,
886        scale,
887        pop_mean,
888        pop_var,
889        epsilon=epsilon,
890        data_format="NHWC",
891        is_training=is_training)
892    if data_format == b"NCHW":
893      dx = array_ops.transpose(dx, [0, 3, 1, 2])
894    return dx, dscale, doffset, None, None
895
896
897@ops.RegisterGradient("FusedBatchNorm")
898def _FusedBatchNormGrad(op, *grad):
899  return _BaseFusedBatchNormGrad(op, False, *grad)
900
901
902@ops.RegisterGradient("FusedBatchNormV2")
903def _FusedBatchNormV2Grad(op, *grad):
904  return _BaseFusedBatchNormGrad(op, True, *grad)
905
906
907def _BatchNormGrad(grad_y,
908                   x,
909                   scale,
910                   pop_mean,
911                   pop_var,
912                   epsilon,
913                   data_format,
914                   is_training=True):
915  """Returns the gradients for the 3 inputs of BatchNorm.
916
917  Args:
918    grad_y: A `Tensor` of 4 dimensions for gradient for y.
919    x: A `Tensor` of 4 dimensions for x.
920    scale: A `Tensor` of 1 dimension for scaling.
921    pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
922      is_training=False.
923    pop_var: A `Tensor` of 1 dimension for the population variance. Only used
924      when is_training=False.
925    epsilon: A small float number added to the variance of x.
926    data_format: The data format for input. Either b"NHWC" or b"NCHW".
927    is_training: A bool value to indicate the operation is for training
928      (default) or inference.
929
930  Returns:
931    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
932    for x, grad_scale the gradient for scale, and grad_offset the gradient
933    for offset.
934  """
935  x_dtype = x.dtype.base_dtype
936  if x_dtype == dtypes.float16:
937    # float16 math is too imprecise, so we do the batch norm gradient
938    # computations in float32.
939    x = math_ops.cast(x, dtypes.float32)
940    grad_y = math_ops.cast(grad_y, dtypes.float32)
941  if is_training:
942    if data_format == b"NHWC":
943      keepdims = False
944      reduce_axis = [0, 1, 2]
945    else:
946      keepdims = True
947      reduce_axis = [0, 2, 3]
948      shape = [1, array_ops.size(scale), 1, 1]
949      scale = array_ops.reshape(scale, shape)
950    mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
951    mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
952    var_x = math_ops.reduce_mean(
953        math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
954        reduce_axis,
955        keepdims=keepdims)
956    grad_y_offset = grad_y - mean_grad_y
957    x_offset = x - mean_x
958    mean = math_ops.reduce_mean(
959        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
960    grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
961        grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
962    grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
963        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
964    if data_format == b"NCHW":
965      grad_scale = array_ops.squeeze(grad_scale)
966    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
967    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
968  else:
969    if data_format == b"NHWC":
970      reduce_axis = [0, 1, 2]
971    else:
972      reduce_axis = [0, 2, 3]
973      shape = [1, array_ops.size(pop_mean), 1, 1]
974      pop_mean = array_ops.reshape(pop_mean, shape)
975      pop_var = array_ops.reshape(pop_var, shape)
976      scale = array_ops.reshape(scale, shape)
977
978    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
979    var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
980    grad_scale = math_ops.reduce_sum(
981        grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis)
982    grad_x = grad_y * scale * var_rsqrt
983    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
984
985
986@ops.RegisterGradient("FusedBatchNormGrad")
987def _FusedBatchNormGradGrad(op, *grad):
988  """Returns the gradients for the 3 inputs of FusedBatchNormGrad.
989
990  Args:
991    op: The FusedBatchNormGradOp for which we need to compute gradients.
992    *grad: An argument list for tensors of gradients wrt the outputs with
993      grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as
994      grad_grad_offset.
995
996  Returns:
997    A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y
998    is the gradient for grad_y, grad_x the gradient for x, grad_scale the
999    gradient for scale.
1000  """
1001  data_format = op.get_attr("data_format")
1002  epsilon = op.get_attr("epsilon")
1003  is_training = op.get_attr("is_training")
1004  grad_y = op.inputs[0]
1005  x = op.inputs[1]
1006  scale = op.inputs[2]
1007  pop_mean = op.inputs[3]
1008  pop_var = op.inputs[4]
1009  grad_grad_x = grad[0]
1010  grad_grad_scale = grad[1]
1011  grad_grad_offset = grad[2]
1012  with backprop.GradientTape() as tape:
1013    tape.watch(grad_y)
1014    tape.watch(x)
1015    tape.watch(scale)
1016    grad_x, grad_scale, grad_offset = _BatchNormGrad(
1017        grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
1018    grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
1019  grad_grad_y, grad_x, grad_scale = tape.gradient(
1020      [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
1021  return grad_grad_y, grad_x, grad_scale, None, None
1022
1023
1024@ops.RegisterGradient("FusedBatchNormGradV2")
1025def _FusedBatchNormGradGradV2(op, *grad):
1026  return _FusedBatchNormGradGrad(op, *grad)
1027
1028
1029@ops.RegisterGradient("L2Loss")
1030def _L2LossGrad(op, grad):
1031  """Return the gradients for L2Loss.
1032
1033  Args:
1034    op: The L2LossOp for which we need to generate gradients.
1035    grad: Tensor containing a single number.
1036
1037  Returns:
1038    The gradient, which is (x * grad).
1039  """
1040  return op.inputs[0] * grad
1041
1042
1043@ops.RegisterGradient("TopK")
1044@ops.RegisterGradient("TopKV2")
1045def _TopKGrad(op, grad, _):
1046  """Return the gradients for TopK.
1047
1048  Args:
1049    op: The TopKOp for which we need to generate gradients.
1050    grad: Tensor. The gradients passed to the TopKOp.
1051
1052  Returns:
1053    A list of two tensors, the first being the gradient w.r.t to the input and
1054    TopK, and the second being the gradient w.r.t. to the indices (all zero).
1055  """
1056  in_shape = array_ops.shape(op.inputs[0])
1057  ind_shape = array_ops.shape(op.outputs[1])
1058
1059  # int32 is not supported on GPU hence up-casting
1060  ind_lastdim = array_ops.gather(
1061      math_ops.cast(ind_shape, dtypes.int64),
1062      array_ops.size(ind_shape) - 1)
1063  # Flatten indices to 2D.
1064  ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
1065
1066  in_lastdim = array_ops.gather(
1067      math_ops.cast(in_shape, dtypes.int64),
1068      array_ops.size(in_shape) - 1)
1069  outerdim = array_ops.shape(ind_2d)[0]
1070  # Compute linear indices (flattened to 1D).
1071  ind = array_ops.reshape(
1072      ind_2d + math_ops.cast(
1073          array_ops.expand_dims(
1074              math_ops.range(0,
1075                             math_ops.cast(outerdim, dtypes.int64) * in_lastdim,
1076                             in_lastdim), -1), dtypes.int32), [-1])
1077
1078  # Substitute grad to appropriate locations and fill the rest with zeros,
1079  # finally reshaping it to the original input shape.
1080  return [
1081      array_ops.reshape(
1082          array_ops.scatter_nd(
1083              array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]),
1084              [math_ops.reduce_prod(in_shape)]), in_shape),
1085      array_ops.zeros([], dtype=dtypes.int32)
1086  ]
1087
1088
1089@ops.RegisterGradient("NthElement")
1090def _NthElementGrad(op, grad):
1091  """Return the gradients for NthElement.
1092
1093  Args:
1094    op: The NthElementOp for which we need to generate gradients.
1095    grad: Tensor. The gradients passed to the NthElementOp
1096
1097  Returns:
1098    A list of two tensors, the first being the gradient w.r.t. the input,
1099    the second being the gradient w.r.t. the N (None).
1100  """
1101  input = op.inputs[0]  # pylint: disable=redefined-builtin
1102  output = op.outputs[0]
1103
1104  # Compute the number of elements which equal to output in each reduction
1105  # dimension. If there are multiple elements then the gradient will be
1106  # divided between them.
1107  indicators = math_ops.cast(
1108      math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
1109
1110  grad = array_ops.expand_dims(grad, -1)
1111  num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
1112
1113  return [math_ops.div(indicators, num_selected) * grad, None]
1114