1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in array_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import pywrap_tensorflow 22from tensorflow.python.eager import context 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_util 30from tensorflow.python.ops import gen_array_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import sparse_ops 33 34 35@ops.RegisterGradient("Pack") 36def _PackGrad(op, grad): 37 """Gradient for pack op.""" 38 return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis")) 39 40 41@ops.RegisterGradient("Unpack") 42def _UnpackGrad(op, *grads): 43 """Gradient for unpack op.""" 44 return array_ops.stack(grads, axis=op.get_attr("axis")) 45 46 47def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): 48 """Gradient for concat op. 49 50 Args: 51 op: An operation. 52 grad: `Tensor` or `IndexedSlices` representing the gradients with respect 53 to each output of the op. 54 start_value_index: An integer index of the first value in the op.inputs. 55 end_value_index: An integer index of the last value in the op.inputs. 56 dim_index: An interger index of concat_dim or axis parameter in op.inputs. 57 58 Returns: 59 Tensors representing the partial gradients with respect to each input 60 of the op. 61 62 Raises: 63 ValueError: if concat_dim/axis is not statically known. 64 """ 65 66 def _CreateDenseMaskAndBegin(sizes, concat_dim): 67 """Create variables for iteratively slicing a dense gradients tensor.""" 68 # Since shape is 1-D, shape_of_shape = [rank-of-inputs] 69 shape_of_shape = array_ops.shape(sizes[0]) 70 # Make a vector of length equal to the input's dimensions, 71 # with 0's everywhere and 1 in the concat dim position. 72 # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) 73 mask = array_ops.concat([ 74 array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1], 75 array_ops.fill(shape_of_shape - concat_dim - 1, 0) 76 ], 0) 77 begin = array_ops.fill(shape_of_shape, 0) 78 return mask, begin 79 80 def _ExtractInputShapes(inputs): 81 """Extract the shapes of a set of input tensors.""" 82 if context.executing_eagerly(): 83 return array_ops.shape_n(inputs) 84 sizes = [] 85 fully_known = True 86 for x in inputs: 87 input_shape = array_ops.shape(x) 88 if not isinstance(input_shape, 89 ops.Tensor) or input_shape.op.type != "Const": 90 fully_known = False 91 break 92 sizes.append(input_shape) 93 94 if fully_known: 95 return sizes 96 else: 97 return array_ops.shape_n(inputs) 98 99 # Degenerate concatenation, just return grad. 100 if len(op.inputs) == 2: 101 return grad + [None] if end_value_index <= dim_index else [None] + grad 102 103 concat_dim = op.inputs[dim_index] 104 input_values = op.inputs[start_value_index:end_value_index] 105 106 out_grads = [] 107 if isinstance(grad, ops.Tensor): 108 if context.executing_eagerly(): 109 # Using mod here for convenience since concat_dim is already verified 110 # in concat implementation to be within the allowed [-rank, rank) range. 111 non_neg_concat_dim = ( 112 concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access 113 # All inputs are guaranteed to be EagerTensors in eager mode 114 sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values, 115 non_neg_concat_dim) 116 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 117 else: 118 if constant_op.is_constant(concat_dim): 119 # If concat_dim is a constant defined in a different context, 120 # then we duplicate it in the current context to avoid passing it 121 # through an Enter node. 122 # This is a small optimization in general, but it is required when 123 # compiling with XLA, as XLA needs the concat input to be folded into a 124 # constant. 125 grad_context = control_flow_util.GetOutputContext(grad.op) 126 dim_context = control_flow_util.GetOutputContext(concat_dim.op) 127 if dim_context != grad_context: 128 value = tensor_util.constant_value(concat_dim) 129 concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) 130 131 # Using mod here for convenience since concat_dim is already verified 132 # in concat implementation to be within the allowed [-rank, rank) range. 133 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 134 135 # Get the inputs' tensor shapes 136 sizes = _ExtractInputShapes(input_values) 137 # The magic number of 16 was found through benchmarking a range of sizes 138 # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of 139 # cases when switching implementations at N=16, but it is possible that 140 # there will be a small number of performance regressions. 141 if len(sizes) > 16: 142 # extract the size of each input along the concat dimension 143 sizes = array_ops.squeeze( 144 array_ops.slice( 145 array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0], 146 [1, -1])) 147 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 148 else: 149 offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) 150 for (begin, size) in zip(offset, sizes): 151 out_grads.append(array_ops.slice(grad, begin, size)) 152 elif isinstance(grad, ops.IndexedSlices): 153 # Using mod here for convenience since concat_dim is already verified 154 # in concat implementation to be within the allowed [-rank, rank) range. 155 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 156 concat_dim_static = tensor_util.constant_value(concat_dim) 157 if concat_dim_static is None: 158 raise ValueError("Can only compute IndexedSlices gradient with " 159 "statically-known concat_dim") 160 if concat_dim_static < 0: 161 rank = tensor_util.constant_value(array_ops.rank(input_values[0])) 162 if rank is None: 163 raise ValueError("Can only compute IndexedSlices gradient with " 164 "negative concat_dim when first value rank is " 165 "statically-known.") 166 concat_dim_static %= rank 167 # Get the inputs' tensor shapes 168 sizes = [array_ops.shape(x) for x in input_values] 169 if concat_dim_static > 0: 170 # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices 171 # gradients with all the indices, but with grad.values sliced accordingly. 172 # This is like the Tensor case, except shape(grad.values)[0] is not equal 173 # to shape(sizes[i])[0], since only a subset of the dim-0 values are 174 # stored. 175 mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) 176 for size in sizes: 177 new_values = array_ops.slice( 178 grad.values, begin, 179 array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) 180 out_grads.append(ops.IndexedSlices(new_values, grad.indices, size)) 181 # Lint complains begin = begin + ... 182 begin = math_ops.add(begin, size * mask) 183 else: 184 # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients 185 # only for the relevant indices. 186 start = constant_op.constant(0, dtype=grad.indices.dtype) 187 for size in sizes: 188 size_concat_dim = array_ops.gather(size, non_neg_concat_dim) 189 if size_concat_dim.dtype != grad.indices.dtype: 190 size_concat_dim = math_ops.cast( 191 size_concat_dim, dtype=grad.indices.dtype) 192 end = start + size_concat_dim 193 # Compute the 1-D Tensor of indices relevant for this input. 194 indices_to_select = array_ops.squeeze( 195 array_ops.where( 196 math_ops.logical_and(grad.indices >= start, 197 grad.indices < end)), 198 axis=[1]) 199 new_indices = array_ops.gather(grad.indices, indices_to_select) - start 200 new_values = array_ops.gather(grad.values, indices_to_select) 201 out_grads.append(ops.IndexedSlices(new_values, new_indices, size)) 202 start = end 203 else: 204 raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) 205 206 return (out_grads + [None] 207 if end_value_index <= dim_index else [None] + out_grads) 208 209 210@ops.RegisterGradient("Concat") 211def _ConcatGrad(op, grad): 212 return _ConcatGradHelper( 213 op, 214 grad, 215 start_value_index=1, 216 end_value_index=len(op.inputs), 217 dim_index=0) 218 219 220@ops.RegisterGradient("ConcatV2") 221def _ConcatGradV2(op, grad): 222 return _ConcatGradHelper( 223 op, grad, start_value_index=0, end_value_index=-1, dim_index=-1) 224 225 226ops.NotDifferentiable("ConcatOffset") 227 228 229@ops.RegisterGradient("Slice") 230def _SliceGrad(op, grad): 231 """Gradient for Slice op.""" 232 # Create an Nx2 padding where the first column represents how many 233 # zeros are to be prepended for each dimension, and the second 234 # column indicates how many zeros are appended. 235 # 236 # The number of zeros to append is the shape of the input 237 # elementwise-subtracted by both the begin vector and sizes vector. 238 # 239 # Some more reshaping is needed to assemble this tensor with the 240 # right dimensions. 241 input_vec = op.inputs[0] 242 begin_vec = op.inputs[1] 243 input_rank = array_ops.rank(input_vec) 244 slice_size = array_ops.shape(op.outputs[0]) 245 246 shape = array_ops.stack([input_rank, 1]) 247 before_pad = array_ops.reshape(begin_vec, shape) 248 after_pad = array_ops.reshape( 249 array_ops.shape(input_vec) - slice_size - begin_vec, shape) 250 paddings = array_ops.concat([before_pad, after_pad], 1) 251 return array_ops.pad(grad, paddings), None, None 252 253 254@ops.RegisterGradient("StridedSlice") 255def _StridedSliceGrad(op, grad): 256 """Gradient for StridedSlice op.""" 257 begin = op.inputs[1] 258 end = op.inputs[2] 259 strides = op.inputs[3] 260 # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the 261 # same dtype so we build a shape of the same type as other args. 262 # Note that the choice of `begin` for specifying `out_type` is arbitrary. 263 # We could choose any of {begin|end|strides}.dtype since they are required to 264 # be the same. 265 x = array_ops.shape(op.inputs[0], out_type=begin.dtype) 266 267 return array_ops.strided_slice_grad( 268 x, 269 begin, 270 end, 271 strides, 272 grad, 273 begin_mask=op.get_attr("begin_mask"), 274 end_mask=op.get_attr("end_mask"), 275 ellipsis_mask=op.get_attr("ellipsis_mask"), 276 new_axis_mask=op.get_attr("new_axis_mask"), 277 shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None 278 279 280@ops.RegisterGradient("StridedSliceGrad") 281def _StridedSliceGradGrad(op, grad): 282 """Gradient for StridedSliceGrad op.""" 283 begin = op.inputs[1] 284 end = op.inputs[2] 285 strides = op.inputs[3] 286 287 return None, None, None, None, array_ops.strided_slice( 288 grad, 289 begin, 290 end, 291 strides, 292 begin_mask=op.get_attr("begin_mask"), 293 end_mask=op.get_attr("end_mask"), 294 ellipsis_mask=op.get_attr("ellipsis_mask"), 295 new_axis_mask=op.get_attr("new_axis_mask"), 296 shrink_axis_mask=op.get_attr("shrink_axis_mask")) 297 298 299@ops.RegisterGradient("Split") 300def _SplitGrad(op, *grads): 301 return None, array_ops.concat(list(grads), op.inputs[0]) 302 303 304@ops.RegisterGradient("SplitV") 305def _SplitVGrad(op, *grads): 306 returnval = array_ops.concat(list(grads), op.inputs[2]) 307 returnval = [returnval] + [ 308 None, 309 ] * ( 310 len(op.inputs) - 1) 311 return returnval 312 313 314ops.NotDifferentiable("Const") 315 316 317@ops.RegisterGradient("Diag") 318def _DiagGrad(_, grad): 319 return array_ops.diag_part(grad) 320 321 322@ops.RegisterGradient("DiagPart") 323def _DiagPartGrad(_, grad): 324 return array_ops.diag(grad) 325 326 327@ops.RegisterGradient("MatrixDiag") 328def _MatrixDiagGrad(_, grad): 329 return array_ops.matrix_diag_part(grad) 330 331 332@ops.RegisterGradient("MatrixDiagPart") 333def _MatrixDiagPartGrad(op, grad): 334 matrix_shape = op.inputs[0].get_shape()[-2:] 335 if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]: 336 return array_ops.matrix_diag(grad) 337 else: 338 return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) 339 340 341@ops.RegisterGradient("MatrixSetDiag") 342def _MatrixSetDiagGrad(op, grad): 343 """Gradient for MatrixSetDiag.""" 344 input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) 345 diag_shape = op.inputs[1].get_shape() 346 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) 347 matrix_shape = input_shape[-2:] 348 if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): 349 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] 350 else: 351 with ops.colocate_with(grad): 352 grad_shape = array_ops.shape(grad) 353 grad_rank = array_ops.rank(grad) 354 batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) 355 matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) 356 min_dim = math_ops.reduce_min(matrix_shape) 357 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) 358 grad_input = array_ops.matrix_set_diag(grad, 359 array_ops.zeros( 360 diag_shape, dtype=grad.dtype)) 361 grad_diag = array_ops.matrix_diag_part(grad) 362 return (grad_input, grad_diag) 363 364 365@ops.RegisterGradient("MatrixBandPart") 366def _MatrixBandPartGrad(op, grad): 367 num_lower = op.inputs[1] 368 num_upper = op.inputs[2] 369 return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None) 370 371 372# Edit Distance has no gradient (but can be used to eval seq2seq or CTC). 373ops.NotDifferentiable("EditDistance") 374 375 376@ops.RegisterGradient("Fill") 377def _FillGrad(_, grad): 378 return None, math_ops.reduce_sum(grad) 379 380 381ops.NotDifferentiable("ZerosLike") 382ops.NotDifferentiable("OnesLike") 383 384 385@ops.RegisterGradient("PreventGradient") 386def _PreventGradientGrad(op, _): 387 raise LookupError( 388 "Gradient explicitly disabled. Reason: %s" % op.get_attr("message")) 389 390 391@ops.RegisterGradient("Gather") 392def _GatherGrad(op, grad): 393 """Gradient for Gather op.""" 394 # params can be large, so colocate the shape calculation with it. 395 # 396 # params can be very large for sparse model, array_ops.shape raises 397 # exception on the Windows platform when any dimension is larger than 398 # int32. params_shape is not used in optimizer apply_sparse gradients, 399 # so it's fine to convert it back to int32 regardless of truncation. 400 params = op.inputs[0] 401 with ops.colocate_with(params): 402 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) 403 params_shape = math_ops.cast(params_shape, dtypes.int32) 404 405 # Build appropriately shaped IndexedSlices 406 indices = op.inputs[1] 407 size = array_ops.expand_dims(array_ops.size(indices), 0) 408 values_shape = array_ops.concat([size, params_shape[1:]], 0) 409 values = array_ops.reshape(grad, values_shape) 410 indices = array_ops.reshape(indices, size) 411 return [ops.IndexedSlices(values, indices, params_shape), None] 412 413 414@ops.RegisterGradient("GatherV2") 415def _GatherV2Grad(op, grad): 416 """Gradient for GatherV2 op.""" 417 # params can be large, so colocate the shape calculation with it. 418 # 419 # params can be very large for sparse model, array_ops.shape raises 420 # exception on the Windows platform when any dimension is larger than 421 # int32. params_shape is not used in optimizer apply_sparse gradients, 422 # so it's fine to convert it back to int32 regardless of truncation. 423 params = op.inputs[0] 424 with ops.colocate_with(params): 425 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) 426 params_shape = math_ops.cast(params_shape, dtypes.int32) 427 428 indices = op.inputs[1] 429 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 430 axis = op.inputs[2] 431 axis_static = tensor_util.constant_value(axis) 432 433 # For axis 0 gathers, build an appropriately shaped IndexedSlices. 434 if axis_static == 0: 435 if context.executing_eagerly(): 436 params_tail_shape = params_shape.cpu()[1:] 437 else: 438 params_tail_shape = params_shape[1:] 439 values_shape = array_ops.concat([indices_size, params_tail_shape], 0) 440 values = array_ops.reshape(grad, values_shape) 441 indices = array_ops.reshape(indices, indices_size) 442 return [ops.IndexedSlices(values, indices, params_shape), None, None] 443 444 outer_shape = params_shape[:axis] 445 outer_dims = array_ops.size(outer_shape) 446 inner_shape = params_shape[axis:][1:] 447 inner_dims = array_ops.size(inner_shape) 448 449 outer_axes_indices = math_ops.range(outer_dims) 450 inner_axes_indices = math_ops.range(outer_dims + 1, 451 outer_dims + 1 + inner_dims) 452 453 values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0) 454 values = array_ops.reshape(grad, values_shape) 455 indices = array_ops.reshape(indices, indices_size) 456 457 # We need to sum up every slice `values[..., i, ....]` corresponding to 458 # `params[..., indices[i], ...]`. Since `unsorted_segment_sum` does not 459 # support an axis parameter, we transpose the gather dimension to the front, 460 # then use `unsorted_segment_sum` to build a 461 # [gather_axis, outer_axes, inner_axes] tensor with all the gradients 462 # affecting each index in `gather_axis` summed up. 463 transpose_dims = array_ops.concat( 464 [[outer_dims], outer_axes_indices, inner_axes_indices], 0) 465 values_transpose = array_ops.transpose(values, transpose_dims) 466 num_segments = params_shape[axis] 467 468 params_grad = math_ops.unsorted_segment_sum(values_transpose, indices, 469 num_segments) 470 471 # Inverts the above transpose by moving dimension 0 back to its original 472 # position. 473 invert_transpose_dims = array_ops.concat( 474 [outer_axes_indices + 1, [0], inner_axes_indices], 0) 475 params_grad = array_ops.transpose(params_grad, invert_transpose_dims) 476 return [params_grad, None, None] 477 478 479@ops.RegisterGradient("GatherNd") 480def _GatherNdGrad(op, grad): 481 ref = op.inputs[0] 482 indices = op.inputs[1] 483 ref_shape = array_ops.shape(ref, out_type=indices.dtype) 484 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 485 ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), 486 ref_shape) 487 else: 488 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 489 return [ref_grad, None] 490 491 492@ops.RegisterGradient("CheckNumerics") 493def _CheckNumericsGrad(op, grad): 494 """Gradient for check_numerics op.""" 495 return array_ops.check_numerics( 496 grad, 497 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 498 op.get_attr("message")) 499 500 501@ops.RegisterGradient("PlaceholderWithDefault") 502@ops.RegisterGradient("Identity") 503def _IdGrad(_, grad): 504 return grad 505 506 507@ops.RegisterGradient("RefIdentity") 508def _RefIdGrad(_, grad): 509 return grad 510 511 512@ops.RegisterGradient("IdentityN") 513def _IdNGrad(_, *grad): 514 return grad 515 516 517ops.NotDifferentiable("StopGradient") 518 519 520@ops.RegisterGradient("Reshape") 521def _ReshapeGrad(op, grad): 522 return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None] 523 524 525ops.NotDifferentiable("InvertPermutation") 526 527 528def _ReshapeToInput(op, grad): 529 """Reshapes the gradient to the shape of the original input.""" 530 return array_ops.reshape(grad, array_ops.shape(op.inputs[0])) 531 532 533@ops.RegisterGradient("ExpandDims") 534def _ExpandDimsGrad(op, grad): 535 return [_ReshapeToInput(op, grad), None] 536 537 538@ops.RegisterGradient("Squeeze") 539def _SqueezeGrad(op, grad): 540 return _ReshapeToInput(op, grad) 541 542 543@ops.RegisterGradient("Transpose") 544def _TransposeGrad(op, grad): 545 """Returns unshuffle(grad).""" 546 p = op.inputs[1] 547 return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None] 548 549 550@ops.RegisterGradient("ConjugateTranspose") 551def _ConjugateTransposeGrad(op, grad): 552 """Returns conj(unshuffle(grad)).""" 553 p = op.inputs[1] 554 return [ 555 array_ops.transpose( 556 grad, array_ops.invert_permutation(p), conjugate=True), None 557 ] 558 559 560ops.NotDifferentiable("Shape") 561 562ops.NotDifferentiable("ShapeN") 563 564ops.NotDifferentiable("Rank") 565 566ops.NotDifferentiable("Size") 567 568 569@ops.RegisterGradient("Tile") 570def _TileGrad(op, grad): 571 """Sum reduces grad along the tiled dimensions.""" 572 input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype) 573 # We interleave multiples and input_shape to get split_shape, 574 # reshape grad to split_shape, and reduce along all even 575 # dimensions (the tiled dimensions) to get the result 576 # with shape input_shape. For example 577 # input_shape = [20, 30, 40] 578 # multiples = [2, 3, 4] 579 # split_shape = [2, 20, 3, 30, 4, 40] 580 # axes = [0, 2, 4] 581 split_shape = array_ops.reshape( 582 array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) 583 axes = math_ops.range(0, array_ops.size(split_shape), 2) 584 # Sum reduces grad along the first dimension for IndexedSlices 585 if isinstance(grad, ops.IndexedSlices): 586 input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) 587 grad = math_ops.unsorted_segment_sum( 588 grad.values, 589 math_ops.mod(grad.indices, input_shape_0), 590 input_shape_0) 591 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) 592 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) 593 # Fix shape inference 594 if not context.executing_eagerly(): 595 input_grad.set_shape(op.inputs[0].get_shape()) 596 return [input_grad, None] 597 598 599ops.NotDifferentiable("BroadcastGradientArgs") 600 601 602def _PadGrad(op, grad): 603 """Gradient for Pad.""" 604 # Pad introduces values around the original tensor, so the gradient function 605 # slices the original shape out of the gradient.""" 606 x = op.inputs[0] 607 a = op.inputs[1] # [Rank(x), 2] 608 # Takes a slice of a. The 1st column. [Rank(x), 1]. 609 pad_before = array_ops.slice(a, [0, 0], 610 array_ops.stack([array_ops.rank(x), 1])) 611 # Make it a 1-D tensor. 612 begin = array_ops.reshape(pad_before, [-1]) 613 sizes = array_ops.shape(x) 614 x_grad = array_ops.slice(grad, begin, sizes) 615 if len(op.inputs) == 3: 616 return x_grad, None, None 617 else: 618 return x_grad, None 619 620 621ops.RegisterGradient("Pad")(_PadGrad) 622ops.RegisterGradient("PadV2")(_PadGrad) 623 624 625# ReverseSequence is just a permutation. The gradient permutes back. 626@ops.RegisterGradient("ReverseSequence") 627def _ReverseSequenceGrad(op, grad): 628 seq_lengths = op.inputs[1] 629 return [ 630 array_ops.reverse_sequence( 631 grad, 632 batch_axis=op.get_attr("batch_dim"), 633 seq_axis=op.get_attr("seq_dim"), 634 seq_lengths=seq_lengths), None 635 ] 636 637 638@ops.RegisterGradient("Reverse") 639def _ReverseGrad(op, grad): 640 reverse_dims = op.inputs[1] 641 return gen_array_ops.reverse(grad, reverse_dims), None 642 643 644@ops.RegisterGradient("ReverseV2") 645def _ReverseV2Grad(op, grad): 646 axis = op.inputs[1] 647 return array_ops.reverse_v2(grad, axis), None 648 649 650@ops.RegisterGradient("SpaceToBatch") 651def _SpaceToBatchGrad(op, grad): 652 # Its gradient is the opposite op: BatchToSpace. 653 block_size = op.get_attr("block_size") 654 return [ 655 array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None 656 ] 657 658 659@ops.RegisterGradient("SpaceToBatchND") 660def _SpaceToBatchNDGrad(op, grad): 661 # Its gradient is the opposite op: BatchToSpaceND. 662 return [ 663 array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None 664 ] 665 666 667@ops.RegisterGradient("BatchToSpace") 668def _BatchToSpaceGrad(op, grad): 669 # Its gradient is the opposite op: SpaceToBatch. 670 block_size = op.get_attr("block_size") 671 return [ 672 array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None 673 ] 674 675 676@ops.RegisterGradient("BatchToSpaceND") 677def _BatchToSpaceNDGrad(op, grad): 678 # Its gradient is the opposite op: SpaceToBatchND. 679 return [ 680 array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None 681 ] 682 683 684@ops.RegisterGradient("SpaceToDepth") 685def _SpaceToDepthGrad(op, grad): 686 # Its gradient is the opposite op: DepthToSpace. 687 block_size = op.get_attr("block_size") 688 data_format = op.get_attr("data_format") 689 if data_format == "NCHW_VECT_C": 690 raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. " 691 "NCHW_VECT_C requires qint8 data type.") 692 return array_ops.depth_to_space(grad, block_size, data_format=data_format) 693 694 695@ops.RegisterGradient("DepthToSpace") 696def _DepthToSpaceGrad(op, grad): 697 # Its gradient is the opposite op: SpaceToDepth. 698 block_size = op.get_attr("block_size") 699 data_format = op.get_attr("data_format") 700 if data_format == "NCHW_VECT_C": 701 raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. " 702 "NCHW_VECT_C requires qint8 data type.") 703 return array_ops.space_to_depth(grad, block_size, data_format=data_format) 704 705 706ops.NotDifferentiable("OneHot") 707 708 709@ops.RegisterGradient("MirrorPad") 710def _MirrorPadGrad(op, grad): 711 mode = op.get_attr("mode") 712 return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] 713 714 715@ops.RegisterGradient("MirrorPadGrad") 716def _MirrorPadGradGrad(op, grad): 717 mode = op.get_attr("mode") 718 return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] 719 720 721@ops.RegisterGradient("QuantizeAndDequantize") 722def _QuantizeAndDequantizeGrad(_, grad): 723 return grad 724 725 726@ops.RegisterGradient("QuantizeAndDequantizeV2") 727def _QuantizeAndDequantizeV2Grad(_, grad): 728 return [grad, None, None] 729 730 731@ops.RegisterGradient("QuantizeAndDequantizeV3") 732def _QuantizeAndDequantizeV3Grad(_, grad): 733 # Only propagate the gradient for the unquantized input. 734 return [grad, None, None, None] 735 736 737@ops.RegisterGradient("ExtractImagePatches") 738def _ExtractImagePatchesGrad(op, grad): 739 batch_size, rows_in, cols_in, channels = [ 740 dim.value for dim in op.inputs[0].shape.dims 741 ] 742 input_bhwc = array_ops.shape(op.inputs[0]) 743 batch_size = input_bhwc[0] 744 channels = input_bhwc[3] 745 746 # Create indices matrix for input tensor. 747 # Note that 0 is preserved for padding location, 748 # so indices for input start from 1 to 1 + rows_in * cols_in. 749 input_indices_num = 1 + rows_in * cols_in 750 input_idx = array_ops.reshape(math_ops.range(1, input_indices_num, 751 dtype=ops.dtypes.int64), 752 (1, rows_in, cols_in, 1)) 753 input_idx_patched = gen_array_ops.extract_image_patches( 754 input_idx, 755 op.get_attr("ksizes"), 756 op.get_attr("strides"), 757 op.get_attr("rates"), 758 op.get_attr("padding")) 759 760 # Create indices matrix for output tensor. 761 _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims] 762 _, ksize_r, ksize_c, _ = op.get_attr("ksizes") 763 # Indices for output start from 0. 764 output_indices_num = rows_out * cols_out * ksize_r * ksize_c 765 output_idx = array_ops.reshape(math_ops.range(output_indices_num, 766 dtype=ops.dtypes.int64), 767 (1, rows_out, cols_out, ksize_r * ksize_c)) 768 769 # Construct mapping table for indices: (input -> output). 770 idx_matrix = array_ops.concat( 771 [array_ops.expand_dims(input_idx_patched, axis=-1), 772 array_ops.expand_dims(output_idx, axis=-1)], 773 axis=-1) 774 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 775 776 sp_shape = (input_indices_num, output_indices_num) 777 sp_mat_full = sparse_tensor.SparseTensor( 778 idx_map, 779 array_ops.ones([output_indices_num], dtype=grad.dtype), 780 sp_shape) 781 # Remove all padding locations [0, :]. 782 sp_mat = sparse_ops.sparse_slice(sp_mat_full, 783 (1, 0), 784 (input_indices_num - 1, output_indices_num)) 785 786 grad_expanded = array_ops.transpose( 787 array_ops.reshape( 788 grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), 789 (1, 2, 3, 4, 0, 5)) 790 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 791 792 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 793 794 grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) 795 grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) 796 797 return [grad_out] 798 799 800@ops.RegisterGradient("ExtractVolumePatches") 801def _ExtractVolumePatchesGrad(op, grad): 802 batch_size, planes_in, rows_in, cols_in, channels = [ 803 dim.value for dim in op.inputs[0].shape.dims 804 ] 805 input_bphwc = array_ops.shape(op.inputs[0]) 806 batch_size = input_bphwc[0] 807 channels = input_bphwc[4] 808 809 # Create indices matrix for input tensor. 810 # Note that 0 is preserved for padding location, 811 # so indices for input start from 1 to 1 + rows_in * cols_in. 812 input_indices_num = 1 + planes_in * rows_in * cols_in 813 input_idx = array_ops.reshape( 814 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 815 (1, planes_in, rows_in, cols_in, 1)) 816 input_idx_patched = gen_array_ops.extract_volume_patches( 817 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 818 op.get_attr("padding")) 819 820 # Create indices matrix for output tensor. 821 _, planes_out, rows_out, cols_out, _ = [ 822 dim.value for dim in op.outputs[0].shape.dims 823 ] 824 _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") 825 # Indices for output start from 0. 826 prc_indices_num = planes_out * rows_out * cols_out 827 output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c 828 output_idx = array_ops.reshape( 829 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 830 (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) 831 832 # Construct mapping table for indices: (input -> output). 833 idx_matrix = array_ops.concat([ 834 array_ops.expand_dims(input_idx_patched, axis=-1), 835 array_ops.expand_dims(output_idx, axis=-1) 836 ], 837 axis=-1) 838 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 839 840 sp_shape = (input_indices_num, output_indices_num) 841 sp_mat_full = sparse_tensor.SparseTensor( 842 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 843 # Remove all padding locations [0, :]. 844 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 845 (input_indices_num - 1, output_indices_num)) 846 847 grad_expanded = array_ops.transpose( 848 array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out, 849 ksize_p, ksize_r, ksize_c, channels)), 850 (1, 2, 3, 4, 5, 6, 0, 7)) 851 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 852 853 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 854 855 grad_out = array_ops.reshape( 856 jac, (planes_in, rows_in, cols_in, batch_size, channels)) 857 grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) 858 859 return [grad_out] 860 861 862@ops.RegisterGradient("ScatterNd") 863def _ScatterNdGrad(op, grad): 864 indices = op.inputs[0] 865 updates_grad = array_ops.gather_nd(grad, indices) 866 return [None, updates_grad, None] 867 868 869@ops.RegisterGradient("TensorScatterUpdate") 870def _TensorScatterUpdateGrad(op, grad): 871 indices = op.inputs[1] 872 updates_grad = array_ops.gather_nd(grad, indices) 873 tensor_grad = array_ops.tensor_scatter_update( 874 array_ops.identity(grad), indices, 875 array_ops.zeros_like(op.inputs[2], dtype=grad.dtype)) 876 return [tensor_grad, None, updates_grad] 877 878 879@ops.RegisterGradient("TensorScatterAdd") 880def _TensorScatterAddGrad(op, grad): 881 indices = op.inputs[1] 882 updates_grad = array_ops.gather_nd(grad, indices) 883 tensor_grad = array_ops.identity(grad) 884 return [tensor_grad, None, updates_grad] 885 886 887@ops.RegisterGradient("TensorScatterSub") 888def _TensorScatterSubGrad(op, grad): 889 indices = op.inputs[1] 890 updates_grad = array_ops.gather_nd(grad, indices) 891 tensor_grad = array_ops.identity(grad) 892 return [tensor_grad, None, -updates_grad] 893 894 895@ops.RegisterGradient("ScatterNdNonAliasingAdd") 896def _ScatterNdNonAliasingAddGrad(op, grad): 897 indices = op.inputs[1] 898 updates_grad = array_ops.gather_nd(grad, indices) 899 return [grad, None, updates_grad] 900 901 902@ops.RegisterGradient("BroadcastTo") 903def _BroadcastToGrad(op, grad): 904 input_value = op.inputs[0] 905 broadcast_shape = op.inputs[1] 906 input_value_shape = array_ops.shape(input_value) 907 _, reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape, 908 input_value_shape) 909 updates_grad_reshaped = math_ops.reduce_sum(grad, 910 axis=reduction_axes, 911 keepdims=True) 912 updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) 913 return [updates_grad, None] 914