1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <atomic>
17
18 #define EIGEN_USE_THREADS
19
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #define EIGEN_USE_GPU
22 #if GOOGLE_CUDA
23 #include "third_party/gpus/cudnn/cudnn.h"
24 #endif // GOOGLE_CUDA
25
26 #include "tensorflow/core/kernels/conv_2d.h"
27 #include "tensorflow/core/platform/stream_executor.h"
28 #include "tensorflow/core/util/stream_executor_util.h"
29 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/kernels/fill_functor.h"
37 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
38 #include "tensorflow/core/kernels/redux_functor.h"
39 #include "tensorflow/core/kernels/transpose_functor.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/util/env_var.h"
42 #include "tensorflow/core/util/tensor_format.h"
43
44 namespace tensorflow {
45 using CPUDevice = Eigen::ThreadPoolDevice;
46 using GPUDevice = Eigen::GpuDevice;
47
48 namespace functor {
49
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 using se::DeviceMemory;
52 using se::ScratchAllocator;
53 using se::Stream;
54 using se::port::StatusOr;
55 #endif
56
ToString(FusedBatchNormActivationMode activation_mode)57 string ToString(FusedBatchNormActivationMode activation_mode) {
58 switch (activation_mode) {
59 case FusedBatchNormActivationMode::kIdentity:
60 return "Identity";
61 case FusedBatchNormActivationMode::kRelu:
62 return "Relu";
63 }
64 }
65
ParseActivationMode(OpKernelConstruction * context,FusedBatchNormActivationMode * activation_mode)66 Status ParseActivationMode(OpKernelConstruction* context,
67 FusedBatchNormActivationMode* activation_mode) {
68 string activation_mode_str;
69 TF_RETURN_IF_ERROR(context->GetAttr("activation_mode", &activation_mode_str));
70
71 if (activation_mode_str == "Identity") {
72 *activation_mode = FusedBatchNormActivationMode::kIdentity;
73 return Status::OK();
74 }
75 if (activation_mode_str == "Relu") {
76 *activation_mode = FusedBatchNormActivationMode::kRelu;
77 return Status::OK();
78 }
79 return errors::InvalidArgument("Unsupported activation mode: ",
80 activation_mode_str);
81 }
82
83 // Functor used by FusedBatchNormOp to do the computations.
84 template <typename Device, typename T, typename U, bool is_training>
85 struct FusedBatchNorm;
86 // Functor used by FusedBatchNormGradOp to do the computations when
87 // is_training=True.
88 template <typename Device, typename T, typename U>
89 struct FusedBatchNormGrad;
90
91 template <typename T, typename U>
92 struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> {
operator ()tensorflow::functor::FusedBatchNorm93 void operator()(OpKernelContext* context, const Tensor& x_input,
94 const Tensor& scale_input, const Tensor& offset_input,
95 const Tensor& running_mean_input,
96 const Tensor& running_variance_input,
97 const Tensor* side_input, U epsilon, U exponential_avg_factor,
98 FusedBatchNormActivationMode activation_mode,
99 Tensor* y_output, Tensor* running_mean_output,
100 Tensor* running_var_output, Tensor* saved_batch_mean_output,
101 Tensor* saved_batch_var_output, TensorFormat tensor_format,
102 bool use_reserved_space) {
103 OP_REQUIRES(context, side_input == nullptr,
104 errors::Internal(
105 "The CPU implementation of FusedBatchNorm does not support "
106 "side input."));
107 OP_REQUIRES(context,
108 activation_mode == FusedBatchNormActivationMode::kIdentity,
109 errors::Internal("The CPU implementation of FusedBatchNorm "
110 "does not support activations."));
111
112 if (use_reserved_space) {
113 Tensor* dummy_reserve_space = nullptr;
114 OP_REQUIRES_OK(context,
115 context->allocate_output(5, {}, &dummy_reserve_space));
116 // Initialize the memory, to avoid sanitizer alerts.
117 dummy_reserve_space->flat<U>()(0) = U();
118 }
119 Tensor transformed_x;
120 Tensor transformed_y;
121 if (tensor_format == FORMAT_NCHW) {
122 const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
123 const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
124 const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
125 const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
126 OP_REQUIRES_OK(context, context->allocate_temp(
127 DataTypeToEnum<T>::value,
128 ShapeFromFormat(FORMAT_NHWC, in_batch,
129 in_rows, in_cols, in_depths),
130 &transformed_x));
131 OP_REQUIRES_OK(context, context->allocate_temp(
132 DataTypeToEnum<T>::value,
133 ShapeFromFormat(FORMAT_NHWC, in_batch,
134 in_rows, in_cols, in_depths),
135 &transformed_y));
136 // Perform NCHW to NHWC
137 std::vector<int32> perm = {0, 2, 3, 1};
138 OP_REQUIRES_OK(
139 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
140 x_input, perm, &transformed_x));
141 } else {
142 transformed_x = x_input;
143 transformed_y = *y_output;
144 }
145 typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
146 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
147 typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
148 typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>());
149 typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>());
150 typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
151 typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>());
152 typename TTypes<U>::Vec new_variance(running_var_output->vec<U>());
153 typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>());
154 typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>());
155
156 const CPUDevice& d = context->eigen_device<CPUDevice>();
157
158 const int depth = x.dimension(3);
159 const int size = x.size();
160 const int rest_size = size / depth;
161 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
162
163 #if !defined(EIGEN_HAS_INDEX_LIST)
164 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
165 Eigen::array<int, 1> reduce_dims({0});
166 Eigen::array<int, 2> bcast_spec({rest_size, 1});
167 #else
168 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
169 one_by_depth.set(1, depth);
170 Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
171 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
172 bcast_spec.set(0, rest_size);
173 #endif
174
175 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
176 const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
177 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
178 // This adjustment is for Bessel's correction
179 U rest_size_adjust =
180 static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
181
182 Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth);
183 Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth);
184
185 batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
186 auto x_centered = x_rest_by_depth -
187 batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
188
189 batch_variance.device(d) =
190 x_centered.square().sum(reduce_dims) * rest_size_inv;
191 auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale)
192 .eval()
193 .reshape(one_by_depth)
194 .broadcast(bcast_spec);
195 auto x_scaled = x_centered * scaling_factor;
196 auto x_shifted =
197 (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
198 .template cast<T>();
199
200 y.reshape(rest_by_depth).device(d) = x_shifted;
201 if (exponential_avg_factor == U(1.0)) {
202 saved_batch_var.device(d) = batch_variance;
203 saved_batch_mean.device(d) = batch_mean;
204 new_variance.device(d) = batch_variance * rest_size_adjust;
205 new_mean.device(d) = batch_mean;
206 } else {
207 U one_minus_factor = U(1) - exponential_avg_factor;
208 saved_batch_var.device(d) = batch_variance;
209 saved_batch_mean.device(d) = batch_mean;
210 new_variance.device(d) =
211 one_minus_factor * old_variance +
212 (exponential_avg_factor * rest_size_adjust) * batch_variance;
213 new_mean.device(d) =
214 one_minus_factor * old_mean + exponential_avg_factor * batch_mean;
215 }
216
217 if (tensor_format == FORMAT_NCHW) {
218 // Perform NHWC to NCHW
219 const std::vector<int32> perm = {0, 3, 1, 2};
220 const Status s = ::tensorflow::DoTranspose(
221 context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
222 if (!s.ok()) {
223 context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
224 }
225 }
226 }
227 };
228
229 template <typename T, typename U>
230 struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
operator ()tensorflow::functor::FusedBatchNorm231 void operator()(OpKernelContext* context, const Tensor& x_input,
232 const Tensor& scale_input, const Tensor& offset_input,
233 const Tensor& estimated_mean_input,
234 const Tensor& estimated_variance_input,
235 const Tensor* side_input, U epsilon, U exponential_avg_factor,
236 FusedBatchNormActivationMode activation_mode,
237 Tensor* y_output, Tensor* batch_mean_output,
238 Tensor* batch_var_output, Tensor* saved_mean_output,
239 Tensor* saved_var_output, TensorFormat tensor_format,
240 bool use_reserved_space) {
241 OP_REQUIRES(context, side_input == nullptr,
242 errors::Internal(
243 "The CPU implementation of FusedBatchNorm does not support "
244 "side input."));
245 OP_REQUIRES(context,
246 activation_mode == FusedBatchNormActivationMode::kIdentity,
247 errors::Internal("The CPU implementation of FusedBatchNorm "
248 "does not support activations."));
249
250 if (use_reserved_space) {
251 Tensor* dummy_reserve_space = nullptr;
252 OP_REQUIRES_OK(context,
253 context->allocate_output(5, {}, &dummy_reserve_space));
254 // Initialize the memory, to avoid sanitizer alerts.
255 dummy_reserve_space->flat<U>()(0) = U();
256 }
257 Tensor transformed_x;
258 Tensor transformed_y;
259 if (tensor_format == FORMAT_NCHW) {
260 const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
261 const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
262 const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
263 const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
264 OP_REQUIRES_OK(context, context->allocate_temp(
265 DataTypeToEnum<T>::value,
266 ShapeFromFormat(FORMAT_NHWC, in_batch,
267 in_rows, in_cols, in_depths),
268 &transformed_x));
269 OP_REQUIRES_OK(context, context->allocate_temp(
270 DataTypeToEnum<T>::value,
271 ShapeFromFormat(FORMAT_NHWC, in_batch,
272 in_rows, in_cols, in_depths),
273 &transformed_y));
274 // Perform NCHW to NHWC
275 std::vector<int32> perm = {0, 2, 3, 1};
276 OP_REQUIRES_OK(
277 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
278 x_input, perm, &transformed_x));
279 } else {
280 transformed_x = x_input;
281 transformed_y = *y_output;
282 }
283 typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
284 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
285 typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
286 typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
287 typename TTypes<U>::ConstVec estimated_variance(
288 estimated_variance_input.vec<U>());
289 typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
290 typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
291 typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>());
292
293 const CPUDevice& d = context->eigen_device<CPUDevice>();
294
295 const int depth = x.dimension(3);
296 const int size = x.size();
297 const int rest_size = size / depth;
298 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
299
300 #if !defined(EIGEN_HAS_INDEX_LIST)
301 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
302 Eigen::array<int, 1> reduce_dims({0});
303 Eigen::array<int, 2> bcast_spec({rest_size, 1});
304 #else
305 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
306 one_by_depth.set(1, depth);
307 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
308 bcast_spec.set(0, rest_size);
309 #endif
310
311 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
312 auto x_centered =
313 x_rest_by_depth -
314 estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
315 auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
316 .eval()
317 .reshape(one_by_depth)
318 .broadcast(bcast_spec);
319 auto x_scaled = x_centered * scaling_factor;
320 auto x_shifted =
321 (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
322 .template cast<T>();
323
324 y.reshape(rest_by_depth).device(d) = x_shifted;
325 batch_mean.device(d) = estimated_mean;
326 batch_variance.device(d) = estimated_variance;
327
328 if (tensor_format == FORMAT_NCHW) {
329 // Perform NHWC to NCHW
330 const std::vector<int32> perm = {0, 3, 1, 2};
331 const Status s = ::tensorflow::DoTranspose(
332 context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
333 if (!s.ok()) {
334 context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
335 }
336 }
337 }
338 };
339
340 template <typename T, typename U>
341 struct FusedBatchNormGrad<CPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormGrad342 void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
343 const Tensor& x_input, const Tensor& scale_input,
344 const Tensor& mean_input, const Tensor& variance_input,
345 U epsilon, Tensor* x_backprop_output,
346 Tensor* scale_backprop_output, Tensor* offset_backprop_output,
347 bool use_reserved_space, TensorFormat tensor_format) {
348 Tensor transformed_y_backprop_input;
349 Tensor transformed_x_input;
350 Tensor transformed_x_backprop_output;
351 if (tensor_format == FORMAT_NCHW) {
352 const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
353 const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
354 const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
355 const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
356 OP_REQUIRES_OK(context, context->allocate_temp(
357 DataTypeToEnum<T>::value,
358 ShapeFromFormat(FORMAT_NHWC, in_batch,
359 in_rows, in_cols, in_depths),
360 &transformed_y_backprop_input));
361 OP_REQUIRES_OK(context, context->allocate_temp(
362 DataTypeToEnum<T>::value,
363 ShapeFromFormat(FORMAT_NHWC, in_batch,
364 in_rows, in_cols, in_depths),
365 &transformed_x_input));
366 OP_REQUIRES_OK(context, context->allocate_temp(
367 DataTypeToEnum<T>::value,
368 ShapeFromFormat(FORMAT_NHWC, in_batch,
369 in_rows, in_cols, in_depths),
370 &transformed_x_backprop_output));
371 // Perform NCHW to NHWC
372 std::vector<int32> perm = {0, 2, 3, 1};
373 OP_REQUIRES_OK(
374 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
375 y_backprop_input, perm,
376 &transformed_y_backprop_input));
377 OP_REQUIRES_OK(context, ::tensorflow::DoTranspose(
378 context->eigen_device<CPUDevice>(), x_input,
379 perm, &transformed_x_input));
380 } else {
381 transformed_y_backprop_input = y_backprop_input;
382 transformed_x_input = x_input;
383 transformed_x_backprop_output = *x_backprop_output;
384 }
385 typename TTypes<T, 4>::Tensor y_backprop(
386 transformed_y_backprop_input.tensor<T, 4>());
387 typename TTypes<T, 4>::Tensor x(transformed_x_input.tensor<T, 4>());
388 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
389 typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
390 typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
391 typename TTypes<T, 4>::Tensor x_backprop(
392 transformed_x_backprop_output.tensor<T, 4>());
393 typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
394
395 // Note: the following formulas are used to compute the gradients for
396 // back propagation.
397 // x_backprop = scale * rsqrt(variance + epsilon) *
398 // [y_backprop - mean(y_backprop) - (x - mean(x)) *
399 // mean(y_backprop * (x - mean(x))) / (variance + epsilon)]
400 // scale_backprop = sum(y_backprop *
401 // (x - mean(x)) * rsqrt(variance + epsilon))
402 // offset_backprop = sum(y_backprop)
403
404 const CPUDevice& d = context->eigen_device<CPUDevice>();
405 const int depth = x.dimension(3);
406 const int size = x.size();
407 const int rest_size = size / depth;
408 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
409
410 #if !defined(EIGEN_HAS_INDEX_LIST)
411 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
412 Eigen::array<int, 2> bcast_spec({rest_size, 1});
413 #else
414 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
415 one_by_depth.set(1, depth);
416 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
417 bcast_spec.set(0, rest_size);
418 #endif
419
420 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
421 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
422
423 // Eigen is notoriously bad at reducing outer dimension, so we materialize
424 // all temporary tensors that require reduction, and then use Eigen redux
425 // functor, that is optimized for this particular task.
426 //
427 // All reductions are of this type: [rest_size, depth] -> [depth].
428 using ScalarSum = Eigen::internal::scalar_sum_op<U>;
429 const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
430 const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
431
432 auto scratch_dtype = DataTypeToEnum<U>::value;
433
434 // Allocate a temporary workspace of [depth] shape.
435 Tensor scratch_one_by_depth;
436 OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth},
437 &scratch_one_by_depth));
438
439 // Maybe allocate a temporary workspace of [rest_size, depth] shape.
440 Tensor scratch_rest_by_depth;
441 if (std::is_same<T, U>::value) {
442 OP_REQUIRES(context,
443 scratch_rest_by_depth.CopyFrom(transformed_x_backprop_output,
444 {rest_size, depth}),
445 errors::Internal("Failed to copy a tensor"));
446 } else {
447 OP_REQUIRES_OK(context,
448 context->allocate_temp(scratch_dtype, {rest_size, depth},
449 &scratch_rest_by_depth));
450 }
451
452 typename TTypes<U, 2>::Tensor scratch_tensor(
453 scratch_rest_by_depth.tensor<U, 2>());
454 typename TTypes<U>::Vec scratch_vector(scratch_one_by_depth.vec<U>());
455
456 auto x_mean_rest_by_depth =
457 mean.reshape(one_by_depth).broadcast(bcast_spec);
458 auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
459 auto coef0_one_by_depth =
460 (variance.reshape(one_by_depth) + epsilon).rsqrt();
461 auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec);
462 auto x_scaled = x_centered * coef0_rest_by_depth;
463
464 auto y_backprop_rest_by_depth =
465 y_backprop.reshape(rest_by_depth).template cast<U>();
466
467 // Compute `scale_backprop_output`:
468 // scale_backprop =
469 // (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims)
470 scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled;
471 redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output);
472
473 // Compute 'offset_backprop_output':
474 // offset_backprop =
475 // y_backprop_rest_by_depth.sum(reduce_dims)
476 redux_sum_t(d, rest_by_depth, transformed_y_backprop_input,
477 offset_backprop_output);
478 auto y_backprop_sum = offset_backprop;
479
480 auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth);
481 auto y_backprop_mean_one_by_depth =
482 y_backprop_sum_one_by_depth * rest_size_inv;
483 auto y_backprop_mean_rest_by_depth =
484 y_backprop_mean_one_by_depth.broadcast(bcast_spec);
485 auto y_backprop_centered =
486 y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
487
488 // Compute expression:
489 // y_backprop_centered_mean =
490 // (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
491 scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
492 redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
493 auto y_backprop_centered_mean =
494 scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size);
495
496 auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth)
497 .broadcast(bcast_spec);
498 auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean)
499 .broadcast(bcast_spec);
500
501 x_backprop.reshape(rest_by_depth).device(d) =
502 (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
503
504 if (tensor_format == FORMAT_NCHW) {
505 // Perform NHWC to NCHW
506 std::vector<int32> perm = {0, 3, 1, 2};
507 OP_REQUIRES_OK(
508 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
509 transformed_x_backprop_output,
510 perm, x_backprop_output));
511 }
512 }
513 };
514
515 template <typename T, typename U>
516 struct FusedBatchNormFreezeGrad<CPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormFreezeGrad517 void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
518 const Tensor& x_input, const Tensor& scale_input,
519 const Tensor& pop_mean_input,
520 const Tensor& pop_variance_input, U epsilon,
521 Tensor* x_backprop_output, Tensor* scale_backprop_output,
522 Tensor* offset_backprop_output) {
523 typename TTypes<T, 4>::ConstTensor y_backprop(
524 y_backprop_input.tensor<T, 4>());
525 typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
526 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
527 typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
528 typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
529 typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
530 typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
531
532 const int depth = pop_mean.dimension(0);
533 const int rest_size = input.size() / depth;
534
535 const CPUDevice& d = context->eigen_device<CPUDevice>();
536
537 // Allocate two temporary workspaces of [depth] shape.
538 Tensor scratch1_vec, scratch2_vec;
539 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
540 {depth}, &scratch1_vec));
541 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
542 {depth}, &scratch2_vec));
543
544 // Maybe allocate a temporary workspace of [rest_size, depth] shape.
545 Tensor scratch3_tensor;
546 if (std::is_same<T, U>::value) {
547 OP_REQUIRES(
548 context,
549 scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}),
550 errors::Internal("Failed to copy a tensor"));
551 } else {
552 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
553 {rest_size, depth},
554 &scratch3_tensor));
555 }
556
557 typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
558 typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
559 typename TTypes<U, 2>::Tensor scratch3(scratch3_tensor.tensor<U, 2>());
560
561 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
562 #if !defined(EIGEN_HAS_INDEX_LIST)
563 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
564 Eigen::array<int, 2> rest_by_one({rest_size, 1});
565 #else
566 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
567 one_by_depth.set(1, depth);
568 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> rest_by_one;
569 rest_by_one.set(0, rest_size);
570 #endif
571
572 // Sum reduction along the 0th dimension using custom CPU functor.
573 using ScalarSum = Eigen::internal::scalar_sum_op<U>;
574 const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
575 const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
576
577 // offset_backprop = sum(y_backprop)
578 // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
579 // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
580
581 // NOTE: DEFAULT DEVICE comment is added to expression assignments that
582 // we don't want to be executed in a thread pool.
583
584 auto y_backprop_rest_by_depth =
585 y_backprop.reshape(rest_by_depth).template cast<U>();
586 auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
587
588 // offset_backprop = sum(y_backprop)
589 redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output);
590
591 // scratch1 = rsqrt(pop_var + epsilon)
592 scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt(); // DEFAULT DEVICE
593
594 // scratch2 = sum(y_backprop * (x - mean))
595 scratch3.device(d) =
596 y_backprop_rest_by_depth *
597 (input_rest_by_depth -
598 pop_mean.reshape(one_by_depth).broadcast(rest_by_one));
599 redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec);
600
601 x_backprop.reshape(rest_by_depth).device(d) =
602 (y_backprop_rest_by_depth *
603 ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth))
604 .broadcast(rest_by_one)))
605 .template cast<T>();
606 scale_backprop = scratch2 * scratch1; // DEFAULT DEVICE
607 }
608 };
609
610 #if !GOOGLE_CUDA
611 namespace {
612 // See implementation under GOOGLE_CUDA #ifdef below.
613 // This is a CUDA specific feature, do not enable it for non-CUDA builds
BatchnormSpatialPersistentEnabled()614 bool BatchnormSpatialPersistentEnabled() { return false; }
615 } // namespace
616 #endif
617
618 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
619
620 namespace {
621
AsDnnActivationMode(const FusedBatchNormActivationMode activation_mode)622 se::dnn::ActivationMode AsDnnActivationMode(
623 const FusedBatchNormActivationMode activation_mode) {
624 switch (activation_mode) {
625 case FusedBatchNormActivationMode::kIdentity:
626 return se::dnn::ActivationMode::kNone;
627 case FusedBatchNormActivationMode::kRelu:
628 return se::dnn::ActivationMode::kRelu;
629 }
630 }
631
632 #if GOOGLE_CUDA
633 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
634 // `cuda_dnn.cc` for details.
BatchnormSpatialPersistentEnabled()635 bool BatchnormSpatialPersistentEnabled() {
636 #if CUDNN_VERSION >= 7402
637 static bool is_enabled = [] {
638 bool is_enabled = false;
639 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
640 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
641 /*default_val=*/false, &is_enabled));
642 return is_enabled;
643 }();
644 return is_enabled;
645 #else
646 return false;
647 #endif
648 }
649 #endif
650
651 } // namespace
652
653 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)654 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
655 return DeviceMemory<U>::MakeFromByteSize(
656 tensor->template flat<T>().data(),
657 tensor->template flat<T>().size() * sizeof(T));
658 }
659
660 // A helper to allocate temporary scratch memory for Cudnn BatchNormEx ops. It
661 // takes the ownership of the underlying memory. The expectation is that the
662 // memory should be alive for the span of the Cudnn BatchNormEx itself.
663 template <typename T>
664 class CudnnBatchNormAllocatorInTemp : public ScratchAllocator {
665 public:
666 ~CudnnBatchNormAllocatorInTemp() override = default;
667
CudnnBatchNormAllocatorInTemp(OpKernelContext * context)668 explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context)
669 : context_(context) {}
670
GetMemoryLimitInBytes()671 int64 GetMemoryLimitInBytes() override {
672 return std::numeric_limits<int64>::max();
673 }
674
AllocateBytes(int64 byte_size)675 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
676 Tensor temporary_memory;
677 const DataType tf_data_type = DataTypeToEnum<T>::v();
678 int64 allocate_count =
679 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
680 Status allocation_status(context_->allocate_temp(
681 tf_data_type, TensorShape({allocate_count}), &temporary_memory));
682 if (!allocation_status.ok()) {
683 return allocation_status;
684 }
685 // Hold the reference of the allocated tensors until the end of the
686 // allocator.
687 allocated_tensors_.push_back(temporary_memory);
688 total_byte_size_ += byte_size;
689 return DeviceMemory<uint8>::MakeFromByteSize(
690 temporary_memory.template flat<T>().data(),
691 temporary_memory.template flat<T>().size() * sizeof(T));
692 }
693
TotalByteSize() const694 int64 TotalByteSize() const { return total_byte_size_; }
695
get_allocated_tensor(int index) const696 Tensor get_allocated_tensor(int index) const {
697 return allocated_tensors_[index];
698 }
699
700 private:
701 int64 total_byte_size_ = 0;
702 OpKernelContext* context_; // not owned
703 std::vector<Tensor> allocated_tensors_;
704 };
705
706 // A helper to allocate memory for Cudnn BatchNormEx as a kernel output. It is
707 // used by forward pass kernel to feed the output to the backward pass.
708 // The memory is expected to live long enough after the backward pass is
709 // finished.
710 template <typename T>
711 class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
712 public:
~CudnnBatchNormAllocatorInOutput()713 ~CudnnBatchNormAllocatorInOutput() override {
714 if (!output_allocated) {
715 Tensor* dummy_reserve_space = nullptr;
716 OP_REQUIRES_OK(context_, context_->allocate_output(output_index_, {},
717 &dummy_reserve_space));
718 }
719 }
720
CudnnBatchNormAllocatorInOutput(OpKernelContext * context,int output_index)721 CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index)
722 : context_(context), output_index_(output_index) {}
723
GetMemoryLimitInBytes()724 int64 GetMemoryLimitInBytes() override {
725 return std::numeric_limits<int64>::max();
726 }
727
AllocateBytes(int64 byte_size)728 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
729 output_allocated = true;
730 DCHECK(total_byte_size_ == 0)
731 << "Reserve space allocator can only be called once";
732 int64 allocate_count =
733 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
734
735 Tensor* temporary_memory = nullptr;
736 Status allocation_status(context_->allocate_output(
737 output_index_, TensorShape({allocate_count}), &temporary_memory));
738 if (!allocation_status.ok()) {
739 return allocation_status;
740 }
741 total_byte_size_ += byte_size;
742 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
743 temporary_memory->template flat<T>().data(),
744 temporary_memory->template flat<T>().size() * sizeof(T));
745 return StatusOr<DeviceMemory<uint8>>(memory_uint8);
746 }
747
TotalByteSize()748 int64 TotalByteSize() { return total_byte_size_; }
749
750 private:
751 int64 total_byte_size_ = 0;
752 OpKernelContext* context_; // not owned
753 int output_index_;
754 bool output_allocated = false;
755 };
756
757 template <typename T, typename U, bool is_training>
758 struct FusedBatchNorm<GPUDevice, T, U, is_training> {
operator ()tensorflow::functor::FusedBatchNorm759 void operator()(OpKernelContext* context, const Tensor& x,
760 const Tensor& scale, const Tensor& offset,
761 const Tensor& estimated_mean,
762 const Tensor& estimated_variance, const Tensor* side_input,
763 U epsilon, U exponential_avg_factor,
764 FusedBatchNormActivationMode activation_mode, Tensor* y,
765 Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
766 Tensor* saved_inv_var, TensorFormat tensor_format,
767 bool use_reserved_space) {
768 auto* stream = context->op_device_context()->stream();
769 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
770
771 const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
772 const int64 channels = GetTensorDim(x, tensor_format, 'C');
773 const int64 height = GetTensorDim(x, tensor_format, 'H');
774 const int64 width = GetTensorDim(x, tensor_format, 'W');
775
776 // If use_reserved_space we have reserve_space_3 output (only in
777 // FusedBatchNormV3 op).
778
779 #if GOOGLE_CUDA
780 // Check if cuDNN batch normalization has a fast NHWC implementation:
781 // (1) In inference mode it's always fast.
782 // (2) Tensorflow enabled batchnorm spatial persistence, we are called
783 // from
784 // FusedBatchNormV3, i.e. use_reserved_space is true.
785 const bool fast_nhwc_batch_norm =
786 !is_training ||
787 (BatchnormSpatialPersistentEnabled() &&
788 DataTypeToEnum<T>::value == DT_HALF && use_reserved_space);
789 #else
790 // fast NHWC implementation is a CUDA only feature
791 const bool fast_nhwc_batch_norm = false;
792 #endif
793
794 // If input tensor is in NHWC format, and we have a fast cuDNN
795 // implementation, there is no need to do data format conversion.
796 TensorFormat compute_format =
797 fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
798 : FORMAT_NCHW;
799
800 VLOG(2) << "FusedBatchNorm:"
801 << " batch_size: " << batch_size << " channels: " << channels
802 << " height: " << height << " width:" << width
803 << " x shape: " << x.shape().DebugString()
804 << " scale shape: " << scale.shape().DebugString()
805 << " offset shape: " << offset.shape().DebugString()
806 << " activation mode: " << ToString(activation_mode)
807 << " tensor format: " << ToString(tensor_format)
808 << " compute format: " << ToString(compute_format);
809
810 auto maybe_make_dummy_output = [context, use_reserved_space]() -> Status {
811 if (use_reserved_space) {
812 Tensor* dummy_reserve_space = nullptr;
813 return context->allocate_output(5, {}, &dummy_reserve_space);
814 }
815 return Status::OK();
816 };
817
818 // If input is empty, return NaN mean/variance
819 if (x.shape().num_elements() == 0) {
820 OP_REQUIRES_OK(context, maybe_make_dummy_output());
821 functor::SetNanFunctor<U> f;
822 f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
823 f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
824 return;
825 }
826
827 // In inference mode we use custom CUDA kernel, because cuDNN does not
828 // support side input and activations for inference.
829 const bool has_side_input = side_input != nullptr;
830 const bool has_activation =
831 activation_mode != FusedBatchNormActivationMode::kIdentity;
832
833 if (!is_training && (has_side_input || has_activation)) {
834 OP_REQUIRES_OK(context, maybe_make_dummy_output());
835 FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor;
836
837 if (has_side_input) {
838 inference_functor(context, tensor_format, x.tensor<T, 4>(),
839 scale.vec<U>(), offset.vec<U>(),
840 estimated_mean.vec<U>(), estimated_variance.vec<U>(),
841 side_input->tensor<T, 4>(), epsilon, activation_mode,
842 y->tensor<T, 4>());
843 } else {
844 typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0);
845 inference_functor(context, tensor_format, x.tensor<T, 4>(),
846 scale.vec<U>(), offset.vec<U>(),
847 estimated_mean.vec<U>(), estimated_variance.vec<U>(),
848 empty_tensor, epsilon, activation_mode,
849 y->tensor<T, 4>());
850 }
851 return;
852 }
853
854 Tensor x_maybe_transformed = x;
855 Tensor x_transformed;
856 Tensor y_transformed;
857 se::DeviceMemory<T> y_ptr;
858
859 if (tensor_format == compute_format) {
860 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
861 } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
862 OP_REQUIRES_OK(context, context->allocate_temp(
863 DataTypeToEnum<T>::value,
864 ShapeFromFormat(compute_format, batch_size,
865 height, width, channels),
866 &x_transformed));
867 functor::NHWCToNCHW<GPUDevice, T, 4>()(
868 context->eigen_device<GPUDevice>(),
869 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
870 x_transformed.tensor<T, 4>());
871 x_maybe_transformed = x_transformed;
872
873 OP_REQUIRES_OK(context, context->allocate_temp(
874 DataTypeToEnum<T>::value,
875 ShapeFromFormat(compute_format, batch_size,
876 height, width, channels),
877 &y_transformed));
878 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed);
879 } else {
880 context->SetStatus(errors::Internal(
881 "Unsupported tensor format: ", ToString(tensor_format),
882 " and compute format: ", ToString(compute_format)));
883 return;
884 }
885
886 const se::dnn::DataLayout data_layout =
887 compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
888 : se::dnn::DataLayout::kBatchDepthYX;
889
890 se::dnn::BatchDescriptor x_desc;
891 x_desc.set_count(batch_size)
892 .set_feature_map_count(channels)
893 .set_height(height)
894 .set_width(width)
895 .set_layout(data_layout);
896
897 se::dnn::BatchDescriptor scale_offset_desc;
898 scale_offset_desc.set_count(1)
899 .set_feature_map_count(channels)
900 .set_height(1)
901 .set_width(1)
902 .set_layout(se::dnn::DataLayout::kBatchDepthYX);
903
904 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
905 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
906 auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
907 auto estimated_mean_ptr =
908 StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
909 auto estimated_variance_ptr =
910 StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
911 auto side_input_ptr =
912 side_input != nullptr
913 ? StreamExecutorUtil::AsDeviceMemory<U>(*side_input)
914 : se::DeviceMemory<U>();
915 auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
916
917 auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
918 auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
919 auto saved_inv_var_ptr =
920 StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
921
922 std::unique_ptr<functor::CudnnBatchNormAllocatorInOutput<U>>
923 reserve_space_allocator;
924 std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
925 workspace_allocator;
926 if (use_reserved_space) {
927 reserve_space_allocator.reset(
928 new functor::CudnnBatchNormAllocatorInOutput<U>(context, 5));
929 workspace_allocator.reset(
930 new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
931 }
932 if (!batch_mean->SharesBufferWith(estimated_mean) &&
933 exponential_avg_factor != 1.0f) {
934 OP_REQUIRES(
935 context,
936 stream
937 ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr,
938 estimated_mean.NumElements() * sizeof(U))
939 .ok(),
940 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
941 "from device"));
942 }
943 if (!batch_var->SharesBufferWith(estimated_variance) &&
944 exponential_avg_factor != 1.0f) {
945 OP_REQUIRES(
946 context,
947 stream
948 ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr,
949 estimated_variance.NumElements() * sizeof(U))
950 .ok(),
951 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
952 "from device"));
953 }
954 bool cudnn_launch_status =
955 stream
956 ->ThenBatchNormalizationForward(
957 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
958 estimated_variance_ptr, side_input_ptr, x_desc,
959 scale_offset_desc, static_cast<double>(epsilon),
960 static_cast<double>(exponential_avg_factor),
961 AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
962 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
963 is_training, reserve_space_allocator.get(),
964 workspace_allocator.get())
965 .ok();
966
967 if (!cudnn_launch_status) {
968 context->SetStatus(
969 errors::Internal("cuDNN launch failure : input shape (",
970 x.shape().DebugString(), ")"));
971 return;
972 }
973
974 if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
975 functor::NCHWToNHWC<GPUDevice, T, 4>()(
976 context->eigen_device<GPUDevice>(),
977 const_cast<const Tensor&>(y_transformed).tensor<T, 4>(),
978 y->tensor<T, 4>());
979 }
980 }
981 };
982
983 template <typename T, typename U>
984 struct FusedBatchNormGrad<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormGrad985 void operator()(OpKernelContext* context, const Tensor& y_backprop,
986 const Tensor& x, const Tensor& scale, const Tensor& mean,
987 const Tensor& inv_variance, U epsilon, Tensor* x_backprop,
988 Tensor* scale_backprop, Tensor* offset_backprop,
989 bool use_reserved_space, TensorFormat tensor_format) {
990 auto* stream = context->op_device_context()->stream();
991 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
992
993 const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
994 const int64 channels = GetTensorDim(x, tensor_format, 'C');
995 const int64 height = GetTensorDim(x, tensor_format, 'H');
996 const int64 width = GetTensorDim(x, tensor_format, 'W');
997
998 #if GOOGLE_CUDA
999 // Check if cuDNN batch normalization has a fast NHWC implementation:
1000 // (1) Tensorflow enabled batchnorm spatial persistence, and
1001 // FusedBatchNormGradV3 passed non-null reserve space and allocator.
1002 const bool fast_nhwc_batch_norm = BatchnormSpatialPersistentEnabled() &&
1003 DataTypeToEnum<T>::value == DT_HALF &&
1004 use_reserved_space;
1005 #else
1006 // fast NHWC implementation is a CUDA only feature
1007 const bool fast_nhwc_batch_norm = false;
1008 #endif
1009
1010 // If input tensor is in NHWC format, and we have a fast cuDNN
1011 // implementation, there is no need to do data format conversion.
1012 TensorFormat compute_format =
1013 fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
1014 : FORMAT_NCHW;
1015
1016 VLOG(2) << "FusedBatchNormGrad:"
1017 << " batch_size: " << batch_size << " channels: " << channels
1018 << " height: " << height << " width: " << width
1019 << " y_backprop shape: " << y_backprop.shape().DebugString()
1020 << " x shape: " << x.shape().DebugString()
1021 << " scale shape: " << scale.shape().DebugString()
1022 << " tensor format: " << ToString(tensor_format)
1023 << " compute format: " << ToString(compute_format);
1024
1025 // Inputs
1026 Tensor y_backprop_maybe_transformed = y_backprop;
1027 Tensor x_maybe_transformed = x;
1028 Tensor y_backprop_transformed;
1029 Tensor x_transformed;
1030
1031 // Outputs
1032 Tensor x_backprop_transformed;
1033 se::DeviceMemory<T> x_backprop_ptr;
1034
1035 if (tensor_format == compute_format) {
1036 x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
1037 } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1038 // Transform inputs from 'NHWC' to 'NCHW'
1039 OP_REQUIRES_OK(context, context->allocate_temp(
1040 DataTypeToEnum<T>::value,
1041 ShapeFromFormat(FORMAT_NCHW, batch_size,
1042 height, width, channels),
1043 &y_backprop_transformed));
1044 functor::NHWCToNCHW<GPUDevice, T, 4>()(
1045 context->eigen_device<GPUDevice>(),
1046 const_cast<const Tensor&>(y_backprop_maybe_transformed)
1047 .tensor<T, 4>(),
1048 y_backprop_transformed.tensor<T, 4>());
1049 y_backprop_maybe_transformed = y_backprop_transformed;
1050
1051 OP_REQUIRES_OK(context, context->allocate_temp(
1052 DataTypeToEnum<T>::value,
1053 ShapeFromFormat(FORMAT_NCHW, batch_size,
1054 height, width, channels),
1055 &x_transformed));
1056 functor::NHWCToNCHW<GPUDevice, T, 4>()(
1057 context->eigen_device<GPUDevice>(),
1058 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
1059 x_transformed.tensor<T, 4>());
1060 x_maybe_transformed = x_transformed;
1061
1062 // Allocate memory for transformed outputs in 'NCHW'
1063 OP_REQUIRES_OK(context, context->allocate_temp(
1064 DataTypeToEnum<T>::value,
1065 ShapeFromFormat(FORMAT_NCHW, batch_size,
1066 height, width, channels),
1067 &x_backprop_transformed));
1068 x_backprop_ptr =
1069 StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
1070 } else {
1071 context->SetStatus(errors::Internal(
1072 "Unsupported tensor format: ", ToString(tensor_format),
1073 " and compute format: ", ToString(compute_format)));
1074 return;
1075 }
1076
1077 const se::dnn::DataLayout data_layout =
1078 compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
1079 : se::dnn::DataLayout::kBatchDepthYX;
1080
1081 se::dnn::BatchDescriptor x_desc;
1082 x_desc.set_count(batch_size)
1083 .set_feature_map_count(channels)
1084 .set_height(height)
1085 .set_width(width)
1086 .set_layout(data_layout);
1087
1088 se::dnn::BatchDescriptor scale_offset_desc;
1089 scale_offset_desc.set_count(1)
1090 .set_feature_map_count(channels)
1091 .set_height(1)
1092 .set_width(1)
1093 .set_layout(se::dnn::DataLayout::kBatchDepthYX);
1094
1095 auto y_backprop_ptr =
1096 StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
1097 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
1098 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
1099 auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
1100 auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
1101 auto scale_backprop_ptr =
1102 StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
1103 auto offset_backprop_ptr =
1104 StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
1105
1106 std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
1107 workspace_allocator;
1108 DeviceMemory<uint8>* reserve_space_data_ptr = nullptr;
1109 DeviceMemory<uint8> reserve_space_data;
1110 #if CUDNN_VERSION >= 7402
1111 if (use_reserved_space) {
1112 const Tensor& reserve_space = context->input(5);
1113 workspace_allocator.reset(
1114 new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
1115
1116 // the cudnn kernel outputs inverse variance in forward and reuse it in
1117 // backward
1118 if (reserve_space.dims() != 0) {
1119 reserve_space_data = functor::CastDeviceMemory<uint8, U>(
1120 const_cast<Tensor*>(&reserve_space));
1121 reserve_space_data_ptr = &reserve_space_data;
1122 }
1123 }
1124 #endif // CUDNN_VERSION >= 7402
1125
1126 bool cudnn_launch_status =
1127 stream
1128 ->ThenBatchNormalizationBackward(
1129 y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr,
1130 x_desc, scale_offset_desc, static_cast<double>(epsilon),
1131 &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr,
1132 reserve_space_data_ptr, workspace_allocator.get())
1133 .ok();
1134
1135 if (!cudnn_launch_status) {
1136 context->SetStatus(
1137 errors::Internal("cuDNN launch failure : input shape (",
1138 x.shape().DebugString(), ")"));
1139 }
1140 if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1141 functor::NCHWToNHWC<GPUDevice, T, 4>()(
1142 context->eigen_device<GPUDevice>(),
1143 const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
1144 x_backprop->tensor<T, 4>());
1145 }
1146 }
1147 };
1148
1149 // Forward declarations of the functor specializations for GPU.
1150 #define DECLARE_GPU_SPEC(T, U) \
1151 template <> \
1152 void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \
1153 OpKernelContext* context, const Tensor& y_backprop_input, \
1154 const Tensor& x_input, const Tensor& scale_input, \
1155 const Tensor& mean_input, const Tensor& variance_input, U epsilon, \
1156 Tensor* x_backprop_output, Tensor* scale_backprop_output, \
1157 Tensor* offset_backprop_output); \
1158 extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; \
1159 template <> \
1160 void FusedBatchNormInferenceFunctor<GPUDevice, T, U>::operator()( \
1161 OpKernelContext* context, TensorFormat tensor_format, \
1162 typename TTypes<T, 4>::ConstTensor in, \
1163 typename TTypes<U>::ConstVec scale, typename TTypes<U>::ConstVec offset, \
1164 typename TTypes<U>::ConstVec estimated_mean, \
1165 typename TTypes<U>::ConstVec estimated_variance, \
1166 typename TTypes<T, 4>::ConstTensor side_input, U epsilon, \
1167 FusedBatchNormActivationMode activation_mode, \
1168 typename TTypes<T, 4>::Tensor out); \
1169 extern template struct FusedBatchNormInferenceFunctor<GPUDevice, T, U>;
1170
1171 DECLARE_GPU_SPEC(float, float);
1172 DECLARE_GPU_SPEC(Eigen::half, float);
1173
1174 #undef DECLARE_GPU_SPEC
1175
1176 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1177 } // namespace functor
1178
1179 template <typename Device, typename T, typename U>
1180 class FusedBatchNormOpBase : public OpKernel {
1181 using FbnActivationMode = functor::FusedBatchNormActivationMode;
1182
1183 protected:
FusedBatchNormOpBase(OpKernelConstruction * context,bool is_batch_norm_ex=false)1184 explicit FusedBatchNormOpBase(OpKernelConstruction* context,
1185 bool is_batch_norm_ex = false)
1186 : OpKernel(context) {
1187 float epsilon;
1188 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1189 epsilon_ = U(epsilon);
1190 float exponential_avg_factor;
1191 OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
1192 &exponential_avg_factor));
1193 exponential_avg_factor_ = U(exponential_avg_factor);
1194 string tensor_format;
1195 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1196 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1197 errors::InvalidArgument("Invalid data format"));
1198 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1199
1200 if (!is_batch_norm_ex) {
1201 has_side_input_ = false;
1202 activation_mode_ = FbnActivationMode::kIdentity;
1203 } else {
1204 OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
1205
1206 int num_side_inputs;
1207 OP_REQUIRES_OK(context,
1208 context->GetAttr("num_side_inputs", &num_side_inputs));
1209 OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1,
1210 errors::InvalidArgument(
1211 "FusedBatchNorm accepts at most one side input."));
1212 has_side_input_ = (num_side_inputs == 1);
1213 if (has_side_input_ && is_training_) {
1214 OP_REQUIRES(
1215 context, activation_mode_ != FbnActivationMode::kIdentity,
1216 errors::InvalidArgument("Identity activation is not supported with "
1217 "non-empty side input"));
1218 }
1219 }
1220
1221 if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) {
1222 // NOTE(ezhulenev): Following requirements are coming from implementation
1223 // details of cudnnBatchNormalizationForwardTrainingEx used in training
1224 // mode. In inference mode we call custom CUDA kernel that supports all
1225 // data formats and data types.
1226 OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
1227 errors::InvalidArgument("FusedBatchNorm with activation "
1228 "supports only DT_HALF data type."));
1229 OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1230 errors::InvalidArgument("FusedBatchNorm with activation "
1231 "supports only NHWC tensor format."));
1232 OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(),
1233 errors::InvalidArgument(
1234 "FusedBatchNorm with activation must run with cuDNN "
1235 "spatial persistence mode enabled."));
1236 }
1237 }
1238
1239 // If use_reserved_space is true, we need to handle the 5th output (a reserved
1240 // space) and a new cudnn batch norm will be called if the version > 7.4.2.
1241 // If use_reserved_space is false, we don't have 5th output.
ComputeWithReservedSpace(OpKernelContext * context,bool use_reserved_space)1242 virtual void ComputeWithReservedSpace(OpKernelContext* context,
1243 bool use_reserved_space) {
1244 Tensor x = context->input(0);
1245 const Tensor& scale = context->input(1);
1246 const Tensor& offset = context->input(2);
1247 const Tensor& estimated_mean = context->input(3);
1248 const Tensor& estimated_variance = context->input(4);
1249 const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
1250
1251 OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1252 errors::InvalidArgument("input must be 4 or 5-dimensional",
1253 x.shape().DebugString()));
1254 OP_REQUIRES(context, scale.dims() == 1,
1255 errors::InvalidArgument("scale must be 1-dimensional",
1256 scale.shape().DebugString()));
1257 OP_REQUIRES(context, offset.dims() == 1,
1258 errors::InvalidArgument("offset must be 1-dimensional",
1259 offset.shape().DebugString()));
1260 OP_REQUIRES(context, estimated_mean.dims() == 1,
1261 errors::InvalidArgument("estimated_mean must be 1-dimensional",
1262 estimated_mean.shape().DebugString()));
1263 OP_REQUIRES(
1264 context, estimated_variance.dims() == 1,
1265 errors::InvalidArgument("estimated_variance must be 1-dimensional",
1266 estimated_variance.shape().DebugString()));
1267 bool use_reshape = (x.dims() == 5);
1268 auto x_shape = x.shape();
1269 TensorShape dest_shape;
1270 if (use_reshape) {
1271 const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
1272 int64 in_planes = GetTensorDim(x, tensor_format_, '0');
1273 int64 in_rows = GetTensorDim(x, tensor_format_, '1');
1274 int64 in_cols = GetTensorDim(x, tensor_format_, '2');
1275 const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
1276 dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1277 {{in_planes, in_rows * in_cols}}, in_depth);
1278 OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1279 errors::InvalidArgument("Error during tensor copy."));
1280 }
1281
1282 if (has_side_input_) {
1283 OP_REQUIRES(context, side_input->shape() == x.shape(),
1284 errors::InvalidArgument(
1285 "side_input shape must be equal to input shape: ",
1286 side_input->shape().DebugString(),
1287 " != ", x.shape().DebugString()));
1288 }
1289
1290 if (activation_mode_ != FbnActivationMode::kIdentity) {
1291 // NOTE(ezhulenev): This requirement is coming from implementation
1292 // details of cudnnBatchNormalizationForwardTrainingEx.
1293 OP_REQUIRES(
1294 context, !is_training_ || x.dim_size(3) % 4 == 0,
1295 errors::InvalidArgument("FusedBatchNorm with activation requires "
1296 "channel dimension to be a multiple of 4."));
1297 }
1298
1299 Tensor* y = nullptr;
1300 auto alloc_shape = use_reshape ? dest_shape : x_shape;
1301 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1302 {0}, 0, alloc_shape, &y));
1303
1304 Tensor* batch_mean = nullptr;
1305 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1306 {3}, 1, scale.shape(), &batch_mean));
1307 Tensor* batch_var = nullptr;
1308 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1309 {4}, 2, scale.shape(), &batch_var));
1310 Tensor* saved_mean = nullptr;
1311 OP_REQUIRES_OK(context,
1312 context->allocate_output(3, scale.shape(), &saved_mean));
1313 Tensor* saved_maybe_inv_var = nullptr;
1314 OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
1315 &saved_maybe_inv_var));
1316
1317 if (is_training_) {
1318 functor::FusedBatchNorm<Device, T, U, true>()(
1319 context, x, scale, offset, estimated_mean, estimated_variance,
1320 side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1321 batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1322 tensor_format_, use_reserved_space);
1323 } else {
1324 functor::FusedBatchNorm<Device, T, U, false>()(
1325 context, x, scale, offset, estimated_mean, estimated_variance,
1326 side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1327 batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1328 tensor_format_, use_reserved_space);
1329 }
1330 if (use_reshape) {
1331 OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
1332 errors::InvalidArgument("Error during tensor copy."));
1333 }
1334 }
1335
1336 private:
1337 U epsilon_;
1338 U exponential_avg_factor_;
1339 TensorFormat tensor_format_;
1340 bool is_training_;
1341 bool has_side_input_;
1342 FbnActivationMode activation_mode_;
1343 };
1344
1345 template <typename Device, typename T, typename U>
1346 class FusedBatchNormOp : public FusedBatchNormOpBase<Device, T, U> {
1347 public:
FusedBatchNormOp(OpKernelConstruction * context)1348 explicit FusedBatchNormOp(OpKernelConstruction* context)
1349 : FusedBatchNormOpBase<Device, T, U>(context) {}
1350
Compute(OpKernelContext * context)1351 void Compute(OpKernelContext* context) override {
1352 FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1353 false);
1354 }
1355 };
1356
1357 template <typename Device, typename T, typename U>
1358 class FusedBatchNormOpV3 : public FusedBatchNormOpBase<Device, T, U> {
1359 public:
FusedBatchNormOpV3(OpKernelConstruction * context)1360 explicit FusedBatchNormOpV3(OpKernelConstruction* context)
1361 : FusedBatchNormOpBase<Device, T, U>(context) {}
1362
Compute(OpKernelContext * context)1363 void Compute(OpKernelContext* context) override {
1364 FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1365 }
1366 };
1367
1368 template <typename Device, typename T, typename U>
1369 class FusedBatchNormOpEx : public FusedBatchNormOpBase<Device, T, U> {
1370 static constexpr bool kWithSideInputAndActivation = true;
1371
1372 public:
FusedBatchNormOpEx(OpKernelConstruction * context)1373 explicit FusedBatchNormOpEx(OpKernelConstruction* context)
1374 : FusedBatchNormOpBase<Device, T, U>(context,
1375 kWithSideInputAndActivation) {}
1376
Compute(OpKernelContext * context)1377 void Compute(OpKernelContext* context) override {
1378 FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1379 }
1380 };
1381
1382 template <typename Device, typename T, typename U>
1383 class FusedBatchNormGradOpBase : public OpKernel {
1384 protected:
FusedBatchNormGradOpBase(OpKernelConstruction * context)1385 explicit FusedBatchNormGradOpBase(OpKernelConstruction* context)
1386 : OpKernel(context) {
1387 float epsilon;
1388 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1389 epsilon_ = U(epsilon);
1390 string tensor_format;
1391 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1392 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1393 errors::InvalidArgument("Invalid data format"));
1394 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1395 }
1396
ComputeWithReservedSpace(OpKernelContext * context,bool use_reserved_space)1397 virtual void ComputeWithReservedSpace(OpKernelContext* context,
1398 bool use_reserved_space) {
1399 Tensor y_backprop = context->input(0);
1400 Tensor x = context->input(1);
1401 const Tensor& scale = context->input(2);
1402 // When is_training=True, batch mean and variance/inverted variance are
1403 // saved in the forward pass to be reused here. When is_training=False,
1404 // population mean and variance need to be forwarded here to compute the
1405 // gradients.
1406 const Tensor& saved_mean_or_pop_mean = context->input(3);
1407 // The Eigen implementation saves variance in the forward pass, while cuDNN
1408 // saves inverted variance.
1409 const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
1410
1411 OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5,
1412 errors::InvalidArgument("input must be 4 or 5-dimensional",
1413 y_backprop.shape().DebugString()));
1414 OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1415 errors::InvalidArgument("input must be 4 or 5-dimensional",
1416 x.shape().DebugString()));
1417 OP_REQUIRES(context, scale.dims() == 1,
1418 errors::InvalidArgument("scale must be 1-dimensional",
1419 scale.shape().DebugString()));
1420 OP_REQUIRES(
1421 context, saved_mean_or_pop_mean.dims() == 1,
1422 errors::InvalidArgument("saved mean must be 1-dimensional",
1423 saved_mean_or_pop_mean.shape().DebugString()));
1424 OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
1425 errors::InvalidArgument(
1426 "saved variance must be 1-dimensional",
1427 saved_maybe_inv_var_or_pop_var.shape().DebugString()));
1428 bool use_reshape = (x.dims() == 5);
1429 auto x_shape = x.shape();
1430 TensorShape dest_shape;
1431 if (use_reshape) {
1432 const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
1433 int64 in_planes = GetTensorDim(x, tensor_format_, '0');
1434 int64 in_rows = GetTensorDim(x, tensor_format_, '1');
1435 int64 in_cols = GetTensorDim(x, tensor_format_, '2');
1436 const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
1437 dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1438 {{in_planes, in_rows * in_cols}}, in_depth);
1439 OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1440 errors::InvalidArgument("Error during tensor copy."));
1441 OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
1442 errors::InvalidArgument("Error during tensor copy."));
1443 }
1444
1445 Tensor* x_backprop = nullptr;
1446 auto alloc_shape = use_reshape ? dest_shape : x_shape;
1447 OP_REQUIRES_OK(context,
1448 context->allocate_output(0, alloc_shape, &x_backprop));
1449
1450 const TensorShape& scale_offset_shape = scale.shape();
1451 Tensor* scale_backprop = nullptr;
1452 OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape,
1453 &scale_backprop));
1454 Tensor* offset_backprop = nullptr;
1455 OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
1456 &offset_backprop));
1457 // Two placeholders for estimated_mean and estimated_variance, which are
1458 // used for inference and thus not needed here for gradient computation.
1459 // They are filled with zeros so as to avoid NaN outputs.
1460 Tensor* placeholder_1 = nullptr;
1461 OP_REQUIRES_OK(
1462 context, context->allocate_output(3, TensorShape({0}), &placeholder_1));
1463 Tensor* placeholder_2 = nullptr;
1464 OP_REQUIRES_OK(
1465 context, context->allocate_output(4, TensorShape({0}), &placeholder_2));
1466
1467 // If input is empty, set gradients w.r.t scale/offset to zero.
1468 if (x.shape().num_elements() == 0) {
1469 functor::SetZeroFunctor<Device, U> f;
1470 f(context->eigen_device<Device>(), scale_backprop->flat<U>());
1471 f(context->eigen_device<Device>(), offset_backprop->flat<U>());
1472 return;
1473 }
1474
1475 if (is_training_) {
1476 functor::FusedBatchNormGrad<Device, T, U>()(
1477 context, y_backprop, x, scale, saved_mean_or_pop_mean,
1478 saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
1479 offset_backprop, use_reserved_space, tensor_format_);
1480 } else {
1481 // Necessary layout conversion is currently done in python.
1482 OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1483 errors::InvalidArgument(
1484 "The implementation of "
1485 "FusedBatchNormGrad with is_training=False only support "
1486 "NHWC tensor format for now."));
1487 functor::FusedBatchNormFreezeGrad<Device, T, U>()(
1488 context, y_backprop, x, scale, saved_mean_or_pop_mean,
1489 saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
1490 offset_backprop);
1491 }
1492 if (use_reshape) {
1493 OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
1494 errors::InvalidArgument("Error during tensor copy."));
1495 }
1496 }
1497
1498 private:
1499 U epsilon_;
1500 TensorFormat tensor_format_;
1501 bool is_training_;
1502 };
1503
1504 template <typename Device, typename T, typename U>
1505 class FusedBatchNormGradOp : public FusedBatchNormGradOpBase<Device, T, U> {
1506 public:
FusedBatchNormGradOp(OpKernelConstruction * context)1507 explicit FusedBatchNormGradOp(OpKernelConstruction* context)
1508 : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1509
Compute(OpKernelContext * context)1510 void Compute(OpKernelContext* context) override {
1511 FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1512 false);
1513 }
1514 };
1515
1516 template <typename Device, typename T, typename U>
1517 class FusedBatchNormGradOpV3 : public FusedBatchNormGradOpBase<Device, T, U> {
1518 public:
FusedBatchNormGradOpV3(OpKernelConstruction * context)1519 explicit FusedBatchNormGradOpV3(OpKernelConstruction* context)
1520 : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1521
Compute(OpKernelContext * context)1522 void Compute(OpKernelContext* context) override {
1523 FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1524 true);
1525 }
1526 };
1527
1528 REGISTER_KERNEL_BUILDER(
1529 Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1530 FusedBatchNormOp<CPUDevice, float, float>);
1531
1532 REGISTER_KERNEL_BUILDER(
1533 Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1534 FusedBatchNormGradOp<CPUDevice, float, float>);
1535
1536 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1537 .Device(DEVICE_CPU)
1538 .TypeConstraint<float>("T")
1539 .TypeConstraint<float>("U"),
1540 FusedBatchNormOp<CPUDevice, float, float>);
1541
1542 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1543 .Device(DEVICE_CPU)
1544 .TypeConstraint<float>("T")
1545 .TypeConstraint<float>("U"),
1546 FusedBatchNormGradOp<CPUDevice, float, float>);
1547
1548 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1549 .Device(DEVICE_CPU)
1550 .TypeConstraint<Eigen::half>("T")
1551 .TypeConstraint<float>("U"),
1552 FusedBatchNormOp<CPUDevice, Eigen::half, float>);
1553
1554 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1555 .Device(DEVICE_CPU)
1556 .TypeConstraint<Eigen::half>("T")
1557 .TypeConstraint<float>("U"),
1558 FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
1559
1560 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1561 .Device(DEVICE_CPU)
1562 .TypeConstraint<float>("T")
1563 .TypeConstraint<float>("U"),
1564 FusedBatchNormOpV3<CPUDevice, float, float>);
1565
1566 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1567 .Device(DEVICE_CPU)
1568 .TypeConstraint<float>("T")
1569 .TypeConstraint<float>("U"),
1570 FusedBatchNormGradOpV3<CPUDevice, float, float>);
1571
1572 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1573 .Device(DEVICE_CPU)
1574 .TypeConstraint<Eigen::half>("T")
1575 .TypeConstraint<float>("U"),
1576 FusedBatchNormOpV3<CPUDevice, Eigen::half, float>);
1577
1578 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1579 .Device(DEVICE_CPU)
1580 .TypeConstraint<Eigen::half>("T")
1581 .TypeConstraint<float>("U"),
1582 FusedBatchNormGradOpV3<CPUDevice, Eigen::half, float>);
1583
1584 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1585
1586 REGISTER_KERNEL_BUILDER(
1587 Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1588 FusedBatchNormOp<GPUDevice, float, float>);
1589
1590 REGISTER_KERNEL_BUILDER(
1591 Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1592 FusedBatchNormGradOp<GPUDevice, float, float>);
1593
1594 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1595 .Device(DEVICE_GPU)
1596 .TypeConstraint<float>("T")
1597 .TypeConstraint<float>("U"),
1598 FusedBatchNormOp<GPUDevice, float, float>);
1599
1600 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1601 .Device(DEVICE_GPU)
1602 .TypeConstraint<float>("T")
1603 .TypeConstraint<float>("U"),
1604 FusedBatchNormGradOp<GPUDevice, float, float>);
1605
1606 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1607 .Device(DEVICE_GPU)
1608 .TypeConstraint<Eigen::half>("T")
1609 .TypeConstraint<float>("U"),
1610 FusedBatchNormOp<GPUDevice, Eigen::half, float>);
1611
1612 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1613 .Device(DEVICE_GPU)
1614 .TypeConstraint<Eigen::half>("T")
1615 .TypeConstraint<float>("U"),
1616 FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
1617
1618 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1619 .Device(DEVICE_GPU)
1620 .TypeConstraint<float>("T")
1621 .TypeConstraint<float>("U"),
1622 FusedBatchNormOpV3<GPUDevice, float, float>);
1623
1624 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1625 .Device(DEVICE_GPU)
1626 .TypeConstraint<float>("T")
1627 .TypeConstraint<float>("U"),
1628 FusedBatchNormOpEx<GPUDevice, float, float>);
1629
1630 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1631 .Device(DEVICE_GPU)
1632 .TypeConstraint<float>("T")
1633 .TypeConstraint<float>("U"),
1634 FusedBatchNormGradOpV3<GPUDevice, float, float>);
1635
1636 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1637 .Device(DEVICE_GPU)
1638 .TypeConstraint<Eigen::half>("T")
1639 .TypeConstraint<float>("U"),
1640 FusedBatchNormOpV3<GPUDevice, Eigen::half, float>);
1641
1642 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1643 .Device(DEVICE_GPU)
1644 .TypeConstraint<Eigen::half>("T")
1645 .TypeConstraint<float>("U"),
1646 FusedBatchNormOpEx<GPUDevice, Eigen::half, float>);
1647
1648 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1649 .Device(DEVICE_GPU)
1650 .TypeConstraint<Eigen::half>("T")
1651 .TypeConstraint<float>("U"),
1652 FusedBatchNormGradOpV3<GPUDevice, Eigen::half, float>);
1653
1654 #endif
1655
1656 } // namespace tensorflow
1657