1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27
28 namespace xla {
29 namespace gpu {
30
31 namespace dnn = se::dnn;
32
33 static std::pair<dnn::BatchDescriptor /*input_desc*/,
34 dnn::BatchDescriptor /*scale_offset_desc*/>
MakeDescriptors(const Shape & shape,int64 feature_index)35 MakeDescriptors(const Shape& shape, int64 feature_index) {
36 std::vector<int64> logical_to_physical =
37 LayoutUtil::MakeLogicalToPhysical(shape.layout());
38
39 auto physical_dim_size = [&](int64 physical_dim) {
40 return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim));
41 };
42
43 // Batchnorm only cares about the location of the depth (aka "feature") dim.
44 // The other dims are all treated the same. Thus we can use the kBatchDepthYX
45 // cudnn layout for any XLA shape+layout, even XLA shapes that don't have
46 // exactly 4 dimensions: We put everything that comes before the feature dim
47 // into "batch", and everything that comes after the feature dim into "Y".
48 int64 batch_size = 1;
49 int64 y_size = 1;
50 int64 physical_dim;
51 for (physical_dim = 0; physical_dim != logical_to_physical[feature_index];
52 ++physical_dim) {
53 CHECK_LT(physical_dim, shape.dimensions_size());
54 batch_size *= physical_dim_size(physical_dim);
55 }
56 ++physical_dim; // Skip the feature dimension.
57 for (; physical_dim < shape.dimensions_size(); ++physical_dim) {
58 y_size *= physical_dim_size(physical_dim);
59 }
60
61 dnn::BatchDescriptor input_desc;
62 input_desc.set_layout(dnn::DataLayout::kBatchDepthYX)
63 .set_count(batch_size)
64 .set_feature_map_count(shape.dimensions(feature_index))
65 .set_height(y_size)
66 .set_width(1);
67
68 dnn::BatchDescriptor scale_offset_desc;
69 scale_offset_desc.set_layout(dnn::DataLayout::kBatchDepthYX)
70 .set_feature_map_count(input_desc.feature_map_count())
71 .set_height(1)
72 .set_width(1)
73 .set_count(1);
74
75 return std::make_pair(input_desc, scale_offset_desc);
76 }
77
CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & variance,float epsilon,int64 feature_index,const BufferAllocation::Slice & output,const HloInstruction * hlo)78 CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
79 const BufferAllocation::Slice& operand,
80 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
81 const BufferAllocation::Slice& mean,
82 const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
83 const BufferAllocation::Slice& output, const HloInstruction* hlo)
84 : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo),
85 operand_(operand),
86 scale_(scale),
87 offset_(offset),
88 mean_(mean),
89 variance_(variance),
90 epsilon_(epsilon),
91 feature_index_(feature_index),
92 output_(output) {
93 CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
94 CHECK_EQ(hlo->custom_call_target(),
95 kCudnnBatchNormForwardInferenceCallTarget);
96 CHECK(
97 LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape()));
98 CHECK_EQ(hlo->shape().element_type(), F32) << "Not yet implemented";
99 }
100
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)101 Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
102 const BufferAllocations& buffer_allocations, se::Stream* stream,
103 HloExecutionProfiler* profiler) {
104 dnn::BatchDescriptor operand_desc;
105 dnn::BatchDescriptor scale_offset_desc;
106 std::tie(operand_desc, scale_offset_desc) =
107 MakeDescriptors(hlo_instruction()->shape(), feature_index_);
108
109 se::DeviceMemory<float> output(buffer_allocations.GetDeviceAddress(output_));
110 auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
111 stream->ThenBatchNormalizationForward(
112 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
113 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
114 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
115 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
116 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(variance_)),
117 operand_desc, //
118 scale_offset_desc, //
119 epsilon_, //
120 &output, //
121 /*batch_mean=*/nullptr, //
122 /*batch_var=*/nullptr, //
123 /*saved_mean=*/nullptr, //
124 /*saved_inv_var=*/nullptr, //
125 /*is_training=*/false, //
126 /*var_to_inv_var=*/nullptr, //
127 /*inv_var_to_var=*/nullptr);
128
129 if (!stream->ok()) {
130 return InternalError("BatchNormalizationForward call failed.");
131 }
132 return Status::OK();
133 }
134
CudnnBatchNormForwardTrainingThunk(const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,float epsilon,int64 feature_index,const BufferAllocation::Slice & output_data,const BufferAllocation::Slice & output_mean,const BufferAllocation::Slice & output_inv_stddev,const BufferAllocation::Slice & output_tuple,const HloInstruction * hlo)135 CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
136 const BufferAllocation::Slice& operand,
137 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
138 float epsilon, int64 feature_index,
139 const BufferAllocation::Slice& output_data,
140 const BufferAllocation::Slice& output_mean,
141 const BufferAllocation::Slice& output_inv_stddev,
142 const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
143 : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo),
144 operand_(operand),
145 scale_(scale),
146 offset_(offset),
147 epsilon_(epsilon),
148 feature_index_(feature_index),
149 output_data_(output_data),
150 output_mean_(output_mean),
151 output_inv_stddev_(output_inv_stddev),
152 output_tuple_(output_tuple) {
153 CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
154 CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
155 CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
156 CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
157 hlo->operand(0)->shape()));
158 for (const auto& tuple_shape : hlo->shape().tuple_shapes()) {
159 CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented";
160 }
161 }
162
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)163 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
164 const BufferAllocations& buffer_allocations, se::Stream* stream,
165 HloExecutionProfiler* profiler) {
166 dnn::BatchDescriptor operand_desc;
167 dnn::BatchDescriptor scale_offset_desc;
168 // The BatchNormTraining HLO outputs a tuple of three elements: output data,
169 // batch mean, and batch variance. We want to make our descriptors based on
170 // the shape of the output data.
171 std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(
172 hlo_instruction()->shape().tuple_shapes(0), feature_index_);
173
174 se::DeviceMemory<float> output_data(
175 buffer_allocations.GetDeviceAddress(output_data_));
176 se::DeviceMemory<float> output_mean(
177 buffer_allocations.GetDeviceAddress(output_mean_));
178 se::DeviceMemory<float> output_inv_stddev(
179 buffer_allocations.GetDeviceAddress(output_inv_stddev_));
180
181 se::DeviceMemory<float> null_device_ptr(nullptr);
182 auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
183 stream->ThenBatchNormalizationForward(
184 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
185 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
186 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
187 /*estimated_mean=*/null_device_ptr,
188 /*estimated_variance=*/null_device_ptr,
189 operand_desc, //
190 scale_offset_desc, //
191 epsilon_, //
192 &output_data, //
193 /*batch_mean=*/&null_device_ptr, //
194 /*batch_var=*/&null_device_ptr, //
195 /*saved_mean=*/&output_mean, //
196 /*saved_inv_var=*/&output_inv_stddev, //
197 /*is_training=*/true, //
198 /*var_to_inv_var=*/nullptr, //
199 /*inv_var_to_var=*/nullptr);
200
201 // Write the tuple.
202 void* ptrs[] = {output_data.opaque(), output_mean.opaque(),
203 output_inv_stddev.opaque()};
204 se::DeviceMemory<void*> tuple_addr(
205 buffer_allocations.GetDeviceAddress(output_tuple_));
206 stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
207
208 if (!stream->ok()) {
209 return InternalError("BatchNormalizationTraining call failed.");
210 }
211 return Status::OK();
212 }
213
CudnnBatchNormBackwardThunk(const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & inv_stddev,const BufferAllocation::Slice & grad_output,float epsilon,int64 feature_index,const BufferAllocation::Slice & output_grad_data,const BufferAllocation::Slice & output_grad_scale,const BufferAllocation::Slice & output_grad_offset,const BufferAllocation::Slice & output_tuple,const HloInstruction * hlo)214 CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
215 const BufferAllocation::Slice& operand,
216 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
217 const BufferAllocation::Slice& inv_stddev,
218 const BufferAllocation::Slice& grad_output, float epsilon,
219 int64 feature_index, const BufferAllocation::Slice& output_grad_data,
220 const BufferAllocation::Slice& output_grad_scale,
221 const BufferAllocation::Slice& output_grad_offset,
222 const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
223 : Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo),
224 operand_(operand),
225 scale_(scale),
226 mean_(mean),
227 inv_stddev_(inv_stddev),
228 grad_output_(grad_output),
229 epsilon_(epsilon),
230 feature_index_(feature_index),
231 output_grad_data_(output_grad_data),
232 output_grad_scale_(output_grad_scale),
233 output_grad_offset_(output_grad_offset),
234 output_tuple_(output_tuple) {
235 CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
236 CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
237 CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
238 CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
239 hlo->operand(0)->shape()));
240 CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
241 hlo->operand(4)->shape()));
242 for (const auto& tuple_shape : hlo->shape().tuple_shapes()) {
243 CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented";
244 }
245 }
246
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)247 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
248 const BufferAllocations& buffer_allocations, se::Stream* stream,
249 HloExecutionProfiler* profiler) {
250 dnn::BatchDescriptor operand_desc;
251 dnn::BatchDescriptor scale_offset_desc;
252
253 // This call outputs a tuple of three elements: grad data, grad offset, and
254 // grad scale. We want to make our descriptors based on the shape of the grad
255 // data.
256 std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(
257 hlo_instruction()->shape().tuple_shapes(0), feature_index_);
258
259 se::DeviceMemory<float> output_grad_data(
260 buffer_allocations.GetDeviceAddress(output_grad_data_));
261 se::DeviceMemory<float> output_grad_scale(
262 buffer_allocations.GetDeviceAddress(output_grad_scale_));
263 se::DeviceMemory<float> output_grad_offset(
264 buffer_allocations.GetDeviceAddress(output_grad_offset_));
265
266 auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
267 stream->ThenBatchNormalizationBackward(
268 se::DeviceMemory<float>(
269 buffer_allocations.GetDeviceAddress(grad_output_)),
270 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
271 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
272 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
273 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
274 operand_desc, scale_offset_desc, epsilon_, &output_grad_data,
275 &output_grad_scale, &output_grad_offset);
276
277 // Write the output tuple.
278 void* ptrs[] = {output_grad_data.opaque(), output_grad_scale.opaque(),
279 output_grad_offset.opaque()};
280 se::DeviceMemory<void*> tuple_addr(
281 buffer_allocations.GetDeviceAddress(output_tuple_));
282 stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
283
284 if (!stream->ok()) {
285 return InternalError("BatchNormalizationBackward call failed.");
286 }
287 return Status::OK();
288 }
289
290 } // namespace gpu
291 } // namespace xla
292