1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/ops/array_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 
24 namespace tensorflow {
25 namespace ops {
26 namespace {
27 
28 REGISTER_NO_GRADIENT_OP("Const");
29 REGISTER_NO_GRADIENT_OP("StopGradient");
30 REGISTER_NO_GRADIENT_OP("ConcatOffset");
31 REGISTER_NO_GRADIENT_OP("EditDistance");
32 REGISTER_NO_GRADIENT_OP("ZerosLike");
33 REGISTER_NO_GRADIENT_OP("InvertPermutation");
34 REGISTER_NO_GRADIENT_OP("Shape");
35 REGISTER_NO_GRADIENT_OP("ShapeN");
36 REGISTER_NO_GRADIENT_OP("Rank");
37 REGISTER_NO_GRADIENT_OP("Size");
38 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39 REGISTER_NO_GRADIENT_OP("OneHot");
40 
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)41 Status PackGrad(const Scope& scope, const Operation& op,
42                 const std::vector<Output>& grad_inputs,
43                 std::vector<Output>* grad_outputs) {
44   int N;
45   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
46   int axis;
47   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
48 
49   grad_outputs->reserve(N);
50   auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51   for (const Output& o : grad_op.output) {
52     grad_outputs->emplace_back(o);
53   }
54   return scope.status();
55 }
56 REGISTER_GRADIENT_OP("Pack", PackGrad);
57 
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)58 Status UnpackGrad(const Scope& scope, const Operation& op,
59                   const std::vector<Output>& grad_inputs,
60                   std::vector<Output>* grad_outputs) {
61   int axis;
62   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
63   grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64   return scope.status();
65 }
66 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67 
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)68 Status IdentityGrad(const Scope& scope, const Operation& op,
69                     const std::vector<Output>& grad_inputs,
70                     std::vector<Output>* grad_outputs) {
71   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72   return scope.status();
73 }
74 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75 
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status RefIdentityGrad(const Scope& scope, const Operation& op,
77                        const std::vector<Output>& grad_inputs,
78                        std::vector<Output>* grad_outputs) {
79   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80   return scope.status();
81 }
82 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83 
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)84 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85                                  const std::vector<Output>& grad_inputs,
86                                  std::vector<Output>* grad_outputs) {
87   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88   return scope.status();
89 }
90 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91 
QuantizeAndDequantizeV4GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)92 Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
93                                          const Operation& op,
94                                          const std::vector<Output>& grad_inputs,
95                                          std::vector<Output>* grad_outputs) {
96   Input input = Shape(scope, op.input(0));
97   Input input_min = op.input(1);
98   Input input_max = op.input(2);
99   int64 axis;
100   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
101   auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
102       scope, grad_inputs[0], input, input_min, input_max,
103       QuantizeAndDequantizeV4Grad::Axis(axis));
104   grad_outputs->push_back(qdq_v4_grad.input_backprop);
105   grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
106   grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
107   return scope.status();
108 }
109 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
110                      QuantizeAndDequantizeV4GradHelper);
111 
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)112 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
113                                    const std::vector<Output>& grad_inputs,
114                                    std::vector<Output>* grad_outputs) {
115   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
116   grad_outputs->push_back(NoGradient());
117   grad_outputs->push_back(NoGradient());
118   grad_outputs->push_back(NoGradient());
119   return scope.status();
120 }
121 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
122 
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status SplitGrad(const Scope& scope, const Operation& op,
124                  const std::vector<Output>& grad_inputs,
125                  std::vector<Output>* grad_outputs) {
126   grad_outputs->push_back(NoGradient());
127   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
128   return scope.status();
129 }
130 REGISTER_GRADIENT_OP("Split", SplitGrad);
131 
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)132 Status FillGrad(const Scope& scope, const Operation& op,
133                 const std::vector<Output>& grad_inputs,
134                 std::vector<Output>* grad_outputs) {
135   // y = fill(fill_shape, x)
136   // No gradient returned for the fill_shape argument.
137   grad_outputs->push_back(NoGradient());
138   // The gradient for x (which must be a scalar) is just the sum of
139   // all the gradients from the shape it fills.
140   // We use ReduceSum to implement this, which needs an argument providing
141   // the indices of all the dimensions of the incoming gradient.
142   // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
143   auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
144                         Const(scope, 1));
145   grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
146   return scope.status();
147 }
148 REGISTER_GRADIENT_OP("Fill", FillGrad);
149 
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)150 Status DiagGrad(const Scope& scope, const Operation& op,
151                 const std::vector<Output>& grad_inputs,
152                 std::vector<Output>* grad_outputs) {
153   grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
154   return scope.status();
155 }
156 REGISTER_GRADIENT_OP("Diag", DiagGrad);
157 
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)158 Status DiagPartGrad(const Scope& scope, const Operation& op,
159                     const std::vector<Output>& grad_inputs,
160                     std::vector<Output>* grad_outputs) {
161   grad_outputs->push_back(Diag(scope, grad_inputs[0]));
162   return scope.status();
163 }
164 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
165 
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)166 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
167                       const std::vector<Output>& grad_inputs,
168                       std::vector<Output>* grad_outputs) {
169   grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
170   return scope.status();
171 }
172 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
173 
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)174 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
175                           const std::vector<Output>& grad_inputs,
176                           std::vector<Output>* grad_outputs) {
177   auto num_lower = op.input(1);
178   auto num_upper = op.input(2);
179   grad_outputs->push_back(
180       MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
181   grad_outputs->push_back(NoGradient());
182   grad_outputs->push_back(NoGradient());
183   return scope.status();
184 }
185 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
186 
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)187 Status GatherNdGrad(const Scope& scope, const Operation& op,
188                     const std::vector<Output>& grad_inputs,
189                     std::vector<Output>* grad_outputs) {
190   auto ref = op.input(0);
191   auto ref_shape = Shape(scope, ref);
192   auto indices = op.input(1);
193   grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
194   grad_outputs->push_back(NoGradient());
195   return scope.status();
196 }
197 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
198 
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)199 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
200                          const std::vector<Output>& grad_inputs,
201                          std::vector<Output>* grad_outputs) {
202   string message;
203   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
204   string err_msg = strings::StrCat(
205       "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
206       message);
207   grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
208   return scope.status();
209 }
210 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
211 
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)212 Status ReshapeGrad(const Scope& scope, const Operation& op,
213                    const std::vector<Output>& grad_inputs,
214                    std::vector<Output>* grad_outputs) {
215   auto input_shape = Shape(scope, op.input(0));
216   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
217   grad_outputs->push_back(NoGradient());
218   return scope.status();
219 }
220 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
221 
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)222 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
223                       const std::vector<Output>& grad_inputs,
224                       std::vector<Output>* grad_outputs) {
225   auto input_shape = Shape(scope, op.input(0));
226   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
227   grad_outputs->push_back(NoGradient());
228   return scope.status();
229 }
230 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
231 
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)232 Status SqueezeGrad(const Scope& scope, const Operation& op,
233                    const std::vector<Output>& grad_inputs,
234                    std::vector<Output>* grad_outputs) {
235   auto input_shape = Shape(scope, op.input(0));
236   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
237   return scope.status();
238 }
239 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
240 
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)241 Status TransposeGrad(const Scope& scope, const Operation& op,
242                      const std::vector<Output>& grad_inputs,
243                      std::vector<Output>* grad_outputs) {
244   auto inverted_perm = InvertPermutation(scope, op.input(1));
245   grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
246   grad_outputs->push_back(NoGradient());
247   return scope.status();
248 }
249 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
250 
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)251 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
252                            const std::vector<Output>& grad_inputs,
253                            std::vector<Output>* grad_outputs) {
254   auto seq_lengths = op.input(1);
255   int batch_dim;
256   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
257   int seq_dim;
258   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
259   grad_outputs->push_back(
260       ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
261                       ReverseSequence::BatchDim(batch_dim)));
262   grad_outputs->push_back(NoGradient());
263   return scope.status();
264 }
265 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
266 
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)267 Status ReverseGrad(const Scope& scope, const Operation& op,
268                    const std::vector<Output>& grad_inputs,
269                    std::vector<Output>* grad_outputs) {
270   auto reverse_dims = op.input(1);
271   grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
272   grad_outputs->push_back(NoGradient());
273   return scope.status();
274 }
275 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
276 
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)277 Status ScatterNdGrad(const Scope& scope, const Operation& op,
278                      const std::vector<Output>& grad_inputs,
279                      std::vector<Output>* grad_outputs) {
280   auto indices = op.input(0);
281   grad_outputs->push_back(NoGradient());
282   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
283   grad_outputs->push_back(NoGradient());
284   return scope.status();
285 }
286 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
287 
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)288 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
289                                    const std::vector<Output>& grad_inputs,
290                                    std::vector<Output>* grad_outputs) {
291   auto indices = op.input(1);
292   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
293   grad_outputs->push_back(NoGradient());
294   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
295   return scope.status();
296 }
297 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
298 
299 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)300 Status PadGrad(const Scope& scope, const Operation& op,
301                const std::vector<Output>& grad_inputs,
302                std::vector<Output>* grad_outputs) {
303   auto x = op.input(0);
304   auto a = op.input(1);  // [Rank(x), 2]
305   // Takes a slice of a. The 1st column. [Rank(x), 1].
306   auto size = Stack(scope, {Rank(scope, x), 1});
307   auto pad_before = Slice(scope, a, {0, 0}, size);
308   // Make it a 1-D tensor.
309   auto begin = Reshape(scope, pad_before, {-1});
310   grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
311   grad_outputs->push_back(NoGradient());
312   // PadV2 adds a "constant_values" input.
313   if (IsPadV2) {
314     grad_outputs->push_back(NoGradient());
315   }
316   return scope.status();
317 }
318 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
319 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
320 
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)321 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
322                         const std::vector<Output>& grad_inputs,
323                         std::vector<Output>* grad_outputs) {
324   int block_size;
325   TF_RETURN_IF_ERROR(
326       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
327   grad_outputs->push_back(
328       BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
329   grad_outputs->push_back(NoGradient());
330   return scope.status();
331 }
332 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
333 
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)334 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
335                           const std::vector<Output>& grad_inputs,
336                           std::vector<Output>* grad_outputs) {
337   grad_outputs->push_back(
338       BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
339   grad_outputs->push_back(NoGradient());
340   grad_outputs->push_back(NoGradient());
341   return scope.status();
342 }
343 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
344 
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)345 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
346                         const std::vector<Output>& grad_inputs,
347                         std::vector<Output>* grad_outputs) {
348   int block_size;
349   TF_RETURN_IF_ERROR(
350       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
351   grad_outputs->push_back(
352       SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
353   grad_outputs->push_back(NoGradient());
354   return scope.status();
355 }
356 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
357 
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)358 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
359                           const std::vector<Output>& grad_inputs,
360                           std::vector<Output>* grad_outputs) {
361   grad_outputs->push_back(
362       SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
363   grad_outputs->push_back(NoGradient());
364   grad_outputs->push_back(NoGradient());
365   return scope.status();
366 }
367 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
368 
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)369 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
370                         const std::vector<Output>& grad_inputs,
371                         std::vector<Output>* grad_outputs) {
372   int block_size;
373   TF_RETURN_IF_ERROR(
374       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
375   grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
376   return scope.status();
377 }
378 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
379 
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)380 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
381                         const std::vector<Output>& grad_inputs,
382                         std::vector<Output>* grad_outputs) {
383   int block_size;
384   TF_RETURN_IF_ERROR(
385       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
386   grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
387   return scope.status();
388 }
389 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
390 
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)391 Status MirrorPadGrad(const Scope& scope, const Operation& op,
392                      const std::vector<Output>& grad_inputs,
393                      std::vector<Output>* grad_outputs) {
394   string mode;
395   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
396   grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
397       scope, grad_inputs[0], op.input(1), mode));
398   grad_outputs->push_back(NoGradient());
399   return scope.status();
400 }
401 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
402 
403 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)404 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
405                          const std::vector<Output>& grad_inputs,
406                          std::vector<Output>* grad_outputs) {
407   string mode;
408   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
409   grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
410   grad_outputs->push_back(NoGradient());
411   return scope.status();
412 }
413 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
414 
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)415 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
416                               const std::vector<Output>& grad_inputs,
417                               std::vector<Output>* grad_outputs) {
418   Input x = Shape(scope, op.input(0));
419   Input begin = op.input(1);
420   Input end = op.input(2);
421   Input strides = op.input(3);
422   int64 begin_mask;
423   int64 end_mask;
424   int64 ellipsis_mask;
425   int64 new_axis_mask;
426   int64 shrink_axis_mask;
427   TF_RETURN_IF_ERROR(
428       GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
429   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
430   TF_RETURN_IF_ERROR(
431       GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
432   TF_RETURN_IF_ERROR(
433       GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
434   TF_RETURN_IF_ERROR(
435       GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
436   grad_outputs->push_back(
437       StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
438                        StridedSliceGrad::BeginMask(begin_mask)
439                            .EndMask(end_mask)
440                            .EllipsisMask(ellipsis_mask)
441                            .NewAxisMask(new_axis_mask)
442                            .ShrinkAxisMask(shrink_axis_mask)));
443   // No gradients returned for begin, end and strides
444   grad_outputs->push_back(NoGradient());
445   grad_outputs->push_back(NoGradient());
446   grad_outputs->push_back(NoGradient());
447   return scope.status();
448 }
449 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
450 
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)451 Status SliceGrad(const Scope& scope, const Operation& op,
452                  const std::vector<Output>& grad_inputs,
453                  std::vector<Output>* grad_outputs) {
454   // Propagate the incoming gradient along all the selected values,
455   // and zero everywhere else. Use the Pad operator for this.
456   //
457   // First create an Nx2 padding where N is the number of input
458   // dimensions. The first column is the number of prepended zeros
459   // for each dimension, and the second column is the number of
460   // appended zeros.
461   //
462   // The first column is just the begin vector.
463   // The second column is the shape of the input element-wise
464   // subtracted by begin+size
465 
466   // Running example:
467   // input.shape = [3, 5, 3]
468   // begin = [1, 2, 1], size = [1, 3, 2]
469   Input input = op.input(0);
470   Input begin = op.input(1);
471   // input_rank = 3
472   auto input_rank = Rank(scope, input);
473   // slice_size = [1, 3, 2]
474   auto slice_size = Shape(scope, op.output(0));
475   // padding_shape = [3, 1]
476   auto padding_shape = Stack(scope, {input_rank, 1});
477   // before_padding = [[1]
478   //                   [2]
479   //                   [1]]
480   Input before_padding = Reshape(scope, begin, padding_shape);
481   // after_padding_sizes = shape(input) - slice_size - begin
482   //                     = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
483   //                     = [1, 0, 0]
484   auto after_padding_sizes =
485       Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
486   // after_padding = [[1]
487   //                  [0]
488   //                  [0]]
489   Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
490   // paddings = [[1 1]
491   //             [2 0]
492   //             [1 0]]
493   auto paddings =
494       Concat(scope, {before_padding, after_padding}, Const(scope, 1));
495   grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
496   // Nothing propagated for "begin" and "size" inputs
497   grad_outputs->push_back(NoGradient());
498   grad_outputs->push_back(NoGradient());
499   return scope.status();
500 }
501 REGISTER_GRADIENT_OP("Slice", SliceGrad);
502 
503 }  // anonymous namespace
504 }  // namespace ops
505 }  // namespace tensorflow
506