1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/nn_ops.h"
17 #include "tensorflow/cc/ops/nn_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 
20 #include "tensorflow/cc/framework/grad_op_registry.h"
21 #include "tensorflow/cc/framework/gradients.h"
22 
23 namespace tensorflow {
24 namespace ops {
25 namespace {
26 
SoftmaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)27 Status SoftmaxGrad(const Scope& scope, const Operation& op,
28                    const std::vector<Output>& grad_inputs,
29                    std::vector<Output>* grad_outputs) {
30   // Softmax gradient function.
31   // p = softmax(x) maps from [batch, n] to [batch, m]
32   // dp/dx = [dp0/dx0   ... dp0/dxn-1  ]
33   //         [  ...           ...      ]
34   //         [dpm-1/dx0 ... dpm-1/dxn-1]
35   // dL/dx = dp/dx * dL/dy
36   //
37   // Using alternative formula:
38   // dL/dx = dL/dy * y - sum(dL/dy * y) * y
39   //    = (dL/dy - sum(dL/dy * y)) * y
40   auto y = op.output(0);
41   auto dyy = Mul(scope, grad_inputs[0], y);
42   auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
43   auto sub = Sub(scope, grad_inputs[0], sum);
44   auto dx = Mul(scope, sub, y);
45   grad_outputs->push_back(dx);
46   return scope.status();
47 }
48 REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
49 
IsZero(const Scope & scope,const Output & grad)50 bool IsZero(const Scope& scope, const Output& grad) {
51   string op_type_name = grad.op().node()->type_string();
52   if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
53     return true;
54   }
55   // The Operation we were provided is not named something obvious so
56   // we need to actually look at its contents.
57   // The original python code did this by calling a utility function called
58   // tensor_util.constant_value.
59   // There is no C++ equivalent to tensor_util.constant_value so we do nothing
60   // for the moment.
61   return false;
62 }
63 
64 // Multiply after broadcasting vec to match dimensions of mat.
65 //   Args:
66 //     vec: A 1-D tensor of dimension [D0]
67 //     mat: A 2-D tensor of dimesnion [D0, D1]
68 //
69 //   Returns:
70 //     A tensor of dimension [D0, D1], the result fo vec * mat.
BroadcastMul(const Scope & scope,const Output & vec,const Output & mat)71 Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
72   auto reshaped = ExpandDims(scope, vec, -1);
73   return Multiply(scope, reshaped, mat);
74 }
75 
SoftmaxCrossEntropyWithLogitsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
77                                          const Operation& op,
78                                          const std::vector<Output>& grad_inputs,
79                                          std::vector<Output>* grad_outputs) {
80   // Softmax gradient with cross entropy logits function.
81   // We multiply the backprop for cost with the gradients - op.output[1].
82   // There is no gradient for labels.
83 
84   // The outputs of the network are at input index 0.
85   auto logits = op.input(0);
86   // The "truth" labels are at index 1.
87   auto softmax_grad = op.output(1);
88 
89   // The loss is the output at index 0, and backprop is the output at index 1.
90   auto grad_loss = grad_inputs[0];
91   auto grad_grad = grad_inputs[1];
92 
93   auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
94   if (!IsZero(scope, grad_grad)) {
95     std::vector<int> axis;
96     auto logits_softmax = Softmax(scope, logits);
97 
98     auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
99     auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
100     auto matmul_result =
101         BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
102     axis.push_back(1);
103     auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
104     auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
105     auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
106     grad = Add(scope, grad, multiply_result);
107   }
108   auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
109   grad_outputs->push_back(grad);
110   grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
111   return scope.status();
112 }
113 REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
114                      SoftmaxCrossEntropyWithLogitsGrad);
115 
LogSoftmaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)116 Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
117                       const std::vector<Output>& grad_inputs,
118                       std::vector<Output>* grad_outputs) {
119   auto softmax = Exp(scope, op.output(0));
120   auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true));
121   auto mul = Mul(scope, sum, softmax);
122   auto dx = Sub(scope, grad_inputs[0], mul);
123   grad_outputs->push_back(dx);
124   return scope.status();
125 }
126 REGISTER_GRADIENT_OP("LogSoftmax", LogSoftmaxGrad);
127 
ReluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)128 Status ReluGradHelper(const Scope& scope, const Operation& op,
129                       const std::vector<Output>& grad_inputs,
130                       std::vector<Output>* grad_outputs) {
131   auto dx = internal::ReluGrad(scope, grad_inputs[0], op.input(0));
132   grad_outputs->push_back(dx);
133   return scope.status();
134 }
135 REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
136 
Relu6GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)137 Status Relu6GradHelper(const Scope& scope, const Operation& op,
138                        const std::vector<Output>& grad_inputs,
139                        std::vector<Output>* grad_outputs) {
140   auto dx = internal::Relu6Grad(scope, grad_inputs[0], op.input(0));
141   grad_outputs->push_back(dx);
142   return scope.status();
143 }
144 REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
145 
LeakyReluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)146 Status LeakyReluGradHelper(const Scope& scope, const Operation& op,
147                            const std::vector<Output>& grad_inputs,
148                            std::vector<Output>* grad_outputs) {
149   float alpha;
150   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
151   internal::LeakyReluGrad::Attrs attrs;
152   auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0),
153                                     attrs.Alpha(alpha));
154   grad_outputs->push_back(dx);
155   return scope.status();
156 }
157 REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper);
158 
LeakyReluGradGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)159 Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op,
160                                const std::vector<Output>& grad_inputs,
161                                std::vector<Output>* grad_outputs) {
162   float alpha;
163   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
164   internal::LeakyReluGrad::Attrs attrs;
165   auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1),
166                                     attrs.Alpha(alpha));
167   grad_outputs->push_back(dx);
168   grad_outputs->push_back(NoGradient());
169   return scope.status();
170 }
171 REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper);
172 
EluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)173 Status EluGradHelper(const Scope& scope, const Operation& op,
174                      const std::vector<Output>& grad_inputs,
175                      std::vector<Output>* grad_outputs) {
176   auto dx = internal::EluGrad(scope, grad_inputs[0], op.output(0));
177   grad_outputs->push_back(dx);
178   return scope.status();
179 }
180 REGISTER_GRADIENT_OP("Elu", EluGradHelper);
181 
SeluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)182 Status SeluGradHelper(const Scope& scope, const Operation& op,
183                       const std::vector<Output>& grad_inputs,
184                       std::vector<Output>* grad_outputs) {
185   auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0));
186   grad_outputs->push_back(dx);
187   return scope.status();
188 }
189 REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
190 
L2LossGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)191 Status L2LossGrad(const Scope& scope, const Operation& op,
192                   const std::vector<Output>& grad_inputs,
193                   std::vector<Output>* grad_outputs) {
194   grad_outputs->push_back(Mul(scope, op.input(0), grad_inputs[0]));
195   return scope.status();
196 }
197 REGISTER_GRADIENT_OP("L2Loss", L2LossGrad);
198 
BiasAddGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)199 Status BiasAddGradHelper(const Scope& scope, const Operation& op,
200                          const std::vector<Output>& grad_inputs,
201                          std::vector<Output>* grad_outputs) {
202   string data_format;
203   TF_RETURN_IF_ERROR(
204       GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format));
205   auto dx_1 =
206       BiasAddGrad(scope, grad_inputs[0], BiasAddGrad::DataFormat(data_format));
207   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
208   grad_outputs->push_back(dx_1);
209   return scope.status();
210 }
211 REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper);
212 
Conv2DGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status Conv2DGrad(const Scope& scope, const Operation& op,
214                   const std::vector<Output>& grad_inputs,
215                   std::vector<Output>* grad_outputs) {
216   string data_format;
217   string padding;
218   std::vector<int32> strides;
219   bool use_cudnn_on_gpu;
220   auto attrs = op.output(0).node()->attrs();
221   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
222   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
223   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
224   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu));
225   auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), op.input(1),
226                                   grad_inputs[0], strides, padding,
227                                   Conv2DBackpropInput::DataFormat(data_format)
228                                       .UseCudnnOnGpu(use_cudnn_on_gpu));
229   grad_outputs->push_back(dx_1);
230   auto dx_2 =
231       Conv2DBackpropFilter(scope, op.input(0), Shape(scope, op.input(1)),
232                            grad_inputs[0], strides, padding,
233                            Conv2DBackpropFilter::DataFormat(data_format)
234                                .UseCudnnOnGpu(use_cudnn_on_gpu));
235   grad_outputs->push_back(dx_2);
236   return scope.status();
237 }
238 REGISTER_GRADIENT_OP("Conv2D", Conv2DGrad);
239 
MaxPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)240 Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
241                          const std::vector<Output>& grad_inputs,
242                          std::vector<Output>* grad_outputs) {
243   string data_format;
244   string padding;
245   std::vector<int32> strides;
246   std::vector<int32> ksize;
247   auto attrs = op.output(0).node()->attrs();
248   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
249   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
250   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
251   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
252   auto dx = internal::MaxPoolGrad(
253       scope, op.input(0), op.output(0), grad_inputs[0], ksize, strides, padding,
254       internal::MaxPoolGrad::DataFormat(data_format));
255   grad_outputs->push_back(dx);
256   return scope.status();
257 }
258 REGISTER_GRADIENT_OP("MaxPool", MaxPoolGradHelper);
259 
MaxPoolGradV2Helper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)260 Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
261                            const std::vector<Output>& grad_inputs,
262                            std::vector<Output>* grad_outputs) {
263   string data_format;
264   string padding;
265   auto attrs = op.output(0).node()->attrs();
266   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
267   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
268   auto dx = MaxPoolGradV2(scope, op.input(0), op.output(0), grad_inputs[0],
269                           op.input(1), op.input(2), padding,
270                           MaxPoolGradV2::DataFormat(data_format));
271   grad_outputs->push_back(dx);
272   grad_outputs->push_back(NoGradient());
273   grad_outputs->push_back(NoGradient());
274   return scope.status();
275 }
276 REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper);
277 
MaxPool3DGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)278 Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
279                            const std::vector<Output>& grad_inputs,
280                            std::vector<Output>* grad_outputs) {
281   std::vector<int32> ksize;
282   std::vector<int32> strides;
283   string padding;
284   string data_format;
285   auto attrs = op.output(0).node()->attrs();
286   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
287   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
288   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
289   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
290   MaxPool3DGrad::Attrs grad_attrs;
291   auto dx =
292       MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize,
293                     strides, padding, grad_attrs.DataFormat(data_format));
294   grad_outputs->push_back(dx);
295   return scope.status();
296 }
297 REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper);
298 
AvgPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)299 Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
300                          const std::vector<Output>& grad_inputs,
301                          std::vector<Output>* grad_outputs) {
302   std::vector<int32> ksize;
303   std::vector<int32> strides;
304   string padding;
305   string data_format;
306   auto attrs = op.output(0).node()->attrs();
307   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
308   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
309   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
310   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
311   internal::AvgPoolGrad::Attrs grad_attrs;
312   auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)),
313                                   grad_inputs[0], ksize, strides, padding,
314                                   grad_attrs.DataFormat(data_format));
315   grad_outputs->push_back(dx);
316   return scope.status();
317 }
318 REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper);
319 
AvgPool3DGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)320 Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
321                            const std::vector<Output>& grad_inputs,
322                            std::vector<Output>* grad_outputs) {
323   std::vector<int32> ksize;
324   std::vector<int32> strides;
325   string padding;
326   string data_format;
327   auto attrs = op.output(0).node()->attrs();
328   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
329   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
330   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
331   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
332   AvgPool3DGrad::Attrs grad_attrs;
333   auto dx =
334       AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize,
335                     strides, padding, grad_attrs.DataFormat(data_format));
336   grad_outputs->push_back(dx);
337   return scope.status();
338 }
339 REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper);
340 
LRNGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)341 Status LRNGradHelper(const Scope& scope, const Operation& op,
342                      const std::vector<Output>& grad_inputs,
343                      std::vector<Output>* grad_outputs) {
344   auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0));
345   grad_outputs->push_back(dx);
346   return scope.status();
347 }
348 REGISTER_GRADIENT_OP("LRN", LRNGradHelper);
349 
SoftplusGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)350 Status SoftplusGradHelper(const Scope& scope, const Operation& op,
351                           const std::vector<Output>& grad_inputs,
352                           std::vector<Output>* grad_outputs) {
353   auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0));
354   grad_outputs->push_back(dx);
355   return scope.status();
356 }
357 REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper);
358 
SoftsignGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)359 Status SoftsignGradHelper(const Scope& scope, const Operation& op,
360                           const std::vector<Output>& grad_inputs,
361                           std::vector<Output>* grad_outputs) {
362   auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0));
363   grad_outputs->push_back(dx);
364   return scope.status();
365 }
366 REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper);
367 
FractionalAvgPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)368 Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op,
369                                    const std::vector<Output>& grad_inputs,
370                                    std::vector<Output>* grad_outputs) {
371   bool overlapping;
372   TF_RETURN_IF_ERROR(
373       GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping));
374   auto dx = internal::FractionalAvgPoolGrad(
375       scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)),
376       grad_inputs[0], op.output(1), op.output(2),
377       internal::FractionalAvgPoolGrad::Overlapping(overlapping));
378   grad_outputs->push_back(dx);
379   return scope.status();
380 }
381 REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper);
382 
FractionalMaxPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)383 Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op,
384                                    const std::vector<Output>& grad_inputs,
385                                    std::vector<Output>* grad_outputs) {
386   bool overlapping;
387   TF_RETURN_IF_ERROR(
388       GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping));
389   auto dx = internal::FractionalMaxPoolGrad(
390       scope, op.input(0), op.output(0), grad_inputs[0], op.output(1),
391       op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping));
392   grad_outputs->push_back(dx);
393   return scope.status();
394 }
395 REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper);
396 
397 }  // anonymous namespace
398 }  // namespace ops
399 }  // namespace tensorflow
400