1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17 
18 #include <functional>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 namespace {
31 
32 // This struct contains the metadata of a matrix, e.g., its base address and
33 // dimensions.
34 struct MatrixDescriptor {
MatrixDescriptorxla::gpu::__anon934738310111::MatrixDescriptor35   MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
36                    int64 matrix_num_rows, int64 matrix_num_cols,
37                    int64 matrix_batch_size)
38       : data(matrix_data),
39         transpose(needs_transpose),
40         num_rows(matrix_num_rows),
41         num_cols(matrix_num_cols),
42         batch_size(matrix_batch_size) {}
43 
44   se::DeviceMemoryBase data;
45   bool transpose;  // Whether this matrix needs to be transposed.
46   int64 num_rows;
47   int64 num_cols;
48   int64 batch_size;
49 };
50 
51 // Performs a gemm call without an explicit algorithm on lhs_matrix and
52 // rhs_matrix, and stores the result to output_matrix.
53 template <typename Element>
DoGemm(MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,double alpha,double beta,se::Stream * stream)54 bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
55             MatrixDescriptor output_matrix, double alpha, double beta,
56             se::Stream* stream) {
57   DCHECK(!output_matrix.transpose);
58 
59   const int64 batch_size = lhs_matrix.batch_size;
60   CHECK_EQ(batch_size, rhs_matrix.batch_size);
61   CHECK_EQ(batch_size, output_matrix.batch_size);
62   se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
63   se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
64   se::DeviceMemory<Element> output_data(output_matrix.data);
65 
66   auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
67                                             : se::blas::Transpose::kNoTranspose;
68   auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
69                                             : se::blas::Transpose::kNoTranspose;
70   auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
71 
72   if (batch_size == 1) {
73     return stream
74         ->ThenBlasGemm(
75             lhs_transpose, rhs_transpose, output_matrix.num_rows,
76             output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
77             lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
78             /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
79             &output_data, /*leading dim of output=*/output_matrix.num_rows)
80         .ok();
81   }
82 
83   int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
84   int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
85   int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
86   return stream
87       ->ThenBlasGemmStridedBatched(
88           lhs_transpose, rhs_transpose, output_matrix.num_rows,
89           output_matrix.num_cols, /*size of reduce dim=*/k,
90           /*alpha=*/alpha, lhs_data,
91           /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
92           /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
93           /*beta=*/beta, &output_data,
94           /*leading dim of output=*/output_matrix.num_rows, output_stride,
95           batch_size)
96       .ok();
97 }
98 
99 // Like DoGemm, but takes an explicit computation type and algorithm.
100 // computation_type specifies the type of intermediate values generated during
101 // the matmul (e.g. your input/output matricies could be f16s but you could do
102 // computations with f32s).  algorithm is an opaque identifier which functions
103 // as a hint to cublas.
104 //
105 // Not all algorithms are valid for all matrix sizes, and not all CUDA versions
106 // and GPUs even support gemm-with-algorithm.  So expect that this may fail
107 // unless you've already checked that it works for this particular GPU + input
108 // size.
109 //
110 // If you pass a non-null ProfileResult, this will always return true (assuming
111 // the Stream was valid to begin with); check the is_valid property of the
112 // ProfileResult to see whether the call actually succeeded.
113 template <typename Element>
DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,double alpha,double beta,se::blas::ComputationType computation_type,se::blas::AlgorithmType algorithm,se::Stream * stream,se::blas::ProfileResult * output_profile_result)114 bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
115                          MatrixDescriptor rhs_matrix,
116                          MatrixDescriptor output_matrix, double alpha,
117                          double beta,
118                          se::blas::ComputationType computation_type,
119                          se::blas::AlgorithmType algorithm, se::Stream* stream,
120                          se::blas::ProfileResult* output_profile_result) {
121   DCHECK(!output_matrix.transpose);
122 
123   CHECK_EQ(1, lhs_matrix.batch_size);
124   CHECK_EQ(1, rhs_matrix.batch_size);
125   CHECK_EQ(1, output_matrix.batch_size);
126 
127   se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
128   se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
129   se::DeviceMemory<Element> output_data(output_matrix.data);
130 
131   auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
132                                             : se::blas::Transpose::kNoTranspose;
133   auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
134                                             : se::blas::Transpose::kNoTranspose;
135   auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
136 
137   return stream
138       ->ThenBlasGemmWithAlgorithm(
139           lhs_transpose, rhs_transpose, output_matrix.num_rows,
140           output_matrix.num_cols, /*size of reduce dim=*/k,
141           /*alpha=*/static_cast<Element>(alpha), lhs_data,
142           /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
143           /*leading dim of RHS=*/rhs_matrix.num_rows,
144           /*beta=*/static_cast<Element>(beta), &output_data,
145           /*leading dim of output=*/output_matrix.num_rows, computation_type,
146           algorithm, output_profile_result)
147       .ok();
148 }
149 
150 // Experimentally tries to pick the best algorithm for the given gemm.
151 //
152 // This may fail under perfectly normal circumstances.  In particular, it will
153 // fail if the program was built with < CUDA 8 or if we're using a gpu older
154 // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
155 // all.
156 template <typename Element>
DoGemmAutotune(MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,double alpha,double beta,se::blas::ComputationType computation_type,se::Stream * stream)157 StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
158     MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
159     MatrixDescriptor output_matrix, double alpha, double beta,
160     se::blas::ComputationType computation_type, se::Stream* stream) {
161   std::vector<se::blas::AlgorithmType> algorithms;
162   CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
163 
164   se::blas::ProfileResult best_result;
165   for (auto algorithm : algorithms) {
166     se::blas::ProfileResult profile_result;
167     // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
168     // for all algorithms if we're targeting < sm_50.  But because we pass a
169     // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
170     // and the actual success-ness is returned in ProfileResult::is_valid.
171     CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
172                                        alpha, beta, computation_type, algorithm,
173                                        stream, &profile_result));
174 
175     if (profile_result.is_valid()) {
176       VLOG(3) << "cublas gemm algorithm " << algorithm << " took "
177               << profile_result.elapsed_time_in_ms() << "ms";
178       if (profile_result.elapsed_time_in_ms() <
179           best_result.elapsed_time_in_ms()) {
180         best_result = profile_result;
181       }
182     } else {
183       VLOG(4) << "cublas gemm algorithm " << algorithm << " failed.";
184     }
185   }
186 
187   if (best_result.is_valid()) {
188     return best_result.algorithm();
189   }
190 
191   return InternalError(
192       "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms "
193       "ran successfully",
194       stream, algorithms.size());
195 }
196 
197 // Helper functions to go from a PrimitiveType to a templated version of
198 // DoGemm/DoGemmWithAlgorithm/DoGemmAutotune.
GetGemmFn(PrimitiveType type)199 auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
200   switch (type) {
201     case F16:
202       return &DoGemm<Eigen::half>;
203     case F32:
204       return &DoGemm<float>;
205     case F64:
206       return &DoGemm<double>;
207     case C64:
208       return &DoGemm<std::complex<float>>;
209     case C128:
210       return &DoGemm<std::complex<double>>;
211     default:
212       LOG(FATAL) << "Unsupported type.";
213   }
214 }
GetGemmWithAlgorithmFn(PrimitiveType type)215 auto GetGemmWithAlgorithmFn(PrimitiveType type)
216     -> decltype(&DoGemmWithAlgorithm<float>) {
217   switch (type) {
218     case F16:
219       return &DoGemmWithAlgorithm<Eigen::half>;
220     case F32:
221       return &DoGemmWithAlgorithm<float>;
222     case F64:
223       return &DoGemmWithAlgorithm<double>;
224     case C64:
225       return &DoGemmWithAlgorithm<std::complex<float>>;
226     case C128:
227       return &DoGemmWithAlgorithm<std::complex<double>>;
228     default:
229       LOG(FATAL) << "Unsupported type.";
230   }
231 }
GetGemmAutotuneFn(PrimitiveType type)232 auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
233   switch (type) {
234     case F16:
235       return &DoGemmAutotune<Eigen::half>;
236     case F32:
237       return &DoGemmAutotune<float>;
238     case F64:
239       return &DoGemmAutotune<double>;
240     case C64:
241       return &DoGemmAutotune<std::complex<float>>;
242     case C128:
243       return &DoGemmAutotune<std::complex<double>>;
244     default:
245       LOG(FATAL) << "Unsupported type.";
246   }
247 }
248 
249 // Converts from an XLA PrimitiveType to a blas::ComputationType, which is used
250 // to specify the precision with which matmul computations should be performed,
251 // separately from the precision of the inputs and result.
GetBlasComputationType(PrimitiveType type)252 se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
253   switch (type) {
254     case F16:
255       // Use F32 as computation type for F16 as we currently only implement the
256       // cuDNN pseudo half configuration for half precision.
257       return se::blas::ComputationType::kF32;
258     case F32:
259       return se::blas::ComputationType::kF32;
260     case F64:
261       return se::blas::ComputationType::kF64;
262     case C64:
263       return se::blas::ComputationType::kComplexF32;
264     case C128:
265       return se::blas::ComputationType::kComplexF64;
266     default:
267       LOG(FATAL) << "Unsupported type.";
268   }
269 }
270 
GetDimensionNumbers(const HloInstruction & hlo_instruction)271 DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
272   if (hlo_instruction.opcode() == HloOpcode::kDot) {
273     return hlo_instruction.dot_dimension_numbers();
274   }
275   CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
276   CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
277   CHECK(hlo_instruction.fused_expression_root()->opcode() == HloOpcode::kAdd ||
278         hlo_instruction.fused_expression_root()->opcode() ==
279             HloOpcode::kMultiply);
280   // Try to find the dot inside the output fusion node.
281   const HloInstruction* dot =
282       hlo_instruction.fused_expression_root()->operand(0);
283   if (dot->opcode() != HloOpcode::kDot) {
284     dot = hlo_instruction.fused_expression_root()->operand(1);
285   }
286   CHECK_EQ(dot->opcode(), HloOpcode::kDot);
287 
288   return dot->dot_dimension_numbers();
289 }
290 
291 }  // namespace
292 
GemmThunk(const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,double alpha,double beta,const HloInstruction * hlo_instruction,bool implements_whole_instruction)293 GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
294                      const BufferAllocation::Slice& rhs_buffer,
295                      const BufferAllocation::Slice& output_buffer,
296                      const Shape& lhs_shape, const Shape& rhs_shape,
297                      const Shape& output_shape, double alpha, double beta,
298                      const HloInstruction* hlo_instruction,
299                      bool implements_whole_instruction)
300     : Thunk(Kind::kGemm, hlo_instruction),
301       lhs_buffer_(lhs_buffer),
302       rhs_buffer_(rhs_buffer),
303       output_buffer_(output_buffer),
304       lhs_shape_(lhs_shape),
305       rhs_shape_(rhs_shape),
306       output_shape_(output_shape),
307       alpha_(alpha),
308       beta_(beta),
309       implements_whole_instruction_(implements_whole_instruction) {}
310 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)311 Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
312                                   se::Stream* stream,
313                                   HloExecutionProfiler* profiler) {
314   VLOG(2) << "Executing a GemmThunk";
315 
316   se::DeviceMemoryBase lhs_data =
317       buffer_allocations.GetDeviceAddress(lhs_buffer_);
318   se::DeviceMemoryBase rhs_data =
319       buffer_allocations.GetDeviceAddress(rhs_buffer_);
320   se::DeviceMemoryBase output_data =
321       buffer_allocations.GetDeviceAddress(output_buffer_);
322 
323   DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
324   CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
325            dim_nums.rhs_batch_dimensions_size());
326   CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
327 
328   int64 row_dim = dim_nums.lhs_batch_dimensions_size();
329   int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
330   int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
331                                      output_shape_.dimensions().end() - 2, 1,
332                                      std::multiplies<int64>());
333 
334   // Check that the batch dims don't cover the last two dims.
335   for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
336     CHECK_NE(row_dim, batch_dim);
337     CHECK_NE(col_dim, batch_dim);
338   }
339 
340   // Verify that the non-batch dimensions are minor-most. This is required for
341   // efficient access.
342   for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
343     CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
344     CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
345   }
346 
347   // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
348   // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
349   // their layout. Therefore, we should treat dimension 0 as row and dimension 1
350   // as column when mapping a matrix Dot to BLAS gemm.
351   int64 output_num_rows = output_shape_.dimensions(row_dim);
352   int64 output_num_cols = output_shape_.dimensions(col_dim);
353 
354   // BLAS gemm expects the inputs and the output are in column-major order.
355   // Therefore, we need to convert dot between row-major matrices to that
356   // between column-major matrices. The key insight for the conversion is that,
357   // in linear storage, matrix M in column-major order is identical to the
358   // transpose of M in row-major order. In other words,
359   //
360   //   column-major(M) = row-major(M^T).
361   //
362   // Leveraging this insight, we can perform dot between row-major matrices as
363   // follows.
364   //
365   // row-major(C)
366   //   = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
367   //   = gemm(column-major(B^T), column-major(A^T))
368   //   = gemm(row-major(B), row-major(A))
369   //
370   // Although we do not modify the content of A and B in linear memory, we
371   // should use the dimensions of B^T and A^T when calling gemm. For example,
372   // the leading dimension of the LHS matrix of gemm is the number of rows in
373   // B^T and thus the number of columns in B.
374 
375   auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
376                              bool transpose) -> MatrixDescriptor {
377     bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
378     bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
379                            LayoutUtil::Minor(output_shape_.layout(), row_dim);
380     return MatrixDescriptor(
381         data, transpose ^ layout_mismatch,
382         shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
383         shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
384         batch_size);
385   };
386 
387   const MatrixDescriptor lhs_descriptor = make_descriptor(
388       lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
389   const MatrixDescriptor rhs_descriptor = make_descriptor(
390       rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
391 
392   // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
393   // autotune this gemm to figure out the best algorithm.
394   auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
395                     MatrixDescriptor output_matrix, se::Stream* stream) {
396     PrimitiveType element_type = output_shape_.element_type();
397     se::blas::ComputationType computation_type =
398         GetBlasComputationType(element_type);
399 
400     // TODO(b/112111608): Implement auto tune for batched gemm.
401     if (batch_size != 1) {
402       return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
403                                      alpha_, beta_, stream);
404     }
405 
406     auto thunk_name = [&] {
407       return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
408                                           : "<null>";
409     };
410 
411     const string& device_name = stream->parent()->GetDeviceDescription().name();
412     auto autotune_it = autotune_results_.find(device_name);
413     if (autotune_it == autotune_results_.end()) {
414       VLOG(3) << "Starting autotune of GemmThunk " << thunk_name();
415 
416       // If the output buffer already contains a bias then autotune into a
417       // scratch buffer. This avoids overwriting the bias buffer. The scratch
418       // buffer may contain arbitrary garbage values.
419       se::DeviceMemoryBase scratch_data = output_data;
420       std::unique_ptr<se::TemporaryDeviceMemory<char>> scratch_mem;
421       if (beta_ != 0.0) {
422         auto temp_status = stream->AllocateTemporaryArray<char>(
423             ShapeUtil::ByteSizeOf(output_shape_));
424         if (!temp_status.ok()) {
425           return false;
426         }
427         scratch_mem = std::move(temp_status).ValueOrDie();
428         scratch_data = scratch_mem->device_memory();
429       }
430       const MatrixDescriptor scratch_descriptor(
431           scratch_data, false, output_matrix.num_rows, output_matrix.num_cols,
432           batch_size);
433 
434       StatusOr<se::blas::AlgorithmType> best_algorithm = GetGemmAutotuneFn(
435           element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_,
436                         beta_, computation_type, stream);
437       autotune_it =
438           autotune_results_.insert({device_name, best_algorithm}).first;
439 
440       if (autotune_it->second.ok()) {
441         VLOG(2) << "Autotune on GemmThunk " << thunk_name()
442                 << " successful; best algorithm is "
443                 << best_algorithm.ValueOrDie();
444       } else {
445         VLOG(2) << "Autotune on GemmThunk " << thunk_name()
446                 << " unsuccessful.  Will use generic gemm.";
447       }
448     }
449 
450     const StatusOr<se::blas::AlgorithmType>& best_algorithm =
451         autotune_it->second;
452     if (best_algorithm.ok()) {
453       auto algorithm = best_algorithm.ValueOrDie();
454       VLOG(2) << "Using algorithm " << algorithm
455               << " chosen by autotuning on GemmThunk " << thunk_name();
456       return GetGemmWithAlgorithmFn(element_type)(
457           lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_,
458           computation_type, algorithm, stream,
459           /*output_profile_result=*/nullptr);
460     }
461 
462     // Autotune will fail when CUDA 8 and GPU sm_50 or older are used.
463     // Use the older Gemm API in this case.
464     return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
465                                    alpha_, beta_, stream);
466   };
467 
468   auto op_profiler = profiler->MakeScopedInstructionProfiler(
469       implements_whole_instruction_ ? hlo_instruction() : nullptr);
470   bool launch_ok;
471   if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
472     launch_ok = launch(lhs_descriptor, rhs_descriptor,
473                        MatrixDescriptor(output_data, false, output_num_rows,
474                                         output_num_cols, batch_size),
475                        stream);
476   } else {
477     launch_ok = launch(rhs_descriptor, lhs_descriptor,
478                        MatrixDescriptor(output_data, false, output_num_cols,
479                                         output_num_rows, batch_size),
480                        stream);
481   }
482 
483   if (!launch_ok) {
484     return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
485   }
486   return Status::OK();
487 }
488 
489 }  // namespace gpu
490 }  // namespace xla
491