1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #define EIGEN_USE_GPU
22 #if GOOGLE_CUDA
23 #include "third_party/gpus/cudnn/cudnn.h"
24 #endif  // GOOGLE_CUDA
25 
26 #include "tensorflow/core/kernels/conv_2d.h"
27 #include "tensorflow/core/platform/stream_executor.h"
28 #include "tensorflow/core/util/stream_executor_util.h"
29 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30 
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/kernels/fill_functor.h"
37 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
38 #include "tensorflow/core/kernels/redux_functor.h"
39 #include "tensorflow/core/kernels/transpose_functor.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/util/env_var.h"
42 #include "tensorflow/core/util/tensor_format.h"
43 
44 namespace tensorflow {
45 using CPUDevice = Eigen::ThreadPoolDevice;
46 using GPUDevice = Eigen::GpuDevice;
47 
48 namespace functor {
49 
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 using se::DeviceMemory;
52 using se::ScratchAllocator;
53 using se::Stream;
54 using se::port::StatusOr;
55 #endif
56 
ToString(FusedBatchNormActivationMode activation_mode)57 string ToString(FusedBatchNormActivationMode activation_mode) {
58   switch (activation_mode) {
59     case FusedBatchNormActivationMode::kIdentity:
60       return "Identity";
61     case FusedBatchNormActivationMode::kRelu:
62       return "Relu";
63   }
64 }
65 
ParseActivationMode(OpKernelConstruction * context,FusedBatchNormActivationMode * activation_mode)66 Status ParseActivationMode(OpKernelConstruction* context,
67                            FusedBatchNormActivationMode* activation_mode) {
68   string activation_mode_str;
69   TF_RETURN_IF_ERROR(context->GetAttr("activation_mode", &activation_mode_str));
70 
71   if (activation_mode_str == "Identity") {
72     *activation_mode = FusedBatchNormActivationMode::kIdentity;
73     return Status::OK();
74   }
75   if (activation_mode_str == "Relu") {
76     *activation_mode = FusedBatchNormActivationMode::kRelu;
77     return Status::OK();
78   }
79   return errors::InvalidArgument("Unsupported activation mode: ",
80                                  activation_mode_str);
81 }
82 
83 // Functor used by FusedBatchNormOp to do the computations.
84 template <typename Device, typename T, typename U, bool is_training>
85 struct FusedBatchNorm;
86 // Functor used by FusedBatchNormGradOp to do the computations when
87 // is_training=True.
88 template <typename Device, typename T, typename U>
89 struct FusedBatchNormGrad;
90 
91 template <typename T, typename U>
92 struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> {
operator ()tensorflow::functor::FusedBatchNorm93   void operator()(OpKernelContext* context, const Tensor& x_input,
94                   const Tensor& scale_input, const Tensor& offset_input,
95                   const Tensor& running_mean_input,
96                   const Tensor& running_variance_input,
97                   const Tensor* side_input, U epsilon, U exponential_avg_factor,
98                   FusedBatchNormActivationMode activation_mode,
99                   Tensor* y_output, Tensor* running_mean_output,
100                   Tensor* running_var_output, Tensor* saved_batch_mean_output,
101                   Tensor* saved_batch_var_output, TensorFormat tensor_format,
102                   bool use_reserved_space) {
103     OP_REQUIRES(context, side_input == nullptr,
104                 errors::Internal(
105                     "The CPU implementation of FusedBatchNorm does not support "
106                     "side input."));
107     OP_REQUIRES(context,
108                 activation_mode == FusedBatchNormActivationMode::kIdentity,
109                 errors::Internal("The CPU implementation of FusedBatchNorm "
110                                  "does not support activations."));
111 
112     if (use_reserved_space) {
113       Tensor* dummy_reserve_space = nullptr;
114       OP_REQUIRES_OK(context,
115                      context->allocate_output(5, {}, &dummy_reserve_space));
116       // Initialize the memory, to avoid sanitizer alerts.
117       dummy_reserve_space->flat<U>()(0) = U();
118     }
119     Tensor transformed_x;
120     Tensor transformed_y;
121     if (tensor_format == FORMAT_NCHW) {
122       const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
123       const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
124       const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
125       const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
126       OP_REQUIRES_OK(context, context->allocate_temp(
127                                   DataTypeToEnum<T>::value,
128                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
129                                                   in_rows, in_cols, in_depths),
130                                   &transformed_x));
131       OP_REQUIRES_OK(context, context->allocate_temp(
132                                   DataTypeToEnum<T>::value,
133                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
134                                                   in_rows, in_cols, in_depths),
135                                   &transformed_y));
136       // Perform NCHW to NHWC
137       std::vector<int32> perm = {0, 2, 3, 1};
138       OP_REQUIRES_OK(
139           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
140                                              x_input, perm, &transformed_x));
141     } else {
142       transformed_x = x_input;
143       transformed_y = *y_output;
144     }
145     typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
146     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
147     typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
148     typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>());
149     typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>());
150     typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
151     typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>());
152     typename TTypes<U>::Vec new_variance(running_var_output->vec<U>());
153     typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>());
154     typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>());
155 
156     const CPUDevice& d = context->eigen_device<CPUDevice>();
157 
158     const int depth = x.dimension(3);
159     const int size = x.size();
160     const int rest_size = size / depth;
161     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
162 
163 #if !defined(EIGEN_HAS_INDEX_LIST)
164     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
165     Eigen::array<int, 1> reduce_dims({0});
166     Eigen::array<int, 2> bcast_spec({rest_size, 1});
167 #else
168     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
169     one_by_depth.set(1, depth);
170     Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
171     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
172     bcast_spec.set(0, rest_size);
173 #endif
174 
175     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
176     const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
177     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
178     // This adjustment is for Bessel's correction
179     U rest_size_adjust =
180         static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
181 
182     Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth);
183     Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth);
184 
185     batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
186     auto x_centered = x_rest_by_depth -
187                       batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
188 
189     batch_variance.device(d) =
190         x_centered.square().sum(reduce_dims) * rest_size_inv;
191     auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale)
192                               .eval()
193                               .reshape(one_by_depth)
194                               .broadcast(bcast_spec);
195     auto x_scaled = x_centered * scaling_factor;
196     auto x_shifted =
197         (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
198             .template cast<T>();
199 
200     y.reshape(rest_by_depth).device(d) = x_shifted;
201     if (exponential_avg_factor == U(1.0)) {
202       saved_batch_var.device(d) = batch_variance;
203       saved_batch_mean.device(d) = batch_mean;
204       new_variance.device(d) = batch_variance * rest_size_adjust;
205       new_mean.device(d) = batch_mean;
206     } else {
207       U one_minus_factor = U(1) - exponential_avg_factor;
208       saved_batch_var.device(d) = batch_variance;
209       saved_batch_mean.device(d) = batch_mean;
210       new_variance.device(d) =
211           one_minus_factor * old_variance +
212           (exponential_avg_factor * rest_size_adjust) * batch_variance;
213       new_mean.device(d) =
214           one_minus_factor * old_mean + exponential_avg_factor * batch_mean;
215     }
216 
217     if (tensor_format == FORMAT_NCHW) {
218       // Perform NHWC to NCHW
219       const std::vector<int32> perm = {0, 3, 1, 2};
220       const Status s = ::tensorflow::DoTranspose(
221           context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
222       if (!s.ok()) {
223         context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
224       }
225     }
226   }
227 };
228 
229 template <typename T, typename U>
230 struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
operator ()tensorflow::functor::FusedBatchNorm231   void operator()(OpKernelContext* context, const Tensor& x_input,
232                   const Tensor& scale_input, const Tensor& offset_input,
233                   const Tensor& estimated_mean_input,
234                   const Tensor& estimated_variance_input,
235                   const Tensor* side_input, U epsilon, U exponential_avg_factor,
236                   FusedBatchNormActivationMode activation_mode,
237                   Tensor* y_output, Tensor* batch_mean_output,
238                   Tensor* batch_var_output, Tensor* saved_mean_output,
239                   Tensor* saved_var_output, TensorFormat tensor_format,
240                   bool use_reserved_space) {
241     OP_REQUIRES(context, side_input == nullptr,
242                 errors::Internal(
243                     "The CPU implementation of FusedBatchNorm does not support "
244                     "side input."));
245     OP_REQUIRES(context,
246                 activation_mode == FusedBatchNormActivationMode::kIdentity,
247                 errors::Internal("The CPU implementation of FusedBatchNorm "
248                                  "does not support activations."));
249 
250     if (use_reserved_space) {
251       Tensor* dummy_reserve_space = nullptr;
252       OP_REQUIRES_OK(context,
253                      context->allocate_output(5, {}, &dummy_reserve_space));
254       // Initialize the memory, to avoid sanitizer alerts.
255       dummy_reserve_space->flat<U>()(0) = U();
256     }
257     Tensor transformed_x;
258     Tensor transformed_y;
259     if (tensor_format == FORMAT_NCHW) {
260       const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
261       const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
262       const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
263       const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
264       OP_REQUIRES_OK(context, context->allocate_temp(
265                                   DataTypeToEnum<T>::value,
266                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
267                                                   in_rows, in_cols, in_depths),
268                                   &transformed_x));
269       OP_REQUIRES_OK(context, context->allocate_temp(
270                                   DataTypeToEnum<T>::value,
271                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
272                                                   in_rows, in_cols, in_depths),
273                                   &transformed_y));
274       // Perform NCHW to NHWC
275       std::vector<int32> perm = {0, 2, 3, 1};
276       OP_REQUIRES_OK(
277           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
278                                              x_input, perm, &transformed_x));
279     } else {
280       transformed_x = x_input;
281       transformed_y = *y_output;
282     }
283     typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
284     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
285     typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
286     typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
287     typename TTypes<U>::ConstVec estimated_variance(
288         estimated_variance_input.vec<U>());
289     typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
290     typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
291     typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>());
292 
293     const CPUDevice& d = context->eigen_device<CPUDevice>();
294 
295     const int depth = x.dimension(3);
296     const int size = x.size();
297     const int rest_size = size / depth;
298     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
299 
300 #if !defined(EIGEN_HAS_INDEX_LIST)
301     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
302     Eigen::array<int, 1> reduce_dims({0});
303     Eigen::array<int, 2> bcast_spec({rest_size, 1});
304 #else
305     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
306     one_by_depth.set(1, depth);
307     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
308     bcast_spec.set(0, rest_size);
309 #endif
310 
311     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
312     auto x_centered =
313         x_rest_by_depth -
314         estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
315     auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
316                               .eval()
317                               .reshape(one_by_depth)
318                               .broadcast(bcast_spec);
319     auto x_scaled = x_centered * scaling_factor;
320     auto x_shifted =
321         (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
322             .template cast<T>();
323 
324     y.reshape(rest_by_depth).device(d) = x_shifted;
325     batch_mean.device(d) = estimated_mean;
326     batch_variance.device(d) = estimated_variance;
327 
328     if (tensor_format == FORMAT_NCHW) {
329       // Perform NHWC to NCHW
330       const std::vector<int32> perm = {0, 3, 1, 2};
331       const Status s = ::tensorflow::DoTranspose(
332           context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
333       if (!s.ok()) {
334         context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
335       }
336     }
337   }
338 };
339 
340 template <typename T, typename U>
341 struct FusedBatchNormGrad<CPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormGrad342   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
343                   const Tensor& x_input, const Tensor& scale_input,
344                   const Tensor& mean_input, const Tensor& variance_input,
345                   U epsilon, Tensor* x_backprop_output,
346                   Tensor* scale_backprop_output, Tensor* offset_backprop_output,
347                   bool use_reserved_space, TensorFormat tensor_format) {
348     Tensor transformed_y_backprop_input;
349     Tensor transformed_x_input;
350     Tensor transformed_x_backprop_output;
351     if (tensor_format == FORMAT_NCHW) {
352       const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
353       const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
354       const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
355       const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
356       OP_REQUIRES_OK(context, context->allocate_temp(
357                                   DataTypeToEnum<T>::value,
358                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
359                                                   in_rows, in_cols, in_depths),
360                                   &transformed_y_backprop_input));
361       OP_REQUIRES_OK(context, context->allocate_temp(
362                                   DataTypeToEnum<T>::value,
363                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
364                                                   in_rows, in_cols, in_depths),
365                                   &transformed_x_input));
366       OP_REQUIRES_OK(context, context->allocate_temp(
367                                   DataTypeToEnum<T>::value,
368                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
369                                                   in_rows, in_cols, in_depths),
370                                   &transformed_x_backprop_output));
371       // Perform NCHW to NHWC
372       std::vector<int32> perm = {0, 2, 3, 1};
373       OP_REQUIRES_OK(
374           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
375                                              y_backprop_input, perm,
376                                              &transformed_y_backprop_input));
377       OP_REQUIRES_OK(context, ::tensorflow::DoTranspose(
378                                   context->eigen_device<CPUDevice>(), x_input,
379                                   perm, &transformed_x_input));
380     } else {
381       transformed_y_backprop_input = y_backprop_input;
382       transformed_x_input = x_input;
383       transformed_x_backprop_output = *x_backprop_output;
384     }
385     typename TTypes<T, 4>::Tensor y_backprop(
386         transformed_y_backprop_input.tensor<T, 4>());
387     typename TTypes<T, 4>::Tensor x(transformed_x_input.tensor<T, 4>());
388     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
389     typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
390     typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
391     typename TTypes<T, 4>::Tensor x_backprop(
392         transformed_x_backprop_output.tensor<T, 4>());
393     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
394 
395     // Note: the following formulas are used to compute the gradients for
396     // back propagation.
397     // x_backprop = scale * rsqrt(variance + epsilon) *
398     //              [y_backprop - mean(y_backprop) - (x - mean(x)) *
399     //              mean(y_backprop * (x - mean(x))) / (variance + epsilon)]
400     // scale_backprop = sum(y_backprop *
401     //                  (x - mean(x)) * rsqrt(variance + epsilon))
402     // offset_backprop = sum(y_backprop)
403 
404     const CPUDevice& d = context->eigen_device<CPUDevice>();
405     const int depth = x.dimension(3);
406     const int size = x.size();
407     const int rest_size = size / depth;
408     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
409 
410 #if !defined(EIGEN_HAS_INDEX_LIST)
411     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
412     Eigen::array<int, 2> bcast_spec({rest_size, 1});
413 #else
414     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
415     one_by_depth.set(1, depth);
416     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
417     bcast_spec.set(0, rest_size);
418 #endif
419 
420     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
421     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
422 
423     // Eigen is notoriously bad at reducing outer dimension, so we materialize
424     // all temporary tensors that require reduction, and then use Eigen redux
425     // functor, that is optimized for this particular task.
426     //
427     // All reductions are of this type: [rest_size, depth] -> [depth].
428     using ScalarSum = Eigen::internal::scalar_sum_op<U>;
429     const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
430     const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
431 
432     auto scratch_dtype = DataTypeToEnum<U>::value;
433 
434     // Allocate a temporary workspace of [depth] shape.
435     Tensor scratch_one_by_depth;
436     OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth},
437                                                    &scratch_one_by_depth));
438 
439     // Maybe allocate a temporary workspace of [rest_size, depth] shape.
440     Tensor scratch_rest_by_depth;
441     if (std::is_same<T, U>::value) {
442       OP_REQUIRES(context,
443                   scratch_rest_by_depth.CopyFrom(transformed_x_backprop_output,
444                                                  {rest_size, depth}),
445                   errors::Internal("Failed to copy a tensor"));
446     } else {
447       OP_REQUIRES_OK(context,
448                      context->allocate_temp(scratch_dtype, {rest_size, depth},
449                                             &scratch_rest_by_depth));
450     }
451 
452     typename TTypes<U, 2>::Tensor scratch_tensor(
453         scratch_rest_by_depth.tensor<U, 2>());
454     typename TTypes<U>::Vec scratch_vector(scratch_one_by_depth.vec<U>());
455 
456     auto x_mean_rest_by_depth =
457         mean.reshape(one_by_depth).broadcast(bcast_spec);
458     auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
459     auto coef0_one_by_depth =
460         (variance.reshape(one_by_depth) + epsilon).rsqrt();
461     auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec);
462     auto x_scaled = x_centered * coef0_rest_by_depth;
463 
464     auto y_backprop_rest_by_depth =
465         y_backprop.reshape(rest_by_depth).template cast<U>();
466 
467     // Compute `scale_backprop_output`:
468     //   scale_backprop =
469     //     (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims)
470     scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled;
471     redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output);
472 
473     // Compute 'offset_backprop_output':
474     //   offset_backprop =
475     //     y_backprop_rest_by_depth.sum(reduce_dims)
476     redux_sum_t(d, rest_by_depth, transformed_y_backprop_input,
477                 offset_backprop_output);
478     auto y_backprop_sum = offset_backprop;
479 
480     auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth);
481     auto y_backprop_mean_one_by_depth =
482         y_backprop_sum_one_by_depth * rest_size_inv;
483     auto y_backprop_mean_rest_by_depth =
484         y_backprop_mean_one_by_depth.broadcast(bcast_spec);
485     auto y_backprop_centered =
486         y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
487 
488     // Compute expression:
489     //   y_backprop_centered_mean =
490     //     (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
491     scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
492     redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
493     auto y_backprop_centered_mean =
494         scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size);
495 
496     auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth)
497                      .broadcast(bcast_spec);
498     auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean)
499                      .broadcast(bcast_spec);
500 
501     x_backprop.reshape(rest_by_depth).device(d) =
502         (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
503 
504     if (tensor_format == FORMAT_NCHW) {
505       // Perform NHWC to NCHW
506       std::vector<int32> perm = {0, 3, 1, 2};
507       OP_REQUIRES_OK(
508           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
509                                              transformed_x_backprop_output,
510                                              perm, x_backprop_output));
511     }
512   }
513 };
514 
515 template <typename T, typename U>
516 struct FusedBatchNormFreezeGrad<CPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormFreezeGrad517   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
518                   const Tensor& x_input, const Tensor& scale_input,
519                   const Tensor& pop_mean_input,
520                   const Tensor& pop_variance_input, U epsilon,
521                   Tensor* x_backprop_output, Tensor* scale_backprop_output,
522                   Tensor* offset_backprop_output) {
523     typename TTypes<T, 4>::ConstTensor y_backprop(
524         y_backprop_input.tensor<T, 4>());
525     typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
526     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
527     typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
528     typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
529     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
530     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
531 
532     const int depth = pop_mean.dimension(0);
533     const int rest_size = input.size() / depth;
534 
535     const CPUDevice& d = context->eigen_device<CPUDevice>();
536 
537     // Allocate two temporary workspaces of [depth] shape.
538     Tensor scratch1_vec, scratch2_vec;
539     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
540                                                    {depth}, &scratch1_vec));
541     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
542                                                    {depth}, &scratch2_vec));
543 
544     // Maybe allocate a temporary workspace of [rest_size, depth] shape.
545     Tensor scratch3_tensor;
546     if (std::is_same<T, U>::value) {
547       OP_REQUIRES(
548           context,
549           scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}),
550           errors::Internal("Failed to copy a tensor"));
551     } else {
552       OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
553                                                      {rest_size, depth},
554                                                      &scratch3_tensor));
555     }
556 
557     typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
558     typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
559     typename TTypes<U, 2>::Tensor scratch3(scratch3_tensor.tensor<U, 2>());
560 
561     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
562 #if !defined(EIGEN_HAS_INDEX_LIST)
563     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
564     Eigen::array<int, 2> rest_by_one({rest_size, 1});
565 #else
566     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
567     one_by_depth.set(1, depth);
568     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> rest_by_one;
569     rest_by_one.set(0, rest_size);
570 #endif
571 
572     // Sum reduction along the 0th dimension using custom CPU functor.
573     using ScalarSum = Eigen::internal::scalar_sum_op<U>;
574     const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
575     const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
576 
577     // offset_backprop  = sum(y_backprop)
578     // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
579     // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
580 
581     // NOTE: DEFAULT DEVICE comment is added to expression assignments that
582     // we don't want to be executed in a thread pool.
583 
584     auto y_backprop_rest_by_depth =
585         y_backprop.reshape(rest_by_depth).template cast<U>();
586     auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
587 
588     // offset_backprop  = sum(y_backprop)
589     redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output);
590 
591     // scratch1 = rsqrt(pop_var + epsilon)
592     scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt();  // DEFAULT DEVICE
593 
594     // scratch2 = sum(y_backprop * (x - mean))
595     scratch3.device(d) =
596         y_backprop_rest_by_depth *
597         (input_rest_by_depth -
598          pop_mean.reshape(one_by_depth).broadcast(rest_by_one));
599     redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec);
600 
601     x_backprop.reshape(rest_by_depth).device(d) =
602         (y_backprop_rest_by_depth *
603          ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth))
604               .broadcast(rest_by_one)))
605             .template cast<T>();
606     scale_backprop = scratch2 * scratch1;  // DEFAULT DEVICE
607   }
608 };
609 
610 #if !GOOGLE_CUDA
611 namespace {
612 // See implementation under GOOGLE_CUDA #ifdef below.
613 // This is a CUDA specific feature, do not enable it for non-CUDA builds
BatchnormSpatialPersistentEnabled()614 bool BatchnormSpatialPersistentEnabled() { return false; }
615 }  // namespace
616 #endif
617 
618 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
619 
620 namespace {
621 
AsDnnActivationMode(const FusedBatchNormActivationMode activation_mode)622 se::dnn::ActivationMode AsDnnActivationMode(
623     const FusedBatchNormActivationMode activation_mode) {
624   switch (activation_mode) {
625     case FusedBatchNormActivationMode::kIdentity:
626       return se::dnn::ActivationMode::kNone;
627     case FusedBatchNormActivationMode::kRelu:
628       return se::dnn::ActivationMode::kRelu;
629   }
630 }
631 
632 #if GOOGLE_CUDA
633 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
634 // `cuda_dnn.cc` for details.
BatchnormSpatialPersistentEnabled()635 bool BatchnormSpatialPersistentEnabled() {
636 #if CUDNN_VERSION >= 7402
637   static bool is_enabled = [] {
638     bool is_enabled = false;
639     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
640         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
641         /*default_val=*/false, &is_enabled));
642     return is_enabled;
643   }();
644   return is_enabled;
645 #else
646   return false;
647 #endif
648 }
649 #endif
650 
651 }  // namespace
652 
653 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)654 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
655   return DeviceMemory<U>::MakeFromByteSize(
656       tensor->template flat<T>().data(),
657       tensor->template flat<T>().size() * sizeof(T));
658 }
659 
660 // A helper to allocate temporary scratch memory for Cudnn BatchNormEx ops. It
661 // takes the ownership of the underlying memory. The expectation is that the
662 // memory should be alive for the span of the Cudnn BatchNormEx itself.
663 template <typename T>
664 class CudnnBatchNormAllocatorInTemp : public ScratchAllocator {
665  public:
666   ~CudnnBatchNormAllocatorInTemp() override = default;
667 
CudnnBatchNormAllocatorInTemp(OpKernelContext * context)668   explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context)
669       : context_(context) {}
670 
GetMemoryLimitInBytes()671   int64 GetMemoryLimitInBytes() override {
672     return std::numeric_limits<int64>::max();
673   }
674 
AllocateBytes(int64 byte_size)675   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
676     Tensor temporary_memory;
677     const DataType tf_data_type = DataTypeToEnum<T>::v();
678     int64 allocate_count =
679         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
680     Status allocation_status(context_->allocate_temp(
681         tf_data_type, TensorShape({allocate_count}), &temporary_memory));
682     if (!allocation_status.ok()) {
683       return allocation_status;
684     }
685     // Hold the reference of the allocated tensors until the end of the
686     // allocator.
687     allocated_tensors_.push_back(temporary_memory);
688     total_byte_size_ += byte_size;
689     return DeviceMemory<uint8>::MakeFromByteSize(
690         temporary_memory.template flat<T>().data(),
691         temporary_memory.template flat<T>().size() * sizeof(T));
692   }
693 
TotalByteSize() const694   int64 TotalByteSize() const { return total_byte_size_; }
695 
get_allocated_tensor(int index) const696   Tensor get_allocated_tensor(int index) const {
697     return allocated_tensors_[index];
698   }
699 
700  private:
701   int64 total_byte_size_ = 0;
702   OpKernelContext* context_;  // not owned
703   std::vector<Tensor> allocated_tensors_;
704 };
705 
706 // A helper to allocate memory for Cudnn BatchNormEx as a kernel output. It is
707 // used by forward pass kernel to feed the output to the backward pass.
708 // The memory is expected to live long enough after the backward pass is
709 // finished.
710 template <typename T>
711 class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
712  public:
~CudnnBatchNormAllocatorInOutput()713   ~CudnnBatchNormAllocatorInOutput() override {
714     if (!output_allocated) {
715       Tensor* dummy_reserve_space = nullptr;
716       OP_REQUIRES_OK(context_, context_->allocate_output(output_index_, {},
717                                                          &dummy_reserve_space));
718     }
719   }
720 
CudnnBatchNormAllocatorInOutput(OpKernelContext * context,int output_index)721   CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index)
722       : context_(context), output_index_(output_index) {}
723 
GetMemoryLimitInBytes()724   int64 GetMemoryLimitInBytes() override {
725     return std::numeric_limits<int64>::max();
726   }
727 
AllocateBytes(int64 byte_size)728   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
729     output_allocated = true;
730     DCHECK(total_byte_size_ == 0)
731         << "Reserve space allocator can only be called once";
732     int64 allocate_count =
733         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
734 
735     Tensor* temporary_memory = nullptr;
736     Status allocation_status(context_->allocate_output(
737         output_index_, TensorShape({allocate_count}), &temporary_memory));
738     if (!allocation_status.ok()) {
739       return allocation_status;
740     }
741     total_byte_size_ += byte_size;
742     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
743         temporary_memory->template flat<T>().data(),
744         temporary_memory->template flat<T>().size() * sizeof(T));
745     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
746   }
747 
TotalByteSize()748   int64 TotalByteSize() { return total_byte_size_; }
749 
750  private:
751   int64 total_byte_size_ = 0;
752   OpKernelContext* context_;  // not owned
753   int output_index_;
754   bool output_allocated = false;
755 };
756 
757 template <typename T, typename U, bool is_training>
758 struct FusedBatchNorm<GPUDevice, T, U, is_training> {
operator ()tensorflow::functor::FusedBatchNorm759   void operator()(OpKernelContext* context, const Tensor& x,
760                   const Tensor& scale, const Tensor& offset,
761                   const Tensor& estimated_mean,
762                   const Tensor& estimated_variance, const Tensor* side_input,
763                   U epsilon, U exponential_avg_factor,
764                   FusedBatchNormActivationMode activation_mode, Tensor* y,
765                   Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
766                   Tensor* saved_inv_var, TensorFormat tensor_format,
767                   bool use_reserved_space) {
768     auto* stream = context->op_device_context()->stream();
769     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
770 
771     const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
772     const int64 channels = GetTensorDim(x, tensor_format, 'C');
773     const int64 height = GetTensorDim(x, tensor_format, 'H');
774     const int64 width = GetTensorDim(x, tensor_format, 'W');
775 
776     // If use_reserved_space we have reserve_space_3 output (only in
777     // FusedBatchNormV3 op).
778 
779 #if GOOGLE_CUDA
780     // Check if cuDNN batch normalization has a fast NHWC implementation:
781     //   (1) In inference mode it's always fast.
782     //   (2) Tensorflow enabled batchnorm spatial persistence, we are called
783     //   from
784     //       FusedBatchNormV3, i.e. use_reserved_space is true.
785     const bool fast_nhwc_batch_norm =
786         !is_training ||
787         (BatchnormSpatialPersistentEnabled() &&
788          DataTypeToEnum<T>::value == DT_HALF && use_reserved_space);
789 #else
790     // fast NHWC implementation is a CUDA only feature
791     const bool fast_nhwc_batch_norm = false;
792 #endif
793 
794     // If input tensor is in NHWC format, and we have a fast cuDNN
795     // implementation, there is no need to do data format conversion.
796     TensorFormat compute_format =
797         fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
798                                                              : FORMAT_NCHW;
799 
800     VLOG(2) << "FusedBatchNorm:"
801             << " batch_size: " << batch_size << " channels: " << channels
802             << " height: " << height << " width:" << width
803             << " x shape: " << x.shape().DebugString()
804             << " scale shape: " << scale.shape().DebugString()
805             << " offset shape: " << offset.shape().DebugString()
806             << " activation mode: " << ToString(activation_mode)
807             << " tensor format: " << ToString(tensor_format)
808             << " compute format: " << ToString(compute_format);
809 
810     auto maybe_make_dummy_output = [context, use_reserved_space]() -> Status {
811       if (use_reserved_space) {
812         Tensor* dummy_reserve_space = nullptr;
813         return context->allocate_output(5, {}, &dummy_reserve_space);
814       }
815       return Status::OK();
816     };
817 
818     // If input is empty, return NaN mean/variance
819     if (x.shape().num_elements() == 0) {
820       OP_REQUIRES_OK(context, maybe_make_dummy_output());
821       functor::SetNanFunctor<U> f;
822       f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
823       f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
824       return;
825     }
826 
827     // In inference mode we use custom CUDA kernel, because cuDNN does not
828     // support side input and activations for inference.
829     const bool has_side_input = side_input != nullptr;
830     const bool has_activation =
831         activation_mode != FusedBatchNormActivationMode::kIdentity;
832 
833     if (!is_training && (has_side_input || has_activation)) {
834       OP_REQUIRES_OK(context, maybe_make_dummy_output());
835       FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor;
836 
837       if (has_side_input) {
838         inference_functor(context, tensor_format, x.tensor<T, 4>(),
839                           scale.vec<U>(), offset.vec<U>(),
840                           estimated_mean.vec<U>(), estimated_variance.vec<U>(),
841                           side_input->tensor<T, 4>(), epsilon, activation_mode,
842                           y->tensor<T, 4>());
843       } else {
844         typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0);
845         inference_functor(context, tensor_format, x.tensor<T, 4>(),
846                           scale.vec<U>(), offset.vec<U>(),
847                           estimated_mean.vec<U>(), estimated_variance.vec<U>(),
848                           empty_tensor, epsilon, activation_mode,
849                           y->tensor<T, 4>());
850       }
851       return;
852     }
853 
854     Tensor x_maybe_transformed = x;
855     Tensor x_transformed;
856     Tensor y_transformed;
857     se::DeviceMemory<T> y_ptr;
858 
859     if (tensor_format == compute_format) {
860       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
861     } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
862       OP_REQUIRES_OK(context, context->allocate_temp(
863                                   DataTypeToEnum<T>::value,
864                                   ShapeFromFormat(compute_format, batch_size,
865                                                   height, width, channels),
866                                   &x_transformed));
867       functor::NHWCToNCHW<GPUDevice, T, 4>()(
868           context->eigen_device<GPUDevice>(),
869           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
870           x_transformed.tensor<T, 4>());
871       x_maybe_transformed = x_transformed;
872 
873       OP_REQUIRES_OK(context, context->allocate_temp(
874                                   DataTypeToEnum<T>::value,
875                                   ShapeFromFormat(compute_format, batch_size,
876                                                   height, width, channels),
877                                   &y_transformed));
878       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed);
879     } else {
880       context->SetStatus(errors::Internal(
881           "Unsupported tensor format: ", ToString(tensor_format),
882           " and compute format: ", ToString(compute_format)));
883       return;
884     }
885 
886     const se::dnn::DataLayout data_layout =
887         compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
888                                       : se::dnn::DataLayout::kBatchDepthYX;
889 
890     se::dnn::BatchDescriptor x_desc;
891     x_desc.set_count(batch_size)
892         .set_feature_map_count(channels)
893         .set_height(height)
894         .set_width(width)
895         .set_layout(data_layout);
896 
897     se::dnn::BatchDescriptor scale_offset_desc;
898     scale_offset_desc.set_count(1)
899         .set_feature_map_count(channels)
900         .set_height(1)
901         .set_width(1)
902         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
903 
904     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
905     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
906     auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
907     auto estimated_mean_ptr =
908         StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
909     auto estimated_variance_ptr =
910         StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
911     auto side_input_ptr =
912         side_input != nullptr
913             ? StreamExecutorUtil::AsDeviceMemory<U>(*side_input)
914             : se::DeviceMemory<U>();
915     auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
916 
917     auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
918     auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
919     auto saved_inv_var_ptr =
920         StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
921 
922     std::unique_ptr<functor::CudnnBatchNormAllocatorInOutput<U>>
923         reserve_space_allocator;
924     std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
925         workspace_allocator;
926     if (use_reserved_space) {
927       reserve_space_allocator.reset(
928           new functor::CudnnBatchNormAllocatorInOutput<U>(context, 5));
929       workspace_allocator.reset(
930           new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
931     }
932     if (!batch_mean->SharesBufferWith(estimated_mean) &&
933         exponential_avg_factor != 1.0f) {
934       OP_REQUIRES(
935           context,
936           stream
937               ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr,
938                               estimated_mean.NumElements() * sizeof(U))
939               .ok(),
940           errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
941                            "from device"));
942     }
943     if (!batch_var->SharesBufferWith(estimated_variance) &&
944         exponential_avg_factor != 1.0f) {
945       OP_REQUIRES(
946           context,
947           stream
948               ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr,
949                               estimated_variance.NumElements() * sizeof(U))
950               .ok(),
951           errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
952                            "from device"));
953     }
954     bool cudnn_launch_status =
955         stream
956             ->ThenBatchNormalizationForward(
957                 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
958                 estimated_variance_ptr, side_input_ptr, x_desc,
959                 scale_offset_desc, static_cast<double>(epsilon),
960                 static_cast<double>(exponential_avg_factor),
961                 AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
962                 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
963                 is_training, reserve_space_allocator.get(),
964                 workspace_allocator.get())
965             .ok();
966 
967     if (!cudnn_launch_status) {
968       context->SetStatus(
969           errors::Internal("cuDNN launch failure : input shape (",
970                            x.shape().DebugString(), ")"));
971       return;
972     }
973 
974     if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
975       functor::NCHWToNHWC<GPUDevice, T, 4>()(
976           context->eigen_device<GPUDevice>(),
977           const_cast<const Tensor&>(y_transformed).tensor<T, 4>(),
978           y->tensor<T, 4>());
979     }
980   }
981 };
982 
983 template <typename T, typename U>
984 struct FusedBatchNormGrad<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormGrad985   void operator()(OpKernelContext* context, const Tensor& y_backprop,
986                   const Tensor& x, const Tensor& scale, const Tensor& mean,
987                   const Tensor& inv_variance, U epsilon, Tensor* x_backprop,
988                   Tensor* scale_backprop, Tensor* offset_backprop,
989                   bool use_reserved_space, TensorFormat tensor_format) {
990     auto* stream = context->op_device_context()->stream();
991     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
992 
993     const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
994     const int64 channels = GetTensorDim(x, tensor_format, 'C');
995     const int64 height = GetTensorDim(x, tensor_format, 'H');
996     const int64 width = GetTensorDim(x, tensor_format, 'W');
997 
998 #if GOOGLE_CUDA
999     // Check if cuDNN batch normalization has a fast NHWC implementation:
1000     //   (1) Tensorflow enabled batchnorm spatial persistence, and
1001     //       FusedBatchNormGradV3 passed non-null reserve space and allocator.
1002     const bool fast_nhwc_batch_norm = BatchnormSpatialPersistentEnabled() &&
1003                                       DataTypeToEnum<T>::value == DT_HALF &&
1004                                       use_reserved_space;
1005 #else
1006     // fast NHWC implementation is a CUDA only feature
1007     const bool fast_nhwc_batch_norm = false;
1008 #endif
1009 
1010     // If input tensor is in NHWC format, and we have a fast cuDNN
1011     // implementation, there is no need to do data format conversion.
1012     TensorFormat compute_format =
1013         fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
1014                                                              : FORMAT_NCHW;
1015 
1016     VLOG(2) << "FusedBatchNormGrad:"
1017             << " batch_size: " << batch_size << " channels: " << channels
1018             << " height: " << height << " width: " << width
1019             << " y_backprop shape: " << y_backprop.shape().DebugString()
1020             << " x shape: " << x.shape().DebugString()
1021             << " scale shape: " << scale.shape().DebugString()
1022             << " tensor format: " << ToString(tensor_format)
1023             << " compute format: " << ToString(compute_format);
1024 
1025     // Inputs
1026     Tensor y_backprop_maybe_transformed = y_backprop;
1027     Tensor x_maybe_transformed = x;
1028     Tensor y_backprop_transformed;
1029     Tensor x_transformed;
1030 
1031     // Outputs
1032     Tensor x_backprop_transformed;
1033     se::DeviceMemory<T> x_backprop_ptr;
1034 
1035     if (tensor_format == compute_format) {
1036       x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
1037     } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1038       // Transform inputs from 'NHWC' to 'NCHW'
1039       OP_REQUIRES_OK(context, context->allocate_temp(
1040                                   DataTypeToEnum<T>::value,
1041                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
1042                                                   height, width, channels),
1043                                   &y_backprop_transformed));
1044       functor::NHWCToNCHW<GPUDevice, T, 4>()(
1045           context->eigen_device<GPUDevice>(),
1046           const_cast<const Tensor&>(y_backprop_maybe_transformed)
1047               .tensor<T, 4>(),
1048           y_backprop_transformed.tensor<T, 4>());
1049       y_backprop_maybe_transformed = y_backprop_transformed;
1050 
1051       OP_REQUIRES_OK(context, context->allocate_temp(
1052                                   DataTypeToEnum<T>::value,
1053                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
1054                                                   height, width, channels),
1055                                   &x_transformed));
1056       functor::NHWCToNCHW<GPUDevice, T, 4>()(
1057           context->eigen_device<GPUDevice>(),
1058           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
1059           x_transformed.tensor<T, 4>());
1060       x_maybe_transformed = x_transformed;
1061 
1062       // Allocate memory for transformed outputs in 'NCHW'
1063       OP_REQUIRES_OK(context, context->allocate_temp(
1064                                   DataTypeToEnum<T>::value,
1065                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
1066                                                   height, width, channels),
1067                                   &x_backprop_transformed));
1068       x_backprop_ptr =
1069           StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
1070     } else {
1071       context->SetStatus(errors::Internal(
1072           "Unsupported tensor format: ", ToString(tensor_format),
1073           " and compute format: ", ToString(compute_format)));
1074       return;
1075     }
1076 
1077     const se::dnn::DataLayout data_layout =
1078         compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
1079                                       : se::dnn::DataLayout::kBatchDepthYX;
1080 
1081     se::dnn::BatchDescriptor x_desc;
1082     x_desc.set_count(batch_size)
1083         .set_feature_map_count(channels)
1084         .set_height(height)
1085         .set_width(width)
1086         .set_layout(data_layout);
1087 
1088     se::dnn::BatchDescriptor scale_offset_desc;
1089     scale_offset_desc.set_count(1)
1090         .set_feature_map_count(channels)
1091         .set_height(1)
1092         .set_width(1)
1093         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
1094 
1095     auto y_backprop_ptr =
1096         StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
1097     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
1098     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
1099     auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
1100     auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
1101     auto scale_backprop_ptr =
1102         StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
1103     auto offset_backprop_ptr =
1104         StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
1105 
1106     std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
1107         workspace_allocator;
1108     DeviceMemory<uint8>* reserve_space_data_ptr = nullptr;
1109     DeviceMemory<uint8> reserve_space_data;
1110 #if CUDNN_VERSION >= 7402
1111     if (use_reserved_space) {
1112       const Tensor& reserve_space = context->input(5);
1113       workspace_allocator.reset(
1114           new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
1115 
1116       // the cudnn kernel outputs inverse variance in forward and reuse it in
1117       // backward
1118       if (reserve_space.dims() != 0) {
1119         reserve_space_data = functor::CastDeviceMemory<uint8, U>(
1120             const_cast<Tensor*>(&reserve_space));
1121         reserve_space_data_ptr = &reserve_space_data;
1122       }
1123     }
1124 #endif  // CUDNN_VERSION >= 7402
1125 
1126     bool cudnn_launch_status =
1127         stream
1128             ->ThenBatchNormalizationBackward(
1129                 y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr,
1130                 x_desc, scale_offset_desc, static_cast<double>(epsilon),
1131                 &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr,
1132                 reserve_space_data_ptr, workspace_allocator.get())
1133             .ok();
1134 
1135     if (!cudnn_launch_status) {
1136       context->SetStatus(
1137           errors::Internal("cuDNN launch failure : input shape (",
1138                            x.shape().DebugString(), ")"));
1139     }
1140     if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1141       functor::NCHWToNHWC<GPUDevice, T, 4>()(
1142           context->eigen_device<GPUDevice>(),
1143           const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
1144           x_backprop->tensor<T, 4>());
1145     }
1146   }
1147 };
1148 
1149 // Forward declarations of the functor specializations for GPU.
1150 #define DECLARE_GPU_SPEC(T, U)                                                 \
1151   template <>                                                                  \
1152   void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()(                  \
1153       OpKernelContext* context, const Tensor& y_backprop_input,                \
1154       const Tensor& x_input, const Tensor& scale_input,                        \
1155       const Tensor& mean_input, const Tensor& variance_input, U epsilon,       \
1156       Tensor* x_backprop_output, Tensor* scale_backprop_output,                \
1157       Tensor* offset_backprop_output);                                         \
1158   extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>;            \
1159   template <>                                                                  \
1160   void FusedBatchNormInferenceFunctor<GPUDevice, T, U>::operator()(            \
1161       OpKernelContext* context, TensorFormat tensor_format,                    \
1162       typename TTypes<T, 4>::ConstTensor in,                                   \
1163       typename TTypes<U>::ConstVec scale, typename TTypes<U>::ConstVec offset, \
1164       typename TTypes<U>::ConstVec estimated_mean,                             \
1165       typename TTypes<U>::ConstVec estimated_variance,                         \
1166       typename TTypes<T, 4>::ConstTensor side_input, U epsilon,                \
1167       FusedBatchNormActivationMode activation_mode,                            \
1168       typename TTypes<T, 4>::Tensor out);                                      \
1169   extern template struct FusedBatchNormInferenceFunctor<GPUDevice, T, U>;
1170 
1171 DECLARE_GPU_SPEC(float, float);
1172 DECLARE_GPU_SPEC(Eigen::half, float);
1173 
1174 #undef DECLARE_GPU_SPEC
1175 
1176 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1177 }  // namespace functor
1178 
1179 template <typename Device, typename T, typename U>
1180 class FusedBatchNormOpBase : public OpKernel {
1181   using FbnActivationMode = functor::FusedBatchNormActivationMode;
1182 
1183  protected:
FusedBatchNormOpBase(OpKernelConstruction * context,bool is_batch_norm_ex=false)1184   explicit FusedBatchNormOpBase(OpKernelConstruction* context,
1185                                 bool is_batch_norm_ex = false)
1186       : OpKernel(context) {
1187     float epsilon;
1188     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1189     epsilon_ = U(epsilon);
1190     float exponential_avg_factor;
1191     OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
1192                                              &exponential_avg_factor));
1193     exponential_avg_factor_ = U(exponential_avg_factor);
1194     string tensor_format;
1195     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1196     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1197                 errors::InvalidArgument("Invalid data format"));
1198     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1199 
1200     if (!is_batch_norm_ex) {
1201       has_side_input_ = false;
1202       activation_mode_ = FbnActivationMode::kIdentity;
1203     } else {
1204       OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
1205 
1206       int num_side_inputs;
1207       OP_REQUIRES_OK(context,
1208                      context->GetAttr("num_side_inputs", &num_side_inputs));
1209       OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1,
1210                   errors::InvalidArgument(
1211                       "FusedBatchNorm accepts at most one side input."));
1212       has_side_input_ = (num_side_inputs == 1);
1213       if (has_side_input_ && is_training_) {
1214         OP_REQUIRES(
1215             context, activation_mode_ != FbnActivationMode::kIdentity,
1216             errors::InvalidArgument("Identity activation is not supported with "
1217                                     "non-empty side input"));
1218       }
1219     }
1220 
1221     if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) {
1222       // NOTE(ezhulenev): Following requirements are coming from implementation
1223       // details of cudnnBatchNormalizationForwardTrainingEx used in training
1224       // mode. In inference mode we call custom CUDA kernel that supports all
1225       // data formats and data types.
1226       OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
1227                   errors::InvalidArgument("FusedBatchNorm with activation "
1228                                           "supports only DT_HALF data type."));
1229       OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1230                   errors::InvalidArgument("FusedBatchNorm with activation "
1231                                           "supports only NHWC tensor format."));
1232       OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(),
1233                   errors::InvalidArgument(
1234                       "FusedBatchNorm with activation must run with cuDNN "
1235                       "spatial persistence mode enabled."));
1236     }
1237   }
1238 
1239   // If use_reserved_space is true, we need to handle the 5th output (a reserved
1240   // space) and a new cudnn batch norm will be called if the version > 7.4.2.
1241   // If use_reserved_space is false, we don't have 5th output.
ComputeWithReservedSpace(OpKernelContext * context,bool use_reserved_space)1242   virtual void ComputeWithReservedSpace(OpKernelContext* context,
1243                                         bool use_reserved_space) {
1244     Tensor x = context->input(0);
1245     const Tensor& scale = context->input(1);
1246     const Tensor& offset = context->input(2);
1247     const Tensor& estimated_mean = context->input(3);
1248     const Tensor& estimated_variance = context->input(4);
1249     const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
1250 
1251     OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1252                 errors::InvalidArgument("input must be 4 or 5-dimensional",
1253                                         x.shape().DebugString()));
1254     OP_REQUIRES(context, scale.dims() == 1,
1255                 errors::InvalidArgument("scale must be 1-dimensional",
1256                                         scale.shape().DebugString()));
1257     OP_REQUIRES(context, offset.dims() == 1,
1258                 errors::InvalidArgument("offset must be 1-dimensional",
1259                                         offset.shape().DebugString()));
1260     OP_REQUIRES(context, estimated_mean.dims() == 1,
1261                 errors::InvalidArgument("estimated_mean must be 1-dimensional",
1262                                         estimated_mean.shape().DebugString()));
1263     OP_REQUIRES(
1264         context, estimated_variance.dims() == 1,
1265         errors::InvalidArgument("estimated_variance must be 1-dimensional",
1266                                 estimated_variance.shape().DebugString()));
1267     bool use_reshape = (x.dims() == 5);
1268     auto x_shape = x.shape();
1269     TensorShape dest_shape;
1270     if (use_reshape) {
1271       const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
1272       int64 in_planes = GetTensorDim(x, tensor_format_, '0');
1273       int64 in_rows = GetTensorDim(x, tensor_format_, '1');
1274       int64 in_cols = GetTensorDim(x, tensor_format_, '2');
1275       const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
1276       dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1277                                    {{in_planes, in_rows * in_cols}}, in_depth);
1278       OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1279                   errors::InvalidArgument("Error during tensor copy."));
1280     }
1281 
1282     if (has_side_input_) {
1283       OP_REQUIRES(context, side_input->shape() == x.shape(),
1284                   errors::InvalidArgument(
1285                       "side_input shape must be equal to input shape: ",
1286                       side_input->shape().DebugString(),
1287                       " != ", x.shape().DebugString()));
1288     }
1289 
1290     if (activation_mode_ != FbnActivationMode::kIdentity) {
1291       // NOTE(ezhulenev): This requirement is coming from implementation
1292       // details of cudnnBatchNormalizationForwardTrainingEx.
1293       OP_REQUIRES(
1294           context, !is_training_ || x.dim_size(3) % 4 == 0,
1295           errors::InvalidArgument("FusedBatchNorm with activation requires "
1296                                   "channel dimension to be a multiple of 4."));
1297     }
1298 
1299     Tensor* y = nullptr;
1300     auto alloc_shape = use_reshape ? dest_shape : x_shape;
1301     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1302                                 {0}, 0, alloc_shape, &y));
1303 
1304     Tensor* batch_mean = nullptr;
1305     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1306                                 {3}, 1, scale.shape(), &batch_mean));
1307     Tensor* batch_var = nullptr;
1308     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1309                                 {4}, 2, scale.shape(), &batch_var));
1310     Tensor* saved_mean = nullptr;
1311     OP_REQUIRES_OK(context,
1312                    context->allocate_output(3, scale.shape(), &saved_mean));
1313     Tensor* saved_maybe_inv_var = nullptr;
1314     OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
1315                                                      &saved_maybe_inv_var));
1316 
1317     if (is_training_) {
1318       functor::FusedBatchNorm<Device, T, U, true>()(
1319           context, x, scale, offset, estimated_mean, estimated_variance,
1320           side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1321           batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1322           tensor_format_, use_reserved_space);
1323     } else {
1324       functor::FusedBatchNorm<Device, T, U, false>()(
1325           context, x, scale, offset, estimated_mean, estimated_variance,
1326           side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1327           batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1328           tensor_format_, use_reserved_space);
1329     }
1330     if (use_reshape) {
1331       OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
1332                   errors::InvalidArgument("Error during tensor copy."));
1333     }
1334   }
1335 
1336  private:
1337   U epsilon_;
1338   U exponential_avg_factor_;
1339   TensorFormat tensor_format_;
1340   bool is_training_;
1341   bool has_side_input_;
1342   FbnActivationMode activation_mode_;
1343 };
1344 
1345 template <typename Device, typename T, typename U>
1346 class FusedBatchNormOp : public FusedBatchNormOpBase<Device, T, U> {
1347  public:
FusedBatchNormOp(OpKernelConstruction * context)1348   explicit FusedBatchNormOp(OpKernelConstruction* context)
1349       : FusedBatchNormOpBase<Device, T, U>(context) {}
1350 
Compute(OpKernelContext * context)1351   void Compute(OpKernelContext* context) override {
1352     FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1353                                                                  false);
1354   }
1355 };
1356 
1357 template <typename Device, typename T, typename U>
1358 class FusedBatchNormOpV3 : public FusedBatchNormOpBase<Device, T, U> {
1359  public:
FusedBatchNormOpV3(OpKernelConstruction * context)1360   explicit FusedBatchNormOpV3(OpKernelConstruction* context)
1361       : FusedBatchNormOpBase<Device, T, U>(context) {}
1362 
Compute(OpKernelContext * context)1363   void Compute(OpKernelContext* context) override {
1364     FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1365   }
1366 };
1367 
1368 template <typename Device, typename T, typename U>
1369 class FusedBatchNormOpEx : public FusedBatchNormOpBase<Device, T, U> {
1370   static constexpr bool kWithSideInputAndActivation = true;
1371 
1372  public:
FusedBatchNormOpEx(OpKernelConstruction * context)1373   explicit FusedBatchNormOpEx(OpKernelConstruction* context)
1374       : FusedBatchNormOpBase<Device, T, U>(context,
1375                                            kWithSideInputAndActivation) {}
1376 
Compute(OpKernelContext * context)1377   void Compute(OpKernelContext* context) override {
1378     FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1379   }
1380 };
1381 
1382 template <typename Device, typename T, typename U>
1383 class FusedBatchNormGradOpBase : public OpKernel {
1384  protected:
FusedBatchNormGradOpBase(OpKernelConstruction * context)1385   explicit FusedBatchNormGradOpBase(OpKernelConstruction* context)
1386       : OpKernel(context) {
1387     float epsilon;
1388     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1389     epsilon_ = U(epsilon);
1390     string tensor_format;
1391     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1392     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1393                 errors::InvalidArgument("Invalid data format"));
1394     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1395   }
1396 
ComputeWithReservedSpace(OpKernelContext * context,bool use_reserved_space)1397   virtual void ComputeWithReservedSpace(OpKernelContext* context,
1398                                         bool use_reserved_space) {
1399     Tensor y_backprop = context->input(0);
1400     Tensor x = context->input(1);
1401     const Tensor& scale = context->input(2);
1402     // When is_training=True, batch mean and variance/inverted variance are
1403     // saved in the forward pass to be reused here. When is_training=False,
1404     // population mean and variance need to be forwarded here to compute the
1405     // gradients.
1406     const Tensor& saved_mean_or_pop_mean = context->input(3);
1407     // The Eigen implementation saves variance in the forward pass, while cuDNN
1408     // saves inverted variance.
1409     const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
1410 
1411     OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5,
1412                 errors::InvalidArgument("input must be 4 or 5-dimensional",
1413                                         y_backprop.shape().DebugString()));
1414     OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1415                 errors::InvalidArgument("input must be 4 or 5-dimensional",
1416                                         x.shape().DebugString()));
1417     OP_REQUIRES(context, scale.dims() == 1,
1418                 errors::InvalidArgument("scale must be 1-dimensional",
1419                                         scale.shape().DebugString()));
1420     OP_REQUIRES(
1421         context, saved_mean_or_pop_mean.dims() == 1,
1422         errors::InvalidArgument("saved mean must be 1-dimensional",
1423                                 saved_mean_or_pop_mean.shape().DebugString()));
1424     OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
1425                 errors::InvalidArgument(
1426                     "saved variance must be 1-dimensional",
1427                     saved_maybe_inv_var_or_pop_var.shape().DebugString()));
1428     bool use_reshape = (x.dims() == 5);
1429     auto x_shape = x.shape();
1430     TensorShape dest_shape;
1431     if (use_reshape) {
1432       const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
1433       int64 in_planes = GetTensorDim(x, tensor_format_, '0');
1434       int64 in_rows = GetTensorDim(x, tensor_format_, '1');
1435       int64 in_cols = GetTensorDim(x, tensor_format_, '2');
1436       const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
1437       dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1438                                    {{in_planes, in_rows * in_cols}}, in_depth);
1439       OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1440                   errors::InvalidArgument("Error during tensor copy."));
1441       OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
1442                   errors::InvalidArgument("Error during tensor copy."));
1443     }
1444 
1445     Tensor* x_backprop = nullptr;
1446     auto alloc_shape = use_reshape ? dest_shape : x_shape;
1447     OP_REQUIRES_OK(context,
1448                    context->allocate_output(0, alloc_shape, &x_backprop));
1449 
1450     const TensorShape& scale_offset_shape = scale.shape();
1451     Tensor* scale_backprop = nullptr;
1452     OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape,
1453                                                      &scale_backprop));
1454     Tensor* offset_backprop = nullptr;
1455     OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
1456                                                      &offset_backprop));
1457     // Two placeholders for estimated_mean and estimated_variance, which are
1458     // used for inference and thus not needed here for gradient computation.
1459     // They are filled with zeros so as to avoid NaN outputs.
1460     Tensor* placeholder_1 = nullptr;
1461     OP_REQUIRES_OK(
1462         context, context->allocate_output(3, TensorShape({0}), &placeholder_1));
1463     Tensor* placeholder_2 = nullptr;
1464     OP_REQUIRES_OK(
1465         context, context->allocate_output(4, TensorShape({0}), &placeholder_2));
1466 
1467     // If input is empty, set gradients w.r.t scale/offset to zero.
1468     if (x.shape().num_elements() == 0) {
1469       functor::SetZeroFunctor<Device, U> f;
1470       f(context->eigen_device<Device>(), scale_backprop->flat<U>());
1471       f(context->eigen_device<Device>(), offset_backprop->flat<U>());
1472       return;
1473     }
1474 
1475     if (is_training_) {
1476       functor::FusedBatchNormGrad<Device, T, U>()(
1477           context, y_backprop, x, scale, saved_mean_or_pop_mean,
1478           saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
1479           offset_backprop, use_reserved_space, tensor_format_);
1480     } else {
1481       // Necessary layout conversion is currently done in python.
1482       OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1483                   errors::InvalidArgument(
1484                       "The implementation of "
1485                       "FusedBatchNormGrad with is_training=False only support "
1486                       "NHWC tensor format for now."));
1487       functor::FusedBatchNormFreezeGrad<Device, T, U>()(
1488           context, y_backprop, x, scale, saved_mean_or_pop_mean,
1489           saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
1490           offset_backprop);
1491     }
1492     if (use_reshape) {
1493       OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
1494                   errors::InvalidArgument("Error during tensor copy."));
1495     }
1496   }
1497 
1498  private:
1499   U epsilon_;
1500   TensorFormat tensor_format_;
1501   bool is_training_;
1502 };
1503 
1504 template <typename Device, typename T, typename U>
1505 class FusedBatchNormGradOp : public FusedBatchNormGradOpBase<Device, T, U> {
1506  public:
FusedBatchNormGradOp(OpKernelConstruction * context)1507   explicit FusedBatchNormGradOp(OpKernelConstruction* context)
1508       : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1509 
Compute(OpKernelContext * context)1510   void Compute(OpKernelContext* context) override {
1511     FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1512                                                                      false);
1513   }
1514 };
1515 
1516 template <typename Device, typename T, typename U>
1517 class FusedBatchNormGradOpV3 : public FusedBatchNormGradOpBase<Device, T, U> {
1518  public:
FusedBatchNormGradOpV3(OpKernelConstruction * context)1519   explicit FusedBatchNormGradOpV3(OpKernelConstruction* context)
1520       : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1521 
Compute(OpKernelContext * context)1522   void Compute(OpKernelContext* context) override {
1523     FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1524                                                                      true);
1525   }
1526 };
1527 
1528 REGISTER_KERNEL_BUILDER(
1529     Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1530     FusedBatchNormOp<CPUDevice, float, float>);
1531 
1532 REGISTER_KERNEL_BUILDER(
1533     Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1534     FusedBatchNormGradOp<CPUDevice, float, float>);
1535 
1536 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1537                             .Device(DEVICE_CPU)
1538                             .TypeConstraint<float>("T")
1539                             .TypeConstraint<float>("U"),
1540                         FusedBatchNormOp<CPUDevice, float, float>);
1541 
1542 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1543                             .Device(DEVICE_CPU)
1544                             .TypeConstraint<float>("T")
1545                             .TypeConstraint<float>("U"),
1546                         FusedBatchNormGradOp<CPUDevice, float, float>);
1547 
1548 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1549                             .Device(DEVICE_CPU)
1550                             .TypeConstraint<Eigen::half>("T")
1551                             .TypeConstraint<float>("U"),
1552                         FusedBatchNormOp<CPUDevice, Eigen::half, float>);
1553 
1554 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1555                             .Device(DEVICE_CPU)
1556                             .TypeConstraint<Eigen::half>("T")
1557                             .TypeConstraint<float>("U"),
1558                         FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
1559 
1560 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1561                             .Device(DEVICE_CPU)
1562                             .TypeConstraint<float>("T")
1563                             .TypeConstraint<float>("U"),
1564                         FusedBatchNormOpV3<CPUDevice, float, float>);
1565 
1566 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1567                             .Device(DEVICE_CPU)
1568                             .TypeConstraint<float>("T")
1569                             .TypeConstraint<float>("U"),
1570                         FusedBatchNormGradOpV3<CPUDevice, float, float>);
1571 
1572 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1573                             .Device(DEVICE_CPU)
1574                             .TypeConstraint<Eigen::half>("T")
1575                             .TypeConstraint<float>("U"),
1576                         FusedBatchNormOpV3<CPUDevice, Eigen::half, float>);
1577 
1578 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1579                             .Device(DEVICE_CPU)
1580                             .TypeConstraint<Eigen::half>("T")
1581                             .TypeConstraint<float>("U"),
1582                         FusedBatchNormGradOpV3<CPUDevice, Eigen::half, float>);
1583 
1584 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1585 
1586 REGISTER_KERNEL_BUILDER(
1587     Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1588     FusedBatchNormOp<GPUDevice, float, float>);
1589 
1590 REGISTER_KERNEL_BUILDER(
1591     Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1592     FusedBatchNormGradOp<GPUDevice, float, float>);
1593 
1594 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1595                             .Device(DEVICE_GPU)
1596                             .TypeConstraint<float>("T")
1597                             .TypeConstraint<float>("U"),
1598                         FusedBatchNormOp<GPUDevice, float, float>);
1599 
1600 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1601                             .Device(DEVICE_GPU)
1602                             .TypeConstraint<float>("T")
1603                             .TypeConstraint<float>("U"),
1604                         FusedBatchNormGradOp<GPUDevice, float, float>);
1605 
1606 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1607                             .Device(DEVICE_GPU)
1608                             .TypeConstraint<Eigen::half>("T")
1609                             .TypeConstraint<float>("U"),
1610                         FusedBatchNormOp<GPUDevice, Eigen::half, float>);
1611 
1612 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1613                             .Device(DEVICE_GPU)
1614                             .TypeConstraint<Eigen::half>("T")
1615                             .TypeConstraint<float>("U"),
1616                         FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
1617 
1618 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1619                             .Device(DEVICE_GPU)
1620                             .TypeConstraint<float>("T")
1621                             .TypeConstraint<float>("U"),
1622                         FusedBatchNormOpV3<GPUDevice, float, float>);
1623 
1624 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1625                             .Device(DEVICE_GPU)
1626                             .TypeConstraint<float>("T")
1627                             .TypeConstraint<float>("U"),
1628                         FusedBatchNormOpEx<GPUDevice, float, float>);
1629 
1630 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1631                             .Device(DEVICE_GPU)
1632                             .TypeConstraint<float>("T")
1633                             .TypeConstraint<float>("U"),
1634                         FusedBatchNormGradOpV3<GPUDevice, float, float>);
1635 
1636 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1637                             .Device(DEVICE_GPU)
1638                             .TypeConstraint<Eigen::half>("T")
1639                             .TypeConstraint<float>("U"),
1640                         FusedBatchNormOpV3<GPUDevice, Eigen::half, float>);
1641 
1642 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1643                             .Device(DEVICE_GPU)
1644                             .TypeConstraint<Eigen::half>("T")
1645                             .TypeConstraint<float>("U"),
1646                         FusedBatchNormOpEx<GPUDevice, Eigen::half, float>);
1647 
1648 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1649                             .Device(DEVICE_GPU)
1650                             .TypeConstraint<Eigen::half>("T")
1651                             .TypeConstraint<float>("U"),
1652                         FusedBatchNormGradOpV3<GPUDevice, Eigen::half, float>);
1653 
1654 #endif
1655 
1656 }  // namespace tensorflow
1657