1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in nn_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_nn_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn_ops
28
29
30@ops.RegisterGradient("Conv2DBackpropInput")
31def _Conv2DBackpropInputGrad(op, grad):
32  """The derivatives for deconvolution.
33
34  Args:
35    op: the Deconvolution op.
36    grad: the tensor representing the gradient w.r.t. the output
37
38  Returns:
39    the gradients w.r.t. the input and the filter
40  """
41  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
42  # functions for performance reasons in Eager mode. See _Conv2DGrad.
43  return [
44      None,
45      gen_nn_ops.conv2d_backprop_filter(
46          grad,
47          array_ops.shape(op.inputs[1]),
48          op.inputs[2],
49          dilations=op.get_attr("dilations"),
50          strides=op.get_attr("strides"),
51          padding=op.get_attr("padding"),
52          explicit_paddings=op.get_attr("explicit_paddings"),
53          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
54          data_format=op.get_attr("data_format").decode()),
55      gen_nn_ops.conv2d(
56          grad,
57          op.inputs[1],
58          dilations=op.get_attr("dilations"),
59          strides=op.get_attr("strides"),
60          padding=op.get_attr("padding"),
61          explicit_paddings=op.get_attr("explicit_paddings"),
62          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
63          data_format=op.get_attr("data_format").decode())
64  ]
65
66
67@ops.RegisterGradient("Conv2DBackpropFilter")
68def _Conv2DBackpropFilterGrad(op, grad):
69  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
70  # functions for performance reasons in Eager mode. See _Conv2DGrad.
71  return [
72      gen_nn_ops.conv2d_backprop_input(
73          array_ops.shape(op.inputs[0]),
74          grad,
75          op.inputs[2],
76          dilations=op.get_attr("dilations"),
77          strides=op.get_attr("strides"),
78          padding=op.get_attr("padding"),
79          explicit_paddings=op.get_attr("explicit_paddings"),
80          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
81          data_format=op.get_attr("data_format").decode()), None,
82      gen_nn_ops.conv2d(
83          op.inputs[0],
84          grad,
85          dilations=op.get_attr("dilations"),
86          strides=op.get_attr("strides"),
87          padding=op.get_attr("padding"),
88          explicit_paddings=op.get_attr("explicit_paddings"),
89          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
90          data_format=op.get_attr("data_format").decode())
91  ]
92
93
94@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput")
95def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
96  """The derivatives for deconvolution.
97
98  Args:
99    op: the Deconvolution op.
100    grad: the tensor representing the gradient w.r.t. the output
101
102  Returns:
103    the gradients w.r.t. the input and the filter
104  """
105  return [
106      None,
107      gen_nn_ops.depthwise_conv2d_native_backprop_filter(
108          grad,
109          array_ops.shape(op.inputs[1]),
110          op.inputs[2],
111          dilations=op.get_attr("dilations"),
112          strides=op.get_attr("strides"),
113          padding=op.get_attr("padding"),
114          explicit_paddings=op.get_attr("explicit_paddings"),
115          data_format=op.get_attr("data_format")),
116      gen_nn_ops.depthwise_conv2d_native(
117          grad,
118          op.inputs[1],
119          dilations=op.get_attr("dilations"),
120          strides=op.get_attr("strides"),
121          padding=op.get_attr("padding"),
122          explicit_paddings=op.get_attr("explicit_paddings"),
123          data_format=op.get_attr("data_format"))
124  ]
125
126
127@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter")
128def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad):
129  return [
130      gen_nn_ops.depthwise_conv2d_native_backprop_input(
131          array_ops.shape(op.inputs[0]),
132          grad,
133          op.inputs[2],
134          dilations=op.get_attr("dilations"),
135          strides=op.get_attr("strides"),
136          padding=op.get_attr("padding"),
137          explicit_paddings=op.get_attr("explicit_paddings"),
138          data_format=op.get_attr("data_format")), None,
139      gen_nn_ops.depthwise_conv2d_native(
140          op.inputs[0],
141          grad,
142          dilations=op.get_attr("dilations"),
143          strides=op.get_attr("strides"),
144          padding=op.get_attr("padding"),
145          explicit_paddings=op.get_attr("explicit_paddings"),
146          data_format=op.get_attr("data_format"))
147  ]
148
149
150@ops.RegisterGradient("Conv3D")
151def _Conv3DGrad(op, grad):
152  data_format = op.get_attr("data_format").decode()
153  return [
154      nn_ops.conv3d_backprop_input_v2(
155          array_ops.shape(op.inputs[0]),
156          op.inputs[1],
157          grad,
158          dilations=op.get_attr("dilations"),
159          strides=op.get_attr("strides"),
160          padding=op.get_attr("padding"),
161          data_format=data_format),
162      nn_ops.conv3d_backprop_filter_v2(
163          op.inputs[0],
164          array_ops.shape(op.inputs[1]),
165          grad,
166          dilations=op.get_attr("dilations"),
167          strides=op.get_attr("strides"),
168          padding=op.get_attr("padding"),
169          data_format=data_format)
170  ]
171
172
173@ops.RegisterGradient("Conv3DBackpropInputV2")
174def _Conv3DBackpropInputGrad(op, grad):
175  data_format = op.get_attr("data_format").decode()
176  return [
177      None,
178      nn_ops.conv3d_backprop_filter_v2(
179          grad,
180          array_ops.shape(op.inputs[1]),
181          op.inputs[2],
182          dilations=op.get_attr("dilations"),
183          strides=op.get_attr("strides"),
184          padding=op.get_attr("padding"),
185          data_format=data_format),
186      nn_ops.conv3d(
187          grad,
188          op.inputs[1],
189          dilations=op.get_attr("dilations"),
190          strides=op.get_attr("strides"),
191          padding=op.get_attr("padding"),
192          data_format=data_format)
193  ]
194
195
196@ops.RegisterGradient("Conv3DBackpropFilterV2")
197def _Conv3DBackpropFilterGrad(op, grad):
198  data_format = op.get_attr("data_format").decode()
199  return [
200      nn_ops.conv3d_backprop_input_v2(
201          array_ops.shape(op.inputs[0]),
202          grad,
203          op.inputs[2],
204          dilations=op.get_attr("dilations"),
205          strides=op.get_attr("strides"),
206          padding=op.get_attr("padding"),
207          data_format=data_format), None,
208      nn_ops.conv3d(
209          op.inputs[0],
210          grad,
211          dilations=op.get_attr("dilations"),
212          strides=op.get_attr("strides"),
213          padding=op.get_attr("padding"),
214          data_format=data_format)
215  ]
216
217
218@ops.RegisterGradient("AvgPool3D")
219def _AvgPool3DGrad(op, grad):
220  return gen_nn_ops.avg_pool3d_grad(
221      array_ops.shape(op.inputs[0]),
222      grad,
223      ksize=op.get_attr("ksize"),
224      strides=op.get_attr("strides"),
225      padding=op.get_attr("padding"),
226      data_format=op.get_attr("data_format").decode())
227
228
229@ops.RegisterGradient("AvgPool3DGrad")
230def _AvgPool3DGradGrad(op, grad):
231  return (array_ops.stop_gradient(op.inputs[0]),
232          gen_nn_ops.avg_pool3d(
233              grad,
234              op.get_attr("ksize"),
235              op.get_attr("strides"),
236              op.get_attr("padding"),
237              data_format=op.get_attr("data_format").decode()))
238
239
240@ops.RegisterGradient("MaxPool3D")
241def _MaxPool3DGrad(op, grad):
242  return gen_nn_ops.max_pool3d_grad(
243      op.inputs[0],
244      op.outputs[0],
245      grad,
246      ksize=op.get_attr("ksize"),
247      strides=op.get_attr("strides"),
248      padding=op.get_attr("padding"),
249      data_format=op.get_attr("data_format").decode())
250
251
252@ops.RegisterGradient("MaxPool3DGrad")
253def _MaxPool3DGradGrad(op, grad):
254  return (array_ops.zeros(
255      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
256          array_ops.zeros(
257              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
258          gen_nn_ops.max_pool3d_grad_grad(
259              op.inputs[0],
260              op.inputs[1],
261              grad,
262              op.get_attr("ksize"),
263              op.get_attr("strides"),
264              padding=op.get_attr("padding"),
265              data_format=op.get_attr("data_format").decode()))
266
267
268@ops.RegisterGradient("MaxPool3DGradGrad")
269def _MaxPool3DGradGradGrad(op, grad):
270  return (array_ops.zeros(
271      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
272          array_ops.zeros(
273              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
274          gen_nn_ops.max_pool3d_grad(
275              op.inputs[0],
276              op.inputs[1],
277              grad,
278              op.get_attr("ksize"),
279              op.get_attr("strides"),
280              padding=op.get_attr("padding"),
281              data_format=op.get_attr("data_format").decode()))
282
283
284@ops.RegisterGradient("Softmax")
285def _SoftmaxGrad(op, grad_softmax):
286  """The derivative of the softmax nonlinearity.
287
288  We assume that probs is of shape [batch_size * dim]
289  The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
290  This matrix is diagonal minus a rank one matrix, so it is easy to implement
291  as follows:
292
293    grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
294
295  Args:
296     op: the Softmax op.
297     grad_softmax:  the tensor representing the gradient w.r.t. the softmax
298       output.
299
300  Returns:
301     gradient w.r.t the input to the softmax
302
303  """
304  softmax = op.outputs[0]
305  sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True)
306  return (grad_softmax - sum_channels) * softmax
307
308
309@ops.RegisterGradient("LogSoftmax")
310def _LogSoftmaxGrad(op, grad):
311  """The gradient for log_softmax.
312
313      log_softmax = input - log(sum(exp(input))
314      dlog_softmax/dinput = diag - softmax(input)
315
316  Args:
317    op: The log softmax op.
318    grad: The tensor representing the gradient w.r.t. the output.
319
320  Returns:
321    The gradients w.r.t. the input.
322  """
323  softmax = math_ops.exp(op.outputs[0])
324  return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax
325
326
327@ops.RegisterGradient("BiasAdd")
328def _BiasAddGrad(op, received_grad):
329  """Return the gradients for the 2 inputs of bias_op.
330
331  The first input of unused_bias_op is the tensor t, and its gradient is
332  just the gradient the unused_bias_op received.
333
334  The second input of unused_bias_op is the bias vector which has one fewer
335  dimension than "received_grad" (the batch dimension.)  Its gradient is the
336  received gradient Summed on the batch dimension, which is the first dimension.
337
338  Args:
339    op: The BiasOp for which we need to generate gradients.
340    received_grad: Tensor.  The gradients passed to the BiasOp.
341
342  Returns:
343    Two tensors, the first one for the "tensor" input of the BiasOp,
344    the second one for the "bias" input of the BiasOp.
345  """
346  try:
347    data_format = op.get_attr("data_format")
348  except ValueError:
349    data_format = None
350  return (received_grad,
351          gen_nn_ops.bias_add_grad(
352              out_backprop=received_grad, data_format=data_format))
353
354
355@ops.RegisterGradient("BiasAddGrad")
356def _BiasAddGradGrad(op, received_grad):
357  """Gradient for the BiasAddGrad op.
358
359  Args:
360    op: BiasAddGrad op for which we are calculating gradients.
361    received_grad: The gradients passed to the BiasAddGrad op.
362
363  Returns:
364    A single gradient Tensor for the input to BiasAddGrad (which
365    is the gradient of the bias term in BiasAdd)
366  """
367
368  try:
369    data_format = op.get_attr("data_format")
370  except ValueError:
371    data_format = None
372
373  shape = array_ops.shape(op.inputs[0])
374  bias_shape = array_ops.shape(received_grad)
375
376  if data_format == b"NCHW":
377    expanded_shape = array_ops.concat([
378        array_ops.ones_like(shape[:1]), bias_shape,
379        array_ops.ones_like(shape[2:])
380    ], 0)
381    tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
382  else:
383    expanded_shape = array_ops.concat(
384        [array_ops.ones_like(shape[:-1]), bias_shape], 0)
385    tile_mults = array_ops.concat([shape[:-1], [1]], 0)
386
387  expanded_grad = array_ops.reshape(received_grad, expanded_shape)
388  return array_ops.tile(expanded_grad, tile_mults)
389
390
391@ops.RegisterGradient("BiasAddV1")
392def _BiasAddGradV1(unused_bias_op, received_grad):
393  """Return the gradients for the 2 inputs of bias_op.
394
395  The first input of unused_bias_op is the tensor t, and its gradient is
396  just the gradient the unused_bias_op received.
397
398  The second input of unused_bias_op is the bias vector which has one fewer
399  dimension than "received_grad" (the batch dimension.)  Its gradient is the
400  received gradient Summed on the batch dimension, which is the first dimension.
401
402  Args:
403    unused_bias_op: The BiasOp for which we need to generate gradients.
404    received_grad: Tensor.  The gradients passed to the BiasOp.
405
406  Returns:
407    Two tensors, the first one for the "tensor" input of the BiasOp,
408    the second one for the "bias" input of the BiasOp.
409  """
410  reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
411  return (received_grad, math_ops.reduce_sum(received_grad,
412                                             reduction_dim_tensor))
413
414
415@ops.RegisterGradient("Relu")
416def _ReluGrad(op, grad):
417  return gen_nn_ops.relu_grad(grad, op.outputs[0])
418
419
420@ops.RegisterGradient("EluGrad")
421def _EluGradGrad(op, grad):
422  elu_x = op.inputs[1]
423  return (gen_nn_ops.elu_grad(grad, elu_x),
424          array_ops.where(
425              elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x)))
426
427
428@ops.RegisterGradient("SeluGrad")
429def _SeluGradGrad(op, grad):
430  selu_x = op.inputs[1]
431  return (gen_nn_ops.selu_grad(grad, selu_x),
432          array_ops.where(
433              selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x)))
434
435
436@ops.RegisterGradient("Relu6")
437def _Relu6Grad(op, grad):
438  return gen_nn_ops.relu6_grad(grad, op.outputs[0])
439
440
441@ops.RegisterGradient("Relu6Grad")
442def _Relu6GradGrad(op, grad):
443  x = op.inputs[1]
444  return (gen_nn_ops.relu6_grad(grad, x),
445          array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
446
447
448@ops.RegisterGradient("LeakyRelu")
449def _LeakyReluGrad(op, grad):
450  x = op.inputs[0]
451  alpha = op.get_attr("alpha")
452  return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
453
454
455@ops.RegisterGradient("LeakyReluGrad")
456def _LeakyReluGradGrad(op, grad):
457  x = op.inputs[1]
458  alpha = op.get_attr("alpha")
459  return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
460          array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
461
462
463@ops.RegisterGradient("Elu")
464def _EluGrad(op, grad):
465  return gen_nn_ops.elu_grad(grad, op.outputs[0])
466
467
468@ops.RegisterGradient("Selu")
469def _SeluGrad(op, grad):
470  return gen_nn_ops.selu_grad(grad, op.outputs[0])
471
472
473@ops.RegisterGradient("Softplus")
474def _SoftplusGrad(op, grad):
475  return grad * math_ops.sigmoid(op.inputs[0])
476
477
478@ops.RegisterGradient("SoftplusGrad")
479def _SoftplusGradGrad(op, grad):
480  # Let:
481  #   y = tf.nn.softplus(x)
482  #   dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x))
483  # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx.
484  dy, x = op.inputs
485  with ops.control_dependencies([grad]):
486    ddy = gen_nn_ops.softplus_grad(grad, x)
487    d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x))
488    return (ddy, d2x)
489
490
491@ops.RegisterGradient("Softsign")
492def _SoftsignGrad(op, grad):
493  return gen_nn_ops.softsign_grad(grad, op.inputs[0])
494
495
496@ops.RegisterGradient("ReluGrad")
497def _ReluGradGrad(op, grad):
498  x = op.inputs[1]
499  return (gen_nn_ops.relu_grad(grad, x),
500          array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
501
502
503def _BroadcastMul(vec, mat):
504  """Multiply after broadcasting vec to match dimensions of mat.
505
506  Args:
507    vec: A 1-D tensor of dimension [D0]
508    mat: A 2-D tensor of dimension [D0, D1]
509
510  Returns:
511    A tensor of dimension [D0, D1], the result of vec * mat
512  """
513  # Reshape vec to [D0, 1]
514  vec = array_ops.expand_dims(vec, -1)
515  return vec * mat
516
517
518@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
519def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
520  """Gradient function for SoftmaxCrossEntropyWithLogits."""
521  # grad_loss is the backprop for cost, and we multiply it with the gradients
522  # (which is output[1])
523  # grad_grad is the backprop for softmax gradient.
524  #
525  # Second derivative is just softmax derivative w.r.t. logits.
526  softmax_grad = op.outputs[1]
527  grad = _BroadcastMul(grad_loss, softmax_grad)
528
529  logits = op.inputs[0]
530  if (grad_grad is not None and
531      not getattr(grad_grad, "_is_zeros_tensor", False)):
532    softmax = nn_ops.softmax(logits)
533
534    grad += ((grad_grad - array_ops.squeeze(
535        math_ops.matmul(
536            array_ops.expand_dims(grad_grad, 1),
537            array_ops.expand_dims(softmax, 2)),
538        axis=1)) * softmax)
539
540  return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
541
542
543@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
544def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
545  """Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
546  # grad_loss is the backprop for cost, and we multiply it with the gradients
547  # (which is output[1])
548  # grad_grad is the backprop for softmax gradient.
549  # There is no gradient for the labels
550  #
551  # Second derivative is just softmax derivative w.r.t. logits.
552  softmax_grad = op.outputs[1]
553  grad = _BroadcastMul(grad_loss, softmax_grad)
554
555  logits = op.inputs[0]
556  if (grad_grad is not None and
557      not getattr(grad_grad, "_is_zeros_tensor", False)):
558    softmax = nn_ops.softmax(logits)
559
560    grad += ((grad_grad - array_ops.squeeze(
561        math_ops.matmul(
562            array_ops.expand_dims(grad_grad, 1),
563            array_ops.expand_dims(softmax, 2)),
564        axis=1)) * softmax)
565
566  return grad, None
567
568
569@ops.RegisterGradient("Conv2D")
570def _Conv2DGrad(op, grad):
571  """Gradient function for Conv2D."""
572  dilations = op.get_attr("dilations")
573  strides = op.get_attr("strides")
574  padding = op.get_attr("padding")
575  explicit_paddings = op.get_attr("explicit_paddings")
576  use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
577  data_format = op.get_attr("data_format")
578  shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
579
580  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
581  # functions for performance reasons in Eager mode. gen_nn_ops functions take a
582  # `explicit_paddings` parameter, but nn_ops functions do not. So if we were
583  # to use the nn_ops functions, we would have to convert `padding` and
584  # `explicit_paddings` into a single `padding` parameter, increasing overhead
585  # in Eager mode.
586  return [
587      gen_nn_ops.conv2d_backprop_input(
588          shape_0,
589          op.inputs[1],
590          grad,
591          dilations=dilations,
592          strides=strides,
593          padding=padding,
594          explicit_paddings=explicit_paddings,
595          use_cudnn_on_gpu=use_cudnn_on_gpu,
596          data_format=data_format),
597      gen_nn_ops.conv2d_backprop_filter(
598          op.inputs[0],
599          shape_1,
600          grad,
601          dilations=dilations,
602          strides=strides,
603          padding=padding,
604          explicit_paddings=explicit_paddings,
605          use_cudnn_on_gpu=use_cudnn_on_gpu,
606          data_format=data_format)
607  ]
608
609
610@ops.RegisterGradient("DepthwiseConv2dNative")
611def _DepthwiseConv2dNativeGrad(op, grad):
612  return [
613      gen_nn_ops.depthwise_conv2d_native_backprop_input(
614          array_ops.shape(op.inputs[0]),
615          op.inputs[1],
616          grad,
617          dilations=op.get_attr("dilations"),
618          strides=op.get_attr("strides"),
619          padding=op.get_attr("padding"),
620          explicit_paddings=op.get_attr("explicit_paddings"),
621          data_format=op.get_attr("data_format")),
622      gen_nn_ops.depthwise_conv2d_native_backprop_filter(
623          op.inputs[0],
624          array_ops.shape(op.inputs[1]),
625          grad,
626          dilations=op.get_attr("dilations"),
627          strides=op.get_attr("strides"),
628          padding=op.get_attr("padding"),
629          explicit_paddings=op.get_attr("explicit_paddings"),
630          data_format=op.get_attr("data_format"))
631  ]
632
633
634@ops.RegisterGradient("Dilation2D")
635def _Dilation2DGrad(op, grad):
636  return [
637      nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
638                                       op.get_attr("strides"),
639                                       op.get_attr("rates"),
640                                       op.get_attr("padding")),
641      nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
642                                        op.get_attr("strides"),
643                                        op.get_attr("rates"),
644                                        op.get_attr("padding"))
645  ]
646
647
648@ops.RegisterGradient("LRN")
649def _LRNGrad(op, grad):
650  depth_radius = op.get_attr("depth_radius")
651  bias = op.get_attr("bias")
652  alpha = op.get_attr("alpha")
653  beta = op.get_attr("beta")
654  return [
655      gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias,
656                          alpha, beta)
657  ]
658
659
660@ops.RegisterGradient("AvgPool")
661def _AvgPoolGrad(op, grad):
662  return gen_nn_ops.avg_pool_grad(
663      array_ops.shape(op.inputs[0]),
664      grad,
665      op.get_attr("ksize"),
666      op.get_attr("strides"),
667      op.get_attr("padding"),
668      data_format=op.get_attr("data_format"))
669
670
671@ops.RegisterGradient("AvgPoolGrad")
672def _AvgPoolGradGrad(op, grad):
673  return (array_ops.stop_gradient(op.inputs[0]),
674          gen_nn_ops.avg_pool(
675              grad,
676              op.get_attr("ksize"),
677              op.get_attr("strides"),
678              op.get_attr("padding"),
679              data_format=op.get_attr("data_format")))
680
681
682@ops.RegisterGradient("MaxPool")
683def _MaxPoolGrad(op, grad):
684  return gen_nn_ops.max_pool_grad(
685      op.inputs[0],
686      op.outputs[0],
687      grad,
688      op.get_attr("ksize"),
689      op.get_attr("strides"),
690      padding=op.get_attr("padding"),
691      explicit_paddings=op.get_attr("explicit_paddings"),
692      data_format=op.get_attr("data_format"))
693
694
695@ops.RegisterGradient("MaxPoolV2")
696def _MaxPoolGradV2(op, grad):
697  ksize = op.inputs[1]
698  strides = op.inputs[2]
699  return gen_nn_ops.max_pool_grad_v2(
700      op.inputs[0],
701      op.outputs[0],
702      grad,
703      ksize,
704      strides,
705      padding=op.get_attr("padding"),
706      data_format=op.get_attr("data_format")), None, None
707
708
709@ops.RegisterGradient("MaxPoolWithArgmax")
710def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
711  del unused_argmax_grad
712  return gen_nn_ops.max_pool_grad_with_argmax(
713      op.inputs[0],
714      grad,
715      op.outputs[1],
716      op.get_attr("ksize"),
717      op.get_attr("strides"),
718      padding=op.get_attr("padding"),
719      include_batch_in_index=op.get_attr("include_batch_in_index"))
720
721
722@ops.RegisterGradient("MaxPoolGrad")
723def _MaxPoolGradGrad(op, grad):
724  return (array_ops.zeros(
725      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
726          array_ops.zeros(
727              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
728          gen_nn_ops.max_pool_grad_grad(
729              op.inputs[0],
730              op.inputs[1],
731              grad,
732              op.get_attr("ksize"),
733              op.get_attr("strides"),
734              padding=op.get_attr("padding"),
735              data_format=op.get_attr("data_format")))
736
737
738@ops.RegisterGradient("MaxPoolGradV2")
739def _MaxPoolGradGradV2(op, grad):
740  ksize = op.inputs[3]
741  strides = op.inputs[4]
742  return (array_ops.zeros(
743      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
744          array_ops.zeros(
745              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
746          gen_nn_ops.max_pool_grad_grad_v2(
747              op.inputs[0],
748              op.inputs[1],
749              grad,
750              ksize,
751              strides,
752              padding=op.get_attr("padding"),
753              data_format=op.get_attr("data_format")), None, None)
754
755
756@ops.RegisterGradient("MaxPoolGradGrad")
757def _MaxPoolGradGradGrad(op, grad):
758  return (array_ops.zeros(
759      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
760          array_ops.zeros(
761              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
762          gen_nn_ops.max_pool_grad(
763              op.inputs[0],
764              op.inputs[1],
765              grad,
766              op.get_attr("ksize"),
767              op.get_attr("strides"),
768              padding=op.get_attr("padding"),
769              data_format=op.get_attr("data_format")))
770
771
772@ops.RegisterGradient("FractionalMaxPool")
773def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
774  """Returns gradient for FractionalMaxPool.
775
776  Since FractionalMaxPool has three outputs, there are three gradients passed in
777  for each of the outputs. Only the first one is useful, the other two gradients
778  are empty.
779
780  Args:
781    op: The FractionalMaxPoolOp.
782    grad_0: Gradient with respect to op.outputs[0]
783    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
784    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
785
786  Returns:
787    Input backprop for FractionalMaxPool op.
788  """
789  return gen_nn_ops.fractional_max_pool_grad(
790      op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
791      op.get_attr("overlapping"))
792
793
794@ops.RegisterGradient("FractionalAvgPool")
795def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
796  """Returns gradient for FractionalAvgPool.
797
798  Since FractionalAvgPool has three outputs, there are three gradients passed in
799  for each of the outputs. Only the first one is useful, the other two gradients
800  are empty.
801
802  Args:
803    op: The FractionalAvgPoolOp.
804    grad_0: Gradient with respect to op.outputs[0]
805    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
806    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
807
808  Returns:
809    Input backprop for FractionalAvgPool op.
810  """
811  return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
812                                             op.outputs[1], op.outputs[2],
813                                             op.get_attr("overlapping"))
814
815
816@ops.RegisterGradient("BatchNormWithGlobalNormalization")
817def _BatchNormWithGlobalNormalizationGrad(op, grad):
818  """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
819
820  We do not backprop anything for the mean and var intentionally as they are
821  not being trained with backprop in the operation.
822
823  Args:
824    op: The BatchNormOp for which we need to generate gradients.
825    grad: Tensor.  The gradients passed to the BatchNormOp.
826
827  Returns:
828    dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
829    dm: Backprop for mean, which is
830        sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
831    dv: Backprop for variance, which is
832        sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
833    db: Backprop for beta, which is grad reduced in all except the
834        last dimension.
835    dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
836  """
837  dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad(
838      op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
839      op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
840  return dx, dm, dv, db, dg
841
842
843def _BaseFusedBatchNormGrad(op, version, *grad):
844  """Return the gradients for the 3 inputs of BatchNorm.
845
846  Args:
847    op: The BatchNormOp for which we need to compute gradients.
848    version: Integer indicating which version to use of the fused batch
849      norm gradient.
850    *grad: An argument list for tensors of gradients wrt the outputs
851      with grad[0] as grad_y.
852
853  Returns:
854    grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) *
855            [grad_y - mean(grad_y) - (x - mean(x)) *
856            mean(grad_y * (x - mean(x))) / (variance + epsilon)]
857            in training mode; grad_y * scale * rsqrt(pop_variance + epsilon)
858            in freeze mode.
859
860    grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) *
861                rsqrt(variance + epsilon)) in training mode;
862                sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon))
863                in freeze mode.
864
865    grad_offset: gradient for offset, which is sum(grad_y) in training mode;
866                 sum(grad_y) in freeze mode.
867  """
868  x = op.inputs[0]
869  grad_y = grad[0]
870  scale = op.inputs[1]
871  epsilon = op.get_attr("epsilon")
872  data_format = op.get_attr("data_format")
873  is_training = op.get_attr("is_training")
874  if version == 2:
875    grad_fun = gen_nn_ops.fused_batch_norm_grad_v3
876  elif version == 1:
877    grad_fun = gen_nn_ops.fused_batch_norm_grad_v2
878  else:
879    grad_fun = gen_nn_ops.fused_batch_norm_grad
880  if is_training:
881    args = {
882        "y_backprop": grad_y,
883        "x": x,
884        "scale": scale,
885        "reserve_space_1": op.outputs[3],
886        "reserve_space_2": op.outputs[4],
887        "epsilon": epsilon,
888        "data_format": data_format,
889        "is_training": is_training
890    }
891    if version == 2:
892      args["reserve_space_3"] = op.outputs[5]
893    dx, dscale, doffset, _, _ = grad_fun(**args)
894  else:
895    pop_mean = op.inputs[3]
896    pop_var = op.inputs[4]
897    if data_format == b"NCHW":
898      x = array_ops.transpose(x, [0, 2, 3, 1])
899      grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
900    elif data_format == b"NCDHW":
901      x = array_ops.transpose(x, [0, 2, 3, 4, 1])
902      grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
903    target_data_format = ("NHWC" if data_format in (b"NCHW",
904                                                    b"NHWC") else "NDHWC")
905    args = {
906        "y_backprop": grad_y,
907        "x": x,
908        "scale": scale,
909        "reserve_space_1": pop_mean,
910        "reserve_space_2": pop_var,
911        "epsilon": epsilon,
912        "data_format": target_data_format,
913        "is_training": is_training
914    }
915    if version == 2:
916      args["reserve_space_3"] = op.outputs[5]
917    dx, dscale, doffset, _, _ = grad_fun(**args)
918    if data_format == b"NCHW":
919      dx = array_ops.transpose(dx, [0, 3, 1, 2])
920    elif data_format == b"NCDHW":
921      dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
922  return dx, dscale, doffset, None, None
923
924
925@ops.RegisterGradient("FusedBatchNorm")
926def _FusedBatchNormGrad(op, *grad):
927  return _BaseFusedBatchNormGrad(op, 0, *grad)
928
929
930@ops.RegisterGradient("FusedBatchNormV2")
931def _FusedBatchNormV2Grad(op, *grad):
932  return _BaseFusedBatchNormGrad(op, 1, *grad)
933
934
935@ops.RegisterGradient("FusedBatchNormV3")
936def _FusedBatchNormV3Grad(op, *grad):
937  return _BaseFusedBatchNormGrad(op, 2, *grad)
938
939
940def _BatchNormGrad(grad_y,
941                   x,
942                   scale,
943                   pop_mean,
944                   pop_var,
945                   epsilon,
946                   data_format,
947                   is_training=True):
948  """Returns the gradients for the 3 inputs of BatchNorm.
949
950  Args:
951    grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
952    x: A `Tensor` of 4 or 5 dimensions for x.
953    scale: A `Tensor` of 1 dimension for scaling.
954    pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
955      is_training=False.
956    pop_var: A `Tensor` of 1 dimension for the population variance. Only used
957      when is_training=False.
958    epsilon: A small float number added to the variance of x.
959    data_format: The data format for input. Either b"NHWC" or b"NCHW".
960    is_training: A bool value to indicate the operation is for training
961      (default) or inference.
962
963  Returns:
964    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
965    for x, grad_scale the gradient for scale, and grad_offset the gradient
966    for offset.
967  """
968  x_dtype = x.dtype.base_dtype
969  if x_dtype == dtypes.float16:
970    # float16 math is too imprecise, so we do the batch norm gradient
971    # computations in float32.
972    x = math_ops.cast(x, dtypes.float32)
973    grad_y = math_ops.cast(grad_y, dtypes.float32)
974  if is_training:
975    if data_format == b"NHWC":
976      keepdims = False
977      reduce_axis = [0, 1, 2]
978    elif data_format == b"NDHWC":
979      keepdims = False
980      reduce_axis = [0, 1, 2, 3]
981    elif data_format == b"NCHW":
982      keepdims = True
983      reduce_axis = [0, 2, 3]
984      shape = [1, array_ops.size(scale), 1, 1]
985      scale = array_ops.reshape(scale, shape)
986    else:
987      keepdims = True
988      reduce_axis = [0, 2, 3, 4]
989      shape = [1, array_ops.size(scale), 1, 1, 1]
990      scale = array_ops.reshape(scale, shape)
991    mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
992    mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
993    var_x = math_ops.reduce_mean(
994        math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
995        reduce_axis,
996        keepdims=keepdims)
997    grad_y_offset = grad_y - mean_grad_y
998    x_offset = x - mean_x
999    mean = math_ops.reduce_mean(
1000        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
1001    grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
1002        grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
1003    grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
1004        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
1005    if data_format == b"NCHW" or data_format == b"NCDHW":
1006      grad_scale = array_ops.squeeze(grad_scale)
1007    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
1008    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
1009  else:
1010    if data_format == b"NHWC":
1011      reduce_axis = [0, 1, 2]
1012    elif data_format == b"NDHWC":
1013      reduce_axis = [0, 1, 2, 3]
1014    elif data_format == b"NCHW":
1015      reduce_axis = [0, 2, 3]
1016      shape = [1, array_ops.size(pop_mean), 1, 1]
1017      pop_mean = array_ops.reshape(pop_mean, shape)
1018      pop_var = array_ops.reshape(pop_var, shape)
1019      scale = array_ops.reshape(scale, shape)
1020    else:
1021      reduce_axis = [0, 2, 3, 4]
1022      shape = [1, array_ops.size(pop_mean), 1, 1, 1]
1023      pop_mean = array_ops.reshape(pop_mean, shape)
1024      pop_var = array_ops.reshape(pop_var, shape)
1025      scale = array_ops.reshape(scale, shape)
1026
1027    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
1028    var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
1029    grad_scale = math_ops.reduce_sum(
1030        grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis)
1031    grad_x = grad_y * scale * var_rsqrt
1032    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
1033
1034
1035@ops.RegisterGradient("FusedBatchNormGrad")
1036def _FusedBatchNormGradGrad(op, *grad):
1037  """Returns the gradients for the 3 inputs of FusedBatchNormGrad.
1038
1039  Args:
1040    op: The FusedBatchNormGradOp for which we need to compute gradients.
1041    *grad: An argument list for tensors of gradients wrt the outputs with
1042      grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as
1043      grad_grad_offset.
1044
1045  Returns:
1046    A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y
1047    is the gradient for grad_y, grad_x the gradient for x, grad_scale the
1048    gradient for scale.
1049  """
1050  data_format = op.get_attr("data_format")
1051  epsilon = op.get_attr("epsilon")
1052  is_training = op.get_attr("is_training")
1053  grad_y = op.inputs[0]
1054  x = op.inputs[1]
1055  scale = op.inputs[2]
1056  pop_mean = op.inputs[3]
1057  pop_var = op.inputs[4]
1058  grad_grad_x = grad[0]
1059  grad_grad_scale = grad[1]
1060  grad_grad_offset = grad[2]
1061  with backprop.GradientTape() as tape:
1062    tape.watch(grad_y)
1063    tape.watch(x)
1064    tape.watch(scale)
1065    grad_x, grad_scale, grad_offset = _BatchNormGrad(
1066        grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
1067    grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
1068  grad_grad_y, grad_x, grad_scale = tape.gradient(
1069      [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
1070  return grad_grad_y, grad_x, grad_scale, None, None
1071
1072
1073@ops.RegisterGradient("FusedBatchNormGradV2")
1074def _FusedBatchNormGradGradV2(op, *grad):
1075  return _FusedBatchNormGradGrad(op, *grad)
1076
1077
1078@ops.RegisterGradient("FusedBatchNormGradV3")
1079def _FusedBatchNormGradGradV3(op, *grad):
1080  grad_grad_y, grad_x, grad_scale, _, _ = _FusedBatchNormGradGrad(op, *grad)
1081  return grad_grad_y, grad_x, grad_scale, None, None, None
1082
1083
1084@ops.RegisterGradient("L2Loss")
1085def _L2LossGrad(op, grad):
1086  """Return the gradients for L2Loss.
1087
1088  Args:
1089    op: The L2LossOp for which we need to generate gradients.
1090    grad: Tensor containing a single number.
1091
1092  Returns:
1093    The gradient, which is (x * grad).
1094  """
1095  return op.inputs[0] * grad
1096
1097
1098@ops.RegisterGradient("TopK")
1099@ops.RegisterGradient("TopKV2")
1100def _TopKGrad(op, grad, _):
1101  """Return the gradients for TopK.
1102
1103  Args:
1104    op: The TopKOp for which we need to generate gradients.
1105    grad: Tensor. The gradients passed to the TopKOp.
1106
1107  Returns:
1108    A list of two tensors, the first being the gradient w.r.t to the input and
1109    TopK, and the second being the gradient w.r.t. to the indices (all zero).
1110  """
1111  in_shape = array_ops.shape(op.inputs[0])
1112  ind_shape = array_ops.shape(op.outputs[1])
1113
1114  # int32 is not supported on GPU hence up-casting
1115  ind_lastdim = array_ops.gather(
1116      math_ops.cast(ind_shape, dtypes.int64),
1117      array_ops.size(ind_shape) - 1)
1118  # Flatten indices to 2D.
1119  ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
1120
1121  in_lastdim = array_ops.gather(
1122      math_ops.cast(in_shape, dtypes.int64),
1123      array_ops.size(in_shape) - 1)
1124  outerdim = array_ops.shape(ind_2d)[0]
1125  # Compute linear indices (flattened to 1D).
1126  ind = array_ops.reshape(
1127      ind_2d + math_ops.cast(
1128          array_ops.expand_dims(
1129              math_ops.range(0,
1130                             math_ops.cast(outerdim, dtypes.int64) * in_lastdim,
1131                             in_lastdim), -1), dtypes.int32), [-1])
1132
1133  # Substitute grad to appropriate locations and fill the rest with zeros,
1134  # finally reshaping it to the original input shape.
1135  return [
1136      array_ops.reshape(
1137          array_ops.scatter_nd(
1138              array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]),
1139              [math_ops.reduce_prod(in_shape)]), in_shape),
1140      array_ops.zeros([], dtype=dtypes.int32)
1141  ]
1142
1143
1144@ops.RegisterGradient("NthElement")
1145def _NthElementGrad(op, grad):
1146  """Return the gradients for NthElement.
1147
1148  Args:
1149    op: The NthElementOp for which we need to generate gradients.
1150    grad: Tensor. The gradients passed to the NthElementOp
1151
1152  Returns:
1153    A list of two tensors, the first being the gradient w.r.t. the input,
1154    the second being the gradient w.r.t. the N (None).
1155  """
1156  input = op.inputs[0]  # pylint: disable=redefined-builtin
1157  output = op.outputs[0]
1158
1159  # Compute the number of elements which equal to output in each reduction
1160  # dimension. If there are multiple elements then the gradient will be
1161  # divided between them.
1162  indicators = math_ops.cast(
1163      math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
1164
1165  grad = array_ops.expand_dims(grad, -1)
1166  num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
1167
1168  return [math_ops.divide(indicators, num_selected) * grad, None]
1169
1170
1171def _MeanAggregator(inputs, segments):
1172  """Replaces each segment with its mean along the last axis.
1173
1174  Specifically, each value in the `inputs` tensor gets replaced by the mean
1175  value computed from the values that belong to the same segment.
1176
1177  Args:
1178   inputs: A 2-tensor. Aggregation is done over dimension 1.
1179   segments: A 2-tensor, same shape as `input`.
1180
1181  Returns:
1182    The result, same shape and type as `inputs`.
1183  """
1184  result = []
1185  for inputs_i, segments_i in zip(
1186      array_ops.split(inputs, inputs.shape[0]),
1187      array_ops.split(segments, segments.shape[0])):
1188    # Note that we do not use tf.math.segment_mean, as it has no TPU support.
1189    means_i = math_ops.unsorted_segment_mean(
1190        inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1)
1191    result.append(
1192        array_ops.reshape(array_ops.gather(means_i, segments_i), [-1]))
1193  return array_ops.stack(result, axis=0)
1194
1195
1196# We have to register the gradients for these ops so that tensorflow will know
1197# how to differentiate them.
1198@ops.RegisterGradient("IsotonicRegression")
1199def _IsotonicRegressionGrad(op, grad_output, grad_segments):
1200  """Gradient for the isotonic regression function.
1201
1202  Args:
1203    op: The IsotonicRegression tensorflow op.
1204    grad_output: Tensor of incoming gradients with respect to the output.
1205    grad_segments: Tensor of incoming gradients with respect to the segments.
1206
1207  Returns:
1208    A tensor, same size as `grad_output` with the gradient with respect to
1209    the input.
1210  """
1211  del grad_segments  # Discrete, non-differentiable.
1212  segments = op.outputs[1]
1213  return _MeanAggregator(grad_output, segments)
1214