1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 
31 struct CudnnBatchNormParamsCommon {
32   se::DeviceMemoryBase operand;
33   se::dnn::BatchDescriptor operand_desc;
34   se::dnn::BatchDescriptor scale_offset_desc;
35   se::DeviceMemory<float> scale;
36   float epsilon;
37 };
38 
39 struct CudnnBatchNormForwardInferenceParams {
40   CudnnBatchNormParamsCommon common;
41   se::DeviceMemoryBase output;
42   se::DeviceMemory<float> offset;
43   se::DeviceMemory<float> mean;
44   se::DeviceMemory<float> variance;
45 };
46 
47 struct CudnnBatchNormForwardTrainingParams {
48   CudnnBatchNormParamsCommon common;
49   se::DeviceMemoryBase output_data;
50   se::DeviceMemory<float> offset;
51   se::DeviceMemory<float> output_mean;
52   se::DeviceMemory<float> output_inv_stddev;
53 };
54 
55 struct CudnnBatchNormBackwardParams {
56   CudnnBatchNormParamsCommon common;
57   se::DeviceMemoryBase output_grad_data;
58   se::DeviceMemoryBase grad_output;
59   se::DeviceMemory<float> output_grad_scale;
60   se::DeviceMemory<float> output_grad_offset;
61   se::DeviceMemory<float> mean;
62   se::DeviceMemory<float> inv_stddev;
63 };
64 
65 struct DnnBatchDescriptors {
66   se::dnn::BatchDescriptor input_desc;
67   se::dnn::BatchDescriptor scale_offset_desc;
68 };
69 
MakeBatchNormDescriptors(const Shape & shape,int64 feature_index)70 DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape,
71                                              int64 feature_index) {
72   std::vector<int64> logical_to_physical =
73       LayoutUtil::MakeLogicalToPhysical(shape.layout());
74 
75   auto physical_dim_size = [&](int64 physical_dim) {
76     return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim));
77   };
78 
79   // Batchnorm only cares about the location of the depth (aka "feature") dim.
80   // The other dims are all treated the same.  Thus we can use the kBatchDepthYX
81   // cudnn layout for any XLA shape+layout, even XLA shapes that don't have
82   // exactly 4 dimensions: We put everything that comes before the feature dim
83   // into "batch", and everything that comes after the feature dim into "Y".
84   int64 batch_size = 1;
85   int64 y_size = 1;
86   int64 physical_dim;
87   for (physical_dim = 0; physical_dim != logical_to_physical[feature_index];
88        ++physical_dim) {
89     CHECK_LT(physical_dim, shape.dimensions_size());
90     batch_size *= physical_dim_size(physical_dim);
91   }
92   ++physical_dim;  // Skip the feature dimension.
93   for (; physical_dim < shape.dimensions_size(); ++physical_dim) {
94     y_size *= physical_dim_size(physical_dim);
95   }
96 
97   DnnBatchDescriptors batch_descs;
98   batch_descs.input_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
99       .set_count(batch_size)
100       .set_feature_map_count(shape.dimensions(feature_index))
101       .set_height(y_size)
102       .set_width(1);
103 
104   batch_descs.scale_offset_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
105       .set_feature_map_count(batch_descs.input_desc.feature_map_count())
106       .set_height(1)
107       .set_width(1)
108       .set_count(1);
109 
110   return batch_descs;
111 }
112 
AssignCommonParams(const CudnnBatchNormConfig & config,CudnnBatchNormParamsCommon * params,const se::DeviceMemoryBase & operand,const se::DeviceMemory<float> & scale)113 void AssignCommonParams(const CudnnBatchNormConfig& config,
114                         CudnnBatchNormParamsCommon* params,
115                         const se::DeviceMemoryBase& operand,
116                         const se::DeviceMemory<float>& scale) {
117   // The BatchNormTraining HLO outputs a tuple of three elements: output data,
118   // batch mean, and batch variance.  We want to make our descriptors based on
119   // the shape of the output data. Batchnorm backward call outputs a tuple of
120   // three elements: grad data, grad offset, and grad scale.  We want to make
121   // our descriptors based on the shape of the grad data.
122   const Shape& shape = config.output_shape;
123   DnnBatchDescriptors batch_descs =
124       MakeBatchNormDescriptors(shape, config.feature_index);
125   params->operand_desc = batch_descs.input_desc;
126   params->scale_offset_desc = batch_descs.scale_offset_desc;
127   params->operand = operand;
128   params->scale = scale;
129   params->epsilon = config.epsilon;
130 }
131 
132 template <typename ElemType>
RunCudnnBatchNormForwardInferenceImpl(CudnnBatchNormForwardInferenceParams * params,se::Stream * stream)133 void RunCudnnBatchNormForwardInferenceImpl(
134     CudnnBatchNormForwardInferenceParams* params, se::Stream* stream) {
135   se::DeviceMemory<float> null_device_ptr(nullptr);
136   auto output_buf = se::DeviceMemory<ElemType>(params->output);
137   stream->ThenBatchNormalizationForward(
138       se::DeviceMemory<ElemType>(params->common.operand),
139       params->common.scale,                                         //
140       params->offset,                                               //
141       params->mean,                                                 //
142       params->variance,                                             //
143       /*side_input=*/null_device_ptr, params->common.operand_desc,  //
144       params->common.scale_offset_desc,                             //
145       static_cast<double>(params->common.epsilon),                  //
146       // TODO(b/137108598): Extend method to allow use of non-trivial
147       // exponential averaging.
148       /*exponential_average_factor=*/1.0,
149       se::dnn::ActivationMode::kNone,       //
150       &output_buf,                          //
151       /*batch_mean=*/nullptr,               //
152       /*batch_var=*/nullptr,                //
153       /*saved_mean=*/nullptr,               //
154       /*saved_inv_var=*/nullptr,            //
155       /*is_training=*/false,                //
156       /*reserve_space_allocator=*/nullptr,  //
157       /*workspace_allocator=*/nullptr);
158 }
159 
160 template <typename ElemType>
RunCudnnBatchNormForwardTrainingImpl(CudnnBatchNormForwardTrainingParams * params,se::Stream * stream)161 void RunCudnnBatchNormForwardTrainingImpl(
162     CudnnBatchNormForwardTrainingParams* params, se::Stream* stream) {
163   se::DeviceMemory<float> null_device_ptr(nullptr);
164   auto output_data = se::DeviceMemory<ElemType>(params->output_data);
165   stream->ThenBatchNormalizationForward(
166       se::DeviceMemory<ElemType>(params->common.operand),
167       params->common.scale,                    //
168       params->offset,                          //
169       /*estimated_mean=*/null_device_ptr,      //
170       /*estimated_variance=*/null_device_ptr,  //
171       /*side_input=*/null_device_ptr,          //
172       params->common.operand_desc,             //
173       params->common.scale_offset_desc,        //
174       params->common.epsilon,                  //
175       // TODO(b/137108598): Extend method to allow use of non-trivial
176       // exponential averaging.
177       /*exponential_average_factor=*/1.0,
178       se::dnn::ActivationMode::kNone,                //
179       &output_data,                                  //
180       /*batch_mean=*/&null_device_ptr,               //
181       /*batch_var=*/&null_device_ptr,                //
182       /*saved_mean=*/&params->output_mean,           //
183       /*saved_inv_var=*/&params->output_inv_stddev,  //
184       /*is_training=*/true,                          //
185       /*reserve_space_allocator=*/nullptr,           //
186       /*workspace_allocator=*/nullptr);
187 }
188 
189 template <typename ElemType>
RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams * params,se::Stream * stream)190 void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params,
191                                    se::Stream* stream) {
192   se::DeviceMemory<float> null_device_ptr(nullptr);
193   auto output_grad_data = se::DeviceMemory<ElemType>(params->output_grad_data);
194   stream->ThenBatchNormalizationBackward(
195       se::DeviceMemory<ElemType>(params->grad_output),     //
196       se::DeviceMemory<ElemType>(params->common.operand),  //
197       params->common.scale,                                //
198       params->mean,                                        //
199       params->inv_stddev,                                  //
200       params->common.operand_desc,                         //
201       params->common.scale_offset_desc,                    //
202       params->common.epsilon,                              //
203       &output_grad_data,                                   //
204       &params->output_grad_scale,                          //
205       &params->output_grad_offset,                         //
206       /*reserve_space_allocator=*/nullptr,                 //
207       /*workspace_allocator=*/nullptr);
208 }
209 
210 }  // namespace
211 
GetCudnnBatchNormConfig(const HloInstruction * instr,float epsilon,int64 feature_index)212 CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr,
213                                              float epsilon,
214                                              int64 feature_index) {
215   CudnnBatchNormConfig config;
216 
217   config.output_shape = instr->shape().IsTuple()
218                             ? instr->shape().tuple_shapes(0)
219                             : instr->shape();
220   config.output_type = config.output_shape.element_type();
221   config.epsilon = epsilon;
222   config.feature_index = feature_index;
223   return config;
224 }
225 
RunCudnnBatchNormForwardInference(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::DeviceMemory<float> mean,se::DeviceMemory<float> variance,se::Stream * stream)226 Status RunCudnnBatchNormForwardInference(
227     const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
228     se::DeviceMemoryBase output, se::DeviceMemory<float> scale,
229     se::DeviceMemory<float> offset, se::DeviceMemory<float> mean,
230     se::DeviceMemory<float> variance, se::Stream* stream) {
231   CudnnBatchNormForwardInferenceParams inference_params;
232   AssignCommonParams(config, &inference_params.common, operand, scale);
233   inference_params.offset = offset;
234   inference_params.mean = mean;
235   inference_params.variance = variance;
236   inference_params.output = output;
237 
238   switch (config.output_type) {
239     case F16:
240       RunCudnnBatchNormForwardInferenceImpl<Eigen::half>(&inference_params,
241                                                          stream);
242       break;
243     case F32:
244       RunCudnnBatchNormForwardInferenceImpl<float>(&inference_params, stream);
245       break;
246     default:
247       return Unimplemented(
248           "Primitive type %s not implemented for batchnorm forward inference",
249           primitive_util::LowercasePrimitiveTypeName(config.output_type)
250               .c_str());
251   }
252   return Status::OK();
253 }
254 
RunCudnnBatchNormForwardTraining(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_data,se::DeviceMemory<float> output_mean,se::DeviceMemory<float> output_inv_stddev,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::Stream * stream)255 Status RunCudnnBatchNormForwardTraining(
256     const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
257     se::DeviceMemoryBase output_data, se::DeviceMemory<float> output_mean,
258     se::DeviceMemory<float> output_inv_stddev, se::DeviceMemory<float> scale,
259     se::DeviceMemory<float> offset, se::Stream* stream) {
260   CudnnBatchNormForwardTrainingParams forward_params;
261   AssignCommonParams(config, &forward_params.common, operand, scale);
262   forward_params.offset = offset;
263   forward_params.output_data = output_data;
264   forward_params.output_mean = output_mean;
265   forward_params.output_inv_stddev = output_inv_stddev;
266 
267   switch (config.output_type) {
268     case F16:
269       RunCudnnBatchNormForwardTrainingImpl<Eigen::half>(&forward_params,
270                                                         stream);
271       break;
272     case F32:
273       RunCudnnBatchNormForwardTrainingImpl<float>(&forward_params, stream);
274       break;
275     default:
276       return Unimplemented(
277           "Primitive type %s not implemented for batchnorm forward training",
278           primitive_util::LowercasePrimitiveTypeName(config.output_type)
279               .c_str());
280   }
281   return Status::OK();
282 }
283 
RunCudnnBatchNormBackward(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_grad_data,se::DeviceMemoryBase grad_output,se::DeviceMemory<float> output_grad_scale,se::DeviceMemory<float> output_grad_offset,se::DeviceMemory<float> scale,se::DeviceMemory<float> mean,se::DeviceMemory<float> inv_stddev,se::Stream * stream)284 Status RunCudnnBatchNormBackward(
285     const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
286     se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output,
287     se::DeviceMemory<float> output_grad_scale,
288     se::DeviceMemory<float> output_grad_offset, se::DeviceMemory<float> scale,
289     se::DeviceMemory<float> mean, se::DeviceMemory<float> inv_stddev,
290     se::Stream* stream) {
291   CudnnBatchNormBackwardParams backward_params;
292   AssignCommonParams(config, &backward_params.common, operand, scale);
293   backward_params.output_grad_data = output_grad_data;
294   backward_params.grad_output = grad_output;
295   backward_params.output_grad_scale = output_grad_scale;
296   backward_params.output_grad_offset = output_grad_offset;
297   backward_params.mean = mean;
298   backward_params.inv_stddev = inv_stddev;
299 
300   switch (config.output_type) {
301     case F16:
302       RunCudnnBatchNormBackwardImpl<Eigen::half>(&backward_params, stream);
303       break;
304     case F32:
305       RunCudnnBatchNormBackwardImpl<float>(&backward_params, stream);
306       break;
307     default:
308       return Unimplemented(
309           "Primitive type %s not implemented for batchnorm backward",
310           primitive_util::LowercasePrimitiveTypeName(config.output_type)
311               .c_str());
312   }
313   return Status::OK();
314 }
315 
316 }  // namespace gpu
317 }  // namespace xla
318