1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/nn_ops.h"
17 #include "tensorflow/cc/ops/nn_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19
20 #include "tensorflow/cc/framework/grad_op_registry.h"
21 #include "tensorflow/cc/framework/gradients.h"
22
23 namespace tensorflow {
24 namespace ops {
25 namespace {
26
SoftmaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)27 Status SoftmaxGrad(const Scope& scope, const Operation& op,
28 const std::vector<Output>& grad_inputs,
29 std::vector<Output>* grad_outputs) {
30 // Softmax gradient function.
31 // p = softmax(x) maps from [batch, n] to [batch, m]
32 // dp/dx = [dp0/dx0 ... dp0/dxn-1 ]
33 // [ ... ... ]
34 // [dpm-1/dx0 ... dpm-1/dxn-1]
35 // dL/dx = dp/dx * dL/dy
36 //
37 // Using alternative formula:
38 // dL/dx = dL/dy * y - sum(dL/dy * y) * y
39 // = (dL/dy - sum(dL/dy * y)) * y
40 auto y = op.output(0);
41 auto dyy = Mul(scope, grad_inputs[0], y);
42 auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
43 auto sub = Sub(scope, grad_inputs[0], sum);
44 auto dx = Mul(scope, sub, y);
45 grad_outputs->push_back(dx);
46 return scope.status();
47 }
48 REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
49
IsZero(const Scope & scope,const Output & grad)50 bool IsZero(const Scope& scope, const Output& grad) {
51 string op_type_name = grad.op().node()->type_string();
52 if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
53 return true;
54 }
55 // The Operation we were provided is not named something obvious so
56 // we need to actually look at its contents.
57 // The original python code did this by calling a utility function called
58 // tensor_util.constant_value.
59 // There is no C++ equivalent to tensor_util.constant_value so we do nothing
60 // for the moment.
61 return false;
62 }
63
64 // Multiply after broadcasting vec to match dimensions of mat.
65 // Args:
66 // vec: A 1-D tensor of dimension [D0]
67 // mat: A 2-D tensor of dimesnion [D0, D1]
68 //
69 // Returns:
70 // A tensor of dimension [D0, D1], the result fo vec * mat.
BroadcastMul(const Scope & scope,const Output & vec,const Output & mat)71 Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
72 auto reshaped = ExpandDims(scope, vec, -1);
73 return Multiply(scope, reshaped, mat);
74 }
75
SoftmaxCrossEntropyWithLogitsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
77 const Operation& op,
78 const std::vector<Output>& grad_inputs,
79 std::vector<Output>* grad_outputs) {
80 // Softmax gradient with cross entropy logits function.
81 // We multiply the backprop for cost with the gradients - op.output[1].
82 // There is no gradient for labels.
83
84 // The outputs of the network are at input index 0.
85 auto logits = op.input(0);
86 // The "truth" labels are at index 1.
87 auto softmax_grad = op.output(1);
88
89 // The loss is the output at index 0, and backprop is the output at index 1.
90 auto grad_loss = grad_inputs[0];
91 auto grad_grad = grad_inputs[1];
92
93 auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
94 if (!IsZero(scope, grad_grad)) {
95 std::vector<int> axis;
96 auto logits_softmax = Softmax(scope, logits);
97
98 auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
99 auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
100 auto matmul_result =
101 BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
102 axis.push_back(1);
103 auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
104 auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
105 auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
106 grad = Add(scope, grad, multiply_result);
107 }
108 auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
109 grad_outputs->push_back(grad);
110 grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
111 return scope.status();
112 }
113 REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
114 SoftmaxCrossEntropyWithLogitsGrad);
115
LogSoftmaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)116 Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
117 const std::vector<Output>& grad_inputs,
118 std::vector<Output>* grad_outputs) {
119 auto softmax = Exp(scope, op.output(0));
120 auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true));
121 auto mul = Mul(scope, sum, softmax);
122 auto dx = Sub(scope, grad_inputs[0], mul);
123 grad_outputs->push_back(dx);
124 return scope.status();
125 }
126 REGISTER_GRADIENT_OP("LogSoftmax", LogSoftmaxGrad);
127
ReluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)128 Status ReluGradHelper(const Scope& scope, const Operation& op,
129 const std::vector<Output>& grad_inputs,
130 std::vector<Output>* grad_outputs) {
131 auto dx = internal::ReluGrad(scope, grad_inputs[0], op.input(0));
132 grad_outputs->push_back(dx);
133 return scope.status();
134 }
135 REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
136
Relu6GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)137 Status Relu6GradHelper(const Scope& scope, const Operation& op,
138 const std::vector<Output>& grad_inputs,
139 std::vector<Output>* grad_outputs) {
140 auto dx = internal::Relu6Grad(scope, grad_inputs[0], op.input(0));
141 grad_outputs->push_back(dx);
142 return scope.status();
143 }
144 REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
145
LeakyReluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)146 Status LeakyReluGradHelper(const Scope& scope, const Operation& op,
147 const std::vector<Output>& grad_inputs,
148 std::vector<Output>* grad_outputs) {
149 float alpha;
150 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
151 internal::LeakyReluGrad::Attrs attrs;
152 auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0),
153 attrs.Alpha(alpha));
154 grad_outputs->push_back(dx);
155 return scope.status();
156 }
157 REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper);
158
LeakyReluGradGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)159 Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op,
160 const std::vector<Output>& grad_inputs,
161 std::vector<Output>* grad_outputs) {
162 float alpha;
163 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
164 internal::LeakyReluGrad::Attrs attrs;
165 auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1),
166 attrs.Alpha(alpha));
167 grad_outputs->push_back(dx);
168 grad_outputs->push_back(NoGradient());
169 return scope.status();
170 }
171 REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper);
172
EluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)173 Status EluGradHelper(const Scope& scope, const Operation& op,
174 const std::vector<Output>& grad_inputs,
175 std::vector<Output>* grad_outputs) {
176 auto dx = internal::EluGrad(scope, grad_inputs[0], op.output(0));
177 grad_outputs->push_back(dx);
178 return scope.status();
179 }
180 REGISTER_GRADIENT_OP("Elu", EluGradHelper);
181
SeluGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)182 Status SeluGradHelper(const Scope& scope, const Operation& op,
183 const std::vector<Output>& grad_inputs,
184 std::vector<Output>* grad_outputs) {
185 auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0));
186 grad_outputs->push_back(dx);
187 return scope.status();
188 }
189 REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
190
L2LossGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)191 Status L2LossGrad(const Scope& scope, const Operation& op,
192 const std::vector<Output>& grad_inputs,
193 std::vector<Output>* grad_outputs) {
194 grad_outputs->push_back(Mul(scope, op.input(0), grad_inputs[0]));
195 return scope.status();
196 }
197 REGISTER_GRADIENT_OP("L2Loss", L2LossGrad);
198
BiasAddGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)199 Status BiasAddGradHelper(const Scope& scope, const Operation& op,
200 const std::vector<Output>& grad_inputs,
201 std::vector<Output>* grad_outputs) {
202 string data_format;
203 TF_RETURN_IF_ERROR(
204 GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format));
205 auto dx_1 =
206 BiasAddGrad(scope, grad_inputs[0], BiasAddGrad::DataFormat(data_format));
207 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
208 grad_outputs->push_back(dx_1);
209 return scope.status();
210 }
211 REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper);
212
Conv2DGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status Conv2DGrad(const Scope& scope, const Operation& op,
214 const std::vector<Output>& grad_inputs,
215 std::vector<Output>* grad_outputs) {
216 string data_format;
217 string padding;
218 std::vector<int32> strides;
219 bool use_cudnn_on_gpu;
220 auto attrs = op.output(0).node()->attrs();
221 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
222 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
223 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
224 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu));
225 auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), op.input(1),
226 grad_inputs[0], strides, padding,
227 Conv2DBackpropInput::DataFormat(data_format)
228 .UseCudnnOnGpu(use_cudnn_on_gpu));
229 grad_outputs->push_back(dx_1);
230 auto dx_2 =
231 Conv2DBackpropFilter(scope, op.input(0), Shape(scope, op.input(1)),
232 grad_inputs[0], strides, padding,
233 Conv2DBackpropFilter::DataFormat(data_format)
234 .UseCudnnOnGpu(use_cudnn_on_gpu));
235 grad_outputs->push_back(dx_2);
236 return scope.status();
237 }
238 REGISTER_GRADIENT_OP("Conv2D", Conv2DGrad);
239
MaxPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)240 Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
241 const std::vector<Output>& grad_inputs,
242 std::vector<Output>* grad_outputs) {
243 string data_format;
244 string padding;
245 std::vector<int32> strides;
246 std::vector<int32> ksize;
247 auto attrs = op.output(0).node()->attrs();
248 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
249 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
250 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
251 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
252 auto dx = internal::MaxPoolGrad(
253 scope, op.input(0), op.output(0), grad_inputs[0], ksize, strides, padding,
254 internal::MaxPoolGrad::DataFormat(data_format));
255 grad_outputs->push_back(dx);
256 return scope.status();
257 }
258 REGISTER_GRADIENT_OP("MaxPool", MaxPoolGradHelper);
259
MaxPoolGradV2Helper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)260 Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
261 const std::vector<Output>& grad_inputs,
262 std::vector<Output>* grad_outputs) {
263 string data_format;
264 string padding;
265 auto attrs = op.output(0).node()->attrs();
266 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
267 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
268 auto dx = MaxPoolGradV2(scope, op.input(0), op.output(0), grad_inputs[0],
269 op.input(1), op.input(2), padding,
270 MaxPoolGradV2::DataFormat(data_format));
271 grad_outputs->push_back(dx);
272 grad_outputs->push_back(NoGradient());
273 grad_outputs->push_back(NoGradient());
274 return scope.status();
275 }
276 REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper);
277
MaxPool3DGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)278 Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
279 const std::vector<Output>& grad_inputs,
280 std::vector<Output>* grad_outputs) {
281 std::vector<int32> ksize;
282 std::vector<int32> strides;
283 string padding;
284 string data_format;
285 auto attrs = op.output(0).node()->attrs();
286 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
287 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
288 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
289 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
290 MaxPool3DGrad::Attrs grad_attrs;
291 auto dx =
292 MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize,
293 strides, padding, grad_attrs.DataFormat(data_format));
294 grad_outputs->push_back(dx);
295 return scope.status();
296 }
297 REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper);
298
AvgPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)299 Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
300 const std::vector<Output>& grad_inputs,
301 std::vector<Output>* grad_outputs) {
302 std::vector<int32> ksize;
303 std::vector<int32> strides;
304 string padding;
305 string data_format;
306 auto attrs = op.output(0).node()->attrs();
307 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
308 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
309 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
310 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
311 internal::AvgPoolGrad::Attrs grad_attrs;
312 auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)),
313 grad_inputs[0], ksize, strides, padding,
314 grad_attrs.DataFormat(data_format));
315 grad_outputs->push_back(dx);
316 return scope.status();
317 }
318 REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper);
319
AvgPool3DGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)320 Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
321 const std::vector<Output>& grad_inputs,
322 std::vector<Output>* grad_outputs) {
323 std::vector<int32> ksize;
324 std::vector<int32> strides;
325 string padding;
326 string data_format;
327 auto attrs = op.output(0).node()->attrs();
328 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
329 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
330 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
331 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
332 AvgPool3DGrad::Attrs grad_attrs;
333 auto dx =
334 AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize,
335 strides, padding, grad_attrs.DataFormat(data_format));
336 grad_outputs->push_back(dx);
337 return scope.status();
338 }
339 REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper);
340
LRNGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)341 Status LRNGradHelper(const Scope& scope, const Operation& op,
342 const std::vector<Output>& grad_inputs,
343 std::vector<Output>* grad_outputs) {
344 auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0));
345 grad_outputs->push_back(dx);
346 return scope.status();
347 }
348 REGISTER_GRADIENT_OP("LRN", LRNGradHelper);
349
SoftplusGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)350 Status SoftplusGradHelper(const Scope& scope, const Operation& op,
351 const std::vector<Output>& grad_inputs,
352 std::vector<Output>* grad_outputs) {
353 auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0));
354 grad_outputs->push_back(dx);
355 return scope.status();
356 }
357 REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper);
358
SoftsignGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)359 Status SoftsignGradHelper(const Scope& scope, const Operation& op,
360 const std::vector<Output>& grad_inputs,
361 std::vector<Output>* grad_outputs) {
362 auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0));
363 grad_outputs->push_back(dx);
364 return scope.status();
365 }
366 REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper);
367
FractionalAvgPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)368 Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op,
369 const std::vector<Output>& grad_inputs,
370 std::vector<Output>* grad_outputs) {
371 bool overlapping;
372 TF_RETURN_IF_ERROR(
373 GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping));
374 auto dx = internal::FractionalAvgPoolGrad(
375 scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)),
376 grad_inputs[0], op.output(1), op.output(2),
377 internal::FractionalAvgPoolGrad::Overlapping(overlapping));
378 grad_outputs->push_back(dx);
379 return scope.status();
380 }
381 REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper);
382
FractionalMaxPoolGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)383 Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op,
384 const std::vector<Output>& grad_inputs,
385 std::vector<Output>* grad_outputs) {
386 bool overlapping;
387 TF_RETURN_IF_ERROR(
388 GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping));
389 auto dx = internal::FractionalMaxPoolGrad(
390 scope, op.input(0), op.output(0), grad_inputs[0], op.output(1),
391 op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping));
392 grad_outputs->push_back(dx);
393 return scope.status();
394 }
395 REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper);
396
397 } // anonymous namespace
398 } // namespace ops
399 } // namespace tensorflow
400