1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17
18 #include <functional>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
25 #include "tensorflow/core/platform/types.h"
26
27 namespace xla {
28 namespace gpu {
29
30 namespace {
31
32 // This struct contains the metadata of a matrix, e.g., its base address and
33 // dimensions.
34 struct MatrixDescriptor {
MatrixDescriptorxla::gpu::__anon934738310111::MatrixDescriptor35 MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
36 int64 matrix_num_rows, int64 matrix_num_cols,
37 int64 matrix_batch_size)
38 : data(matrix_data),
39 transpose(needs_transpose),
40 num_rows(matrix_num_rows),
41 num_cols(matrix_num_cols),
42 batch_size(matrix_batch_size) {}
43
44 se::DeviceMemoryBase data;
45 bool transpose; // Whether this matrix needs to be transposed.
46 int64 num_rows;
47 int64 num_cols;
48 int64 batch_size;
49 };
50
51 // Performs a gemm call without an explicit algorithm on lhs_matrix and
52 // rhs_matrix, and stores the result to output_matrix.
53 template <typename Element>
DoGemm(MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,double alpha,double beta,se::Stream * stream)54 bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
55 MatrixDescriptor output_matrix, double alpha, double beta,
56 se::Stream* stream) {
57 DCHECK(!output_matrix.transpose);
58
59 const int64 batch_size = lhs_matrix.batch_size;
60 CHECK_EQ(batch_size, rhs_matrix.batch_size);
61 CHECK_EQ(batch_size, output_matrix.batch_size);
62 se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
63 se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
64 se::DeviceMemory<Element> output_data(output_matrix.data);
65
66 auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
67 : se::blas::Transpose::kNoTranspose;
68 auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
69 : se::blas::Transpose::kNoTranspose;
70 auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
71
72 if (batch_size == 1) {
73 return stream
74 ->ThenBlasGemm(
75 lhs_transpose, rhs_transpose, output_matrix.num_rows,
76 output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
77 lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
78 /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
79 &output_data, /*leading dim of output=*/output_matrix.num_rows)
80 .ok();
81 }
82
83 int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
84 int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
85 int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
86 return stream
87 ->ThenBlasGemmStridedBatched(
88 lhs_transpose, rhs_transpose, output_matrix.num_rows,
89 output_matrix.num_cols, /*size of reduce dim=*/k,
90 /*alpha=*/alpha, lhs_data,
91 /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
92 /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
93 /*beta=*/beta, &output_data,
94 /*leading dim of output=*/output_matrix.num_rows, output_stride,
95 batch_size)
96 .ok();
97 }
98
99 // Like DoGemm, but takes an explicit computation type and algorithm.
100 // computation_type specifies the type of intermediate values generated during
101 // the matmul (e.g. your input/output matricies could be f16s but you could do
102 // computations with f32s). algorithm is an opaque identifier which functions
103 // as a hint to cublas.
104 //
105 // Not all algorithms are valid for all matrix sizes, and not all CUDA versions
106 // and GPUs even support gemm-with-algorithm. So expect that this may fail
107 // unless you've already checked that it works for this particular GPU + input
108 // size.
109 //
110 // If you pass a non-null ProfileResult, this will always return true (assuming
111 // the Stream was valid to begin with); check the is_valid property of the
112 // ProfileResult to see whether the call actually succeeded.
113 template <typename Element>
DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,double alpha,double beta,se::blas::ComputationType computation_type,se::blas::AlgorithmType algorithm,se::Stream * stream,se::blas::ProfileResult * output_profile_result)114 bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
115 MatrixDescriptor rhs_matrix,
116 MatrixDescriptor output_matrix, double alpha,
117 double beta,
118 se::blas::ComputationType computation_type,
119 se::blas::AlgorithmType algorithm, se::Stream* stream,
120 se::blas::ProfileResult* output_profile_result) {
121 DCHECK(!output_matrix.transpose);
122
123 CHECK_EQ(1, lhs_matrix.batch_size);
124 CHECK_EQ(1, rhs_matrix.batch_size);
125 CHECK_EQ(1, output_matrix.batch_size);
126
127 se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
128 se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
129 se::DeviceMemory<Element> output_data(output_matrix.data);
130
131 auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
132 : se::blas::Transpose::kNoTranspose;
133 auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
134 : se::blas::Transpose::kNoTranspose;
135 auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
136
137 return stream
138 ->ThenBlasGemmWithAlgorithm(
139 lhs_transpose, rhs_transpose, output_matrix.num_rows,
140 output_matrix.num_cols, /*size of reduce dim=*/k,
141 /*alpha=*/static_cast<Element>(alpha), lhs_data,
142 /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
143 /*leading dim of RHS=*/rhs_matrix.num_rows,
144 /*beta=*/static_cast<Element>(beta), &output_data,
145 /*leading dim of output=*/output_matrix.num_rows, computation_type,
146 algorithm, output_profile_result)
147 .ok();
148 }
149
150 // Experimentally tries to pick the best algorithm for the given gemm.
151 //
152 // This may fail under perfectly normal circumstances. In particular, it will
153 // fail if the program was built with < CUDA 8 or if we're using a gpu older
154 // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
155 // all.
156 template <typename Element>
DoGemmAutotune(MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,double alpha,double beta,se::blas::ComputationType computation_type,se::Stream * stream)157 StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
158 MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
159 MatrixDescriptor output_matrix, double alpha, double beta,
160 se::blas::ComputationType computation_type, se::Stream* stream) {
161 std::vector<se::blas::AlgorithmType> algorithms;
162 CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
163
164 se::blas::ProfileResult best_result;
165 for (auto algorithm : algorithms) {
166 se::blas::ProfileResult profile_result;
167 // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
168 // for all algorithms if we're targeting < sm_50. But because we pass a
169 // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
170 // and the actual success-ness is returned in ProfileResult::is_valid.
171 CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
172 alpha, beta, computation_type, algorithm,
173 stream, &profile_result));
174
175 if (profile_result.is_valid()) {
176 VLOG(3) << "cublas gemm algorithm " << algorithm << " took "
177 << profile_result.elapsed_time_in_ms() << "ms";
178 if (profile_result.elapsed_time_in_ms() <
179 best_result.elapsed_time_in_ms()) {
180 best_result = profile_result;
181 }
182 } else {
183 VLOG(4) << "cublas gemm algorithm " << algorithm << " failed.";
184 }
185 }
186
187 if (best_result.is_valid()) {
188 return best_result.algorithm();
189 }
190
191 return InternalError(
192 "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms "
193 "ran successfully",
194 stream, algorithms.size());
195 }
196
197 // Helper functions to go from a PrimitiveType to a templated version of
198 // DoGemm/DoGemmWithAlgorithm/DoGemmAutotune.
GetGemmFn(PrimitiveType type)199 auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
200 switch (type) {
201 case F16:
202 return &DoGemm<Eigen::half>;
203 case F32:
204 return &DoGemm<float>;
205 case F64:
206 return &DoGemm<double>;
207 case C64:
208 return &DoGemm<std::complex<float>>;
209 case C128:
210 return &DoGemm<std::complex<double>>;
211 default:
212 LOG(FATAL) << "Unsupported type.";
213 }
214 }
GetGemmWithAlgorithmFn(PrimitiveType type)215 auto GetGemmWithAlgorithmFn(PrimitiveType type)
216 -> decltype(&DoGemmWithAlgorithm<float>) {
217 switch (type) {
218 case F16:
219 return &DoGemmWithAlgorithm<Eigen::half>;
220 case F32:
221 return &DoGemmWithAlgorithm<float>;
222 case F64:
223 return &DoGemmWithAlgorithm<double>;
224 case C64:
225 return &DoGemmWithAlgorithm<std::complex<float>>;
226 case C128:
227 return &DoGemmWithAlgorithm<std::complex<double>>;
228 default:
229 LOG(FATAL) << "Unsupported type.";
230 }
231 }
GetGemmAutotuneFn(PrimitiveType type)232 auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
233 switch (type) {
234 case F16:
235 return &DoGemmAutotune<Eigen::half>;
236 case F32:
237 return &DoGemmAutotune<float>;
238 case F64:
239 return &DoGemmAutotune<double>;
240 case C64:
241 return &DoGemmAutotune<std::complex<float>>;
242 case C128:
243 return &DoGemmAutotune<std::complex<double>>;
244 default:
245 LOG(FATAL) << "Unsupported type.";
246 }
247 }
248
249 // Converts from an XLA PrimitiveType to a blas::ComputationType, which is used
250 // to specify the precision with which matmul computations should be performed,
251 // separately from the precision of the inputs and result.
GetBlasComputationType(PrimitiveType type)252 se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
253 switch (type) {
254 case F16:
255 // Use F32 as computation type for F16 as we currently only implement the
256 // cuDNN pseudo half configuration for half precision.
257 return se::blas::ComputationType::kF32;
258 case F32:
259 return se::blas::ComputationType::kF32;
260 case F64:
261 return se::blas::ComputationType::kF64;
262 case C64:
263 return se::blas::ComputationType::kComplexF32;
264 case C128:
265 return se::blas::ComputationType::kComplexF64;
266 default:
267 LOG(FATAL) << "Unsupported type.";
268 }
269 }
270
GetDimensionNumbers(const HloInstruction & hlo_instruction)271 DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
272 if (hlo_instruction.opcode() == HloOpcode::kDot) {
273 return hlo_instruction.dot_dimension_numbers();
274 }
275 CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
276 CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
277 CHECK(hlo_instruction.fused_expression_root()->opcode() == HloOpcode::kAdd ||
278 hlo_instruction.fused_expression_root()->opcode() ==
279 HloOpcode::kMultiply);
280 // Try to find the dot inside the output fusion node.
281 const HloInstruction* dot =
282 hlo_instruction.fused_expression_root()->operand(0);
283 if (dot->opcode() != HloOpcode::kDot) {
284 dot = hlo_instruction.fused_expression_root()->operand(1);
285 }
286 CHECK_EQ(dot->opcode(), HloOpcode::kDot);
287
288 return dot->dot_dimension_numbers();
289 }
290
291 } // namespace
292
GemmThunk(const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,double alpha,double beta,const HloInstruction * hlo_instruction,bool implements_whole_instruction)293 GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
294 const BufferAllocation::Slice& rhs_buffer,
295 const BufferAllocation::Slice& output_buffer,
296 const Shape& lhs_shape, const Shape& rhs_shape,
297 const Shape& output_shape, double alpha, double beta,
298 const HloInstruction* hlo_instruction,
299 bool implements_whole_instruction)
300 : Thunk(Kind::kGemm, hlo_instruction),
301 lhs_buffer_(lhs_buffer),
302 rhs_buffer_(rhs_buffer),
303 output_buffer_(output_buffer),
304 lhs_shape_(lhs_shape),
305 rhs_shape_(rhs_shape),
306 output_shape_(output_shape),
307 alpha_(alpha),
308 beta_(beta),
309 implements_whole_instruction_(implements_whole_instruction) {}
310
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)311 Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
312 se::Stream* stream,
313 HloExecutionProfiler* profiler) {
314 VLOG(2) << "Executing a GemmThunk";
315
316 se::DeviceMemoryBase lhs_data =
317 buffer_allocations.GetDeviceAddress(lhs_buffer_);
318 se::DeviceMemoryBase rhs_data =
319 buffer_allocations.GetDeviceAddress(rhs_buffer_);
320 se::DeviceMemoryBase output_data =
321 buffer_allocations.GetDeviceAddress(output_buffer_);
322
323 DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
324 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
325 dim_nums.rhs_batch_dimensions_size());
326 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
327
328 int64 row_dim = dim_nums.lhs_batch_dimensions_size();
329 int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
330 int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
331 output_shape_.dimensions().end() - 2, 1,
332 std::multiplies<int64>());
333
334 // Check that the batch dims don't cover the last two dims.
335 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
336 CHECK_NE(row_dim, batch_dim);
337 CHECK_NE(col_dim, batch_dim);
338 }
339
340 // Verify that the non-batch dimensions are minor-most. This is required for
341 // efficient access.
342 for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
343 CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
344 CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
345 }
346
347 // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
348 // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
349 // their layout. Therefore, we should treat dimension 0 as row and dimension 1
350 // as column when mapping a matrix Dot to BLAS gemm.
351 int64 output_num_rows = output_shape_.dimensions(row_dim);
352 int64 output_num_cols = output_shape_.dimensions(col_dim);
353
354 // BLAS gemm expects the inputs and the output are in column-major order.
355 // Therefore, we need to convert dot between row-major matrices to that
356 // between column-major matrices. The key insight for the conversion is that,
357 // in linear storage, matrix M in column-major order is identical to the
358 // transpose of M in row-major order. In other words,
359 //
360 // column-major(M) = row-major(M^T).
361 //
362 // Leveraging this insight, we can perform dot between row-major matrices as
363 // follows.
364 //
365 // row-major(C)
366 // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
367 // = gemm(column-major(B^T), column-major(A^T))
368 // = gemm(row-major(B), row-major(A))
369 //
370 // Although we do not modify the content of A and B in linear memory, we
371 // should use the dimensions of B^T and A^T when calling gemm. For example,
372 // the leading dimension of the LHS matrix of gemm is the number of rows in
373 // B^T and thus the number of columns in B.
374
375 auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
376 bool transpose) -> MatrixDescriptor {
377 bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
378 bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
379 LayoutUtil::Minor(output_shape_.layout(), row_dim);
380 return MatrixDescriptor(
381 data, transpose ^ layout_mismatch,
382 shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
383 shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
384 batch_size);
385 };
386
387 const MatrixDescriptor lhs_descriptor = make_descriptor(
388 lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
389 const MatrixDescriptor rhs_descriptor = make_descriptor(
390 rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
391
392 // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
393 // autotune this gemm to figure out the best algorithm.
394 auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
395 MatrixDescriptor output_matrix, se::Stream* stream) {
396 PrimitiveType element_type = output_shape_.element_type();
397 se::blas::ComputationType computation_type =
398 GetBlasComputationType(element_type);
399
400 // TODO(b/112111608): Implement auto tune for batched gemm.
401 if (batch_size != 1) {
402 return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
403 alpha_, beta_, stream);
404 }
405
406 auto thunk_name = [&] {
407 return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
408 : "<null>";
409 };
410
411 const string& device_name = stream->parent()->GetDeviceDescription().name();
412 auto autotune_it = autotune_results_.find(device_name);
413 if (autotune_it == autotune_results_.end()) {
414 VLOG(3) << "Starting autotune of GemmThunk " << thunk_name();
415
416 // If the output buffer already contains a bias then autotune into a
417 // scratch buffer. This avoids overwriting the bias buffer. The scratch
418 // buffer may contain arbitrary garbage values.
419 se::DeviceMemoryBase scratch_data = output_data;
420 std::unique_ptr<se::TemporaryDeviceMemory<char>> scratch_mem;
421 if (beta_ != 0.0) {
422 auto temp_status = stream->AllocateTemporaryArray<char>(
423 ShapeUtil::ByteSizeOf(output_shape_));
424 if (!temp_status.ok()) {
425 return false;
426 }
427 scratch_mem = std::move(temp_status).ValueOrDie();
428 scratch_data = scratch_mem->device_memory();
429 }
430 const MatrixDescriptor scratch_descriptor(
431 scratch_data, false, output_matrix.num_rows, output_matrix.num_cols,
432 batch_size);
433
434 StatusOr<se::blas::AlgorithmType> best_algorithm = GetGemmAutotuneFn(
435 element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_,
436 beta_, computation_type, stream);
437 autotune_it =
438 autotune_results_.insert({device_name, best_algorithm}).first;
439
440 if (autotune_it->second.ok()) {
441 VLOG(2) << "Autotune on GemmThunk " << thunk_name()
442 << " successful; best algorithm is "
443 << best_algorithm.ValueOrDie();
444 } else {
445 VLOG(2) << "Autotune on GemmThunk " << thunk_name()
446 << " unsuccessful. Will use generic gemm.";
447 }
448 }
449
450 const StatusOr<se::blas::AlgorithmType>& best_algorithm =
451 autotune_it->second;
452 if (best_algorithm.ok()) {
453 auto algorithm = best_algorithm.ValueOrDie();
454 VLOG(2) << "Using algorithm " << algorithm
455 << " chosen by autotuning on GemmThunk " << thunk_name();
456 return GetGemmWithAlgorithmFn(element_type)(
457 lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_,
458 computation_type, algorithm, stream,
459 /*output_profile_result=*/nullptr);
460 }
461
462 // Autotune will fail when CUDA 8 and GPU sm_50 or older are used.
463 // Use the older Gemm API in this case.
464 return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
465 alpha_, beta_, stream);
466 };
467
468 auto op_profiler = profiler->MakeScopedInstructionProfiler(
469 implements_whole_instruction_ ? hlo_instruction() : nullptr);
470 bool launch_ok;
471 if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
472 launch_ok = launch(lhs_descriptor, rhs_descriptor,
473 MatrixDescriptor(output_data, false, output_num_rows,
474 output_num_cols, batch_size),
475 stream);
476 } else {
477 launch_ok = launch(rhs_descriptor, lhs_descriptor,
478 MatrixDescriptor(output_data, false, output_num_cols,
479 output_num_rows, batch_size),
480 stream);
481 }
482
483 if (!launch_ok) {
484 return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
485 }
486 return Status::OK();
487 }
488
489 } // namespace gpu
490 } // namespace xla
491