1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 namespace dnn = se::dnn;
32 
33 static std::pair<dnn::BatchDescriptor /*input_desc*/,
34                  dnn::BatchDescriptor /*scale_offset_desc*/>
MakeDescriptors(const Shape & shape,int64 feature_index)35 MakeDescriptors(const Shape& shape, int64 feature_index) {
36   std::vector<int64> logical_to_physical =
37       LayoutUtil::MakeLogicalToPhysical(shape.layout());
38 
39   auto physical_dim_size = [&](int64 physical_dim) {
40     return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim));
41   };
42 
43   // Batchnorm only cares about the location of the depth (aka "feature") dim.
44   // The other dims are all treated the same.  Thus we can use the kBatchDepthYX
45   // cudnn layout for any XLA shape+layout, even XLA shapes that don't have
46   // exactly 4 dimensions: We put everything that comes before the feature dim
47   // into "batch", and everything that comes after the feature dim into "Y".
48   int64 batch_size = 1;
49   int64 y_size = 1;
50   int64 physical_dim;
51   for (physical_dim = 0; physical_dim != logical_to_physical[feature_index];
52        ++physical_dim) {
53     CHECK_LT(physical_dim, shape.dimensions_size());
54     batch_size *= physical_dim_size(physical_dim);
55   }
56   ++physical_dim;  // Skip the feature dimension.
57   for (; physical_dim < shape.dimensions_size(); ++physical_dim) {
58     y_size *= physical_dim_size(physical_dim);
59   }
60 
61   dnn::BatchDescriptor input_desc;
62   input_desc.set_layout(dnn::DataLayout::kBatchDepthYX)
63       .set_count(batch_size)
64       .set_feature_map_count(shape.dimensions(feature_index))
65       .set_height(y_size)
66       .set_width(1);
67 
68   dnn::BatchDescriptor scale_offset_desc;
69   scale_offset_desc.set_layout(dnn::DataLayout::kBatchDepthYX)
70       .set_feature_map_count(input_desc.feature_map_count())
71       .set_height(1)
72       .set_width(1)
73       .set_count(1);
74 
75   return std::make_pair(input_desc, scale_offset_desc);
76 }
77 
CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & variance,float epsilon,int64 feature_index,const BufferAllocation::Slice & output,const HloInstruction * hlo)78 CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
79     const BufferAllocation::Slice& operand,
80     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
81     const BufferAllocation::Slice& mean,
82     const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
83     const BufferAllocation::Slice& output, const HloInstruction* hlo)
84     : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo),
85       operand_(operand),
86       scale_(scale),
87       offset_(offset),
88       mean_(mean),
89       variance_(variance),
90       epsilon_(epsilon),
91       feature_index_(feature_index),
92       output_(output) {
93   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
94   CHECK_EQ(hlo->custom_call_target(),
95            kCudnnBatchNormForwardInferenceCallTarget);
96   CHECK(
97       LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape()));
98   CHECK_EQ(hlo->shape().element_type(), F32) << "Not yet implemented";
99 }
100 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)101 Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
102     const BufferAllocations& buffer_allocations, se::Stream* stream,
103     HloExecutionProfiler* profiler) {
104   dnn::BatchDescriptor operand_desc;
105   dnn::BatchDescriptor scale_offset_desc;
106   std::tie(operand_desc, scale_offset_desc) =
107       MakeDescriptors(hlo_instruction()->shape(), feature_index_);
108 
109   se::DeviceMemory<float> output(buffer_allocations.GetDeviceAddress(output_));
110   auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
111   stream->ThenBatchNormalizationForward(
112       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
113       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
114       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
115       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
116       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(variance_)),
117       operand_desc,                //
118       scale_offset_desc,           //
119       epsilon_,                    //
120       &output,                     //
121       /*batch_mean=*/nullptr,      //
122       /*batch_var=*/nullptr,       //
123       /*saved_mean=*/nullptr,      //
124       /*saved_inv_var=*/nullptr,   //
125       /*is_training=*/false,       //
126       /*var_to_inv_var=*/nullptr,  //
127       /*inv_var_to_var=*/nullptr);
128 
129   if (!stream->ok()) {
130     return InternalError("BatchNormalizationForward call failed.");
131   }
132   return Status::OK();
133 }
134 
CudnnBatchNormForwardTrainingThunk(const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,float epsilon,int64 feature_index,const BufferAllocation::Slice & output_data,const BufferAllocation::Slice & output_mean,const BufferAllocation::Slice & output_inv_stddev,const BufferAllocation::Slice & output_tuple,const HloInstruction * hlo)135 CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
136     const BufferAllocation::Slice& operand,
137     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
138     float epsilon, int64 feature_index,
139     const BufferAllocation::Slice& output_data,
140     const BufferAllocation::Slice& output_mean,
141     const BufferAllocation::Slice& output_inv_stddev,
142     const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
143     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo),
144       operand_(operand),
145       scale_(scale),
146       offset_(offset),
147       epsilon_(epsilon),
148       feature_index_(feature_index),
149       output_data_(output_data),
150       output_mean_(output_mean),
151       output_inv_stddev_(output_inv_stddev),
152       output_tuple_(output_tuple) {
153   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
154   CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
155   CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
156   CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
157                                          hlo->operand(0)->shape()));
158   for (const auto& tuple_shape : hlo->shape().tuple_shapes()) {
159     CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented";
160   }
161 }
162 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)163 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
164     const BufferAllocations& buffer_allocations, se::Stream* stream,
165     HloExecutionProfiler* profiler) {
166   dnn::BatchDescriptor operand_desc;
167   dnn::BatchDescriptor scale_offset_desc;
168   // The BatchNormTraining HLO outputs a tuple of three elements: output data,
169   // batch mean, and batch variance.  We want to make our descriptors based on
170   // the shape of the output data.
171   std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(
172       hlo_instruction()->shape().tuple_shapes(0), feature_index_);
173 
174   se::DeviceMemory<float> output_data(
175       buffer_allocations.GetDeviceAddress(output_data_));
176   se::DeviceMemory<float> output_mean(
177       buffer_allocations.GetDeviceAddress(output_mean_));
178   se::DeviceMemory<float> output_inv_stddev(
179       buffer_allocations.GetDeviceAddress(output_inv_stddev_));
180 
181   se::DeviceMemory<float> null_device_ptr(nullptr);
182   auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
183   stream->ThenBatchNormalizationForward(
184       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
185       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
186       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
187       /*estimated_mean=*/null_device_ptr,
188       /*estimated_variance=*/null_device_ptr,
189       operand_desc,                          //
190       scale_offset_desc,                     //
191       epsilon_,                              //
192       &output_data,                          //
193       /*batch_mean=*/&null_device_ptr,       //
194       /*batch_var=*/&null_device_ptr,        //
195       /*saved_mean=*/&output_mean,           //
196       /*saved_inv_var=*/&output_inv_stddev,  //
197       /*is_training=*/true,                  //
198       /*var_to_inv_var=*/nullptr,            //
199       /*inv_var_to_var=*/nullptr);
200 
201   // Write the tuple.
202   void* ptrs[] = {output_data.opaque(), output_mean.opaque(),
203                   output_inv_stddev.opaque()};
204   se::DeviceMemory<void*> tuple_addr(
205       buffer_allocations.GetDeviceAddress(output_tuple_));
206   stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
207 
208   if (!stream->ok()) {
209     return InternalError("BatchNormalizationTraining call failed.");
210   }
211   return Status::OK();
212 }
213 
CudnnBatchNormBackwardThunk(const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & inv_stddev,const BufferAllocation::Slice & grad_output,float epsilon,int64 feature_index,const BufferAllocation::Slice & output_grad_data,const BufferAllocation::Slice & output_grad_scale,const BufferAllocation::Slice & output_grad_offset,const BufferAllocation::Slice & output_tuple,const HloInstruction * hlo)214 CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
215     const BufferAllocation::Slice& operand,
216     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
217     const BufferAllocation::Slice& inv_stddev,
218     const BufferAllocation::Slice& grad_output, float epsilon,
219     int64 feature_index, const BufferAllocation::Slice& output_grad_data,
220     const BufferAllocation::Slice& output_grad_scale,
221     const BufferAllocation::Slice& output_grad_offset,
222     const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
223     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo),
224       operand_(operand),
225       scale_(scale),
226       mean_(mean),
227       inv_stddev_(inv_stddev),
228       grad_output_(grad_output),
229       epsilon_(epsilon),
230       feature_index_(feature_index),
231       output_grad_data_(output_grad_data),
232       output_grad_scale_(output_grad_scale),
233       output_grad_offset_(output_grad_offset),
234       output_tuple_(output_tuple) {
235   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
236   CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
237   CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
238   CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
239                                          hlo->operand(0)->shape()));
240   CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
241                                          hlo->operand(4)->shape()));
242   for (const auto& tuple_shape : hlo->shape().tuple_shapes()) {
243     CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented";
244   }
245 }
246 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)247 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
248     const BufferAllocations& buffer_allocations, se::Stream* stream,
249     HloExecutionProfiler* profiler) {
250   dnn::BatchDescriptor operand_desc;
251   dnn::BatchDescriptor scale_offset_desc;
252 
253   // This call outputs a tuple of three elements: grad data, grad offset, and
254   // grad scale.  We want to make our descriptors based on the shape of the grad
255   // data.
256   std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(
257       hlo_instruction()->shape().tuple_shapes(0), feature_index_);
258 
259   se::DeviceMemory<float> output_grad_data(
260       buffer_allocations.GetDeviceAddress(output_grad_data_));
261   se::DeviceMemory<float> output_grad_scale(
262       buffer_allocations.GetDeviceAddress(output_grad_scale_));
263   se::DeviceMemory<float> output_grad_offset(
264       buffer_allocations.GetDeviceAddress(output_grad_offset_));
265 
266   auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
267   stream->ThenBatchNormalizationBackward(
268       se::DeviceMemory<float>(
269           buffer_allocations.GetDeviceAddress(grad_output_)),
270       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
271       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
272       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
273       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
274       operand_desc, scale_offset_desc, epsilon_, &output_grad_data,
275       &output_grad_scale, &output_grad_offset);
276 
277   // Write the output tuple.
278   void* ptrs[] = {output_grad_data.opaque(), output_grad_scale.opaque(),
279                   output_grad_offset.opaque()};
280   se::DeviceMemory<void*> tuple_addr(
281       buffer_allocations.GetDeviceAddress(output_tuple_));
282   stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
283 
284   if (!stream->ok()) {
285     return InternalError("BatchNormalizationBackward call failed.");
286   }
287   return Status::OK();
288 }
289 
290 }  // namespace gpu
291 }  // namespace xla
292