1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"  // from @llvm-project
27 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"  // from @llvm-project
28 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"  // from @llvm-project
29 #include "mlir/EDSC/Builders.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
38 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
39 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
40 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
41 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
42 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
47 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/platform/logging.h"
53 
54 namespace xla {
55 
56 using llvm_ir::SetToFirstInsertPoint;
57 
58 namespace cpu {
59 namespace {
60 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)61 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
62   return config.debug_options().xla_cpu_multi_thread_eigen();
63 }
64 
65 // Represents a dot operation.  We use this in lieu of an `HloInstruction`
66 // because we want to be able to create this for the "inner" dot operation in a
67 // batch dot, for which there is no separate HLO instruction.
68 struct DotInfo {
69   Shape lhs_shape;
70   Shape rhs_shape;
71   Shape result_shape;
72   DotDimensionNumbers dim_nums;
73 
74   DotInfo() = default;
75 
DotInfoxla::cpu::__anonb70c17dc0111::DotInfo76   explicit DotInfo(const HloInstruction& instr) {
77     CHECK_EQ(instr.opcode(), HloOpcode::kDot);
78     lhs_shape = instr.operand(0)->shape();
79     rhs_shape = instr.operand(1)->shape();
80     result_shape = instr.shape();
81     dim_nums = instr.dot_dimension_numbers();
82   }
83 };
84 
85 // Dictates how a dot operation is implemented.
86 enum class DotImplementationStrategy {
87   // The dot operation is lowered into LLVM IR that implements a naive nested
88   // loop that computes the result one element at a time.  This is our
89   // "fallback"; we don't really want this to kick in for any non-trival dot
90   // operation.
91   kNaiveLlvmIr,
92 
93   // The dot operation is lowered into LLVM IR that implements a tiled
94   // Matrix*Vector operation.  This strategy also allows fusing in a bias add
95   // into the dot.  The matrix can be row major or column major, both are
96   // supported.
97   kTiledLlvmIrGemv,
98 
99   // The dot operation is lowered into LLVM IR that implements a tiled
100   // Matrix*Matrix operation.  No fusions are supported.  The two inputs
101   // and the output have to be row major.
102   kTiledLlvmIrGemm,
103 
104   // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
105   kLinalgMatmul,
106 
107   // The dot operation is lowered into a call into an Eigen routine.  No fusions
108   // are supported today.  The two inputs and the output have to be row major.
109   // However, we do allow transposing either the LHS or the RHS as part of the
110   // GEMM -- we expose this flexibility as flexibility in the contraction
111   // dimensions, but we can also see this as flexibility in the input layouts.
112   kEigen,
113 };
114 
115 // Returns the implementation strategy for a dot with the configuration
116 // `dot_info`.
117 DotImplementationStrategy GetDotImplementationStrategy(
118     const HloModuleConfig& config, const DotInfo& dot_info,
119     const TargetMachineFeatures& target_machine_features);
120 
121 // Helper class for emitting LLVM IR to perform the dot operation.
122 class DotOpEmitter {
123  public:
124   explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
125                         const llvm_ir::IrArray& target_array,
126                         const llvm_ir::IrArray& lhs_array,
127                         const llvm_ir::IrArray& rhs_array,
128                         const llvm_ir::IrArray* addend_array,
129                         llvm::Value* executable_run_options_value,
130                         llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
131                         const HloModuleConfig& hlo_module_config,
132                         const TargetMachineFeatures& target_machine_features);
133 
134   // Emits the IR to perform the dot operation.
135   Status Emit();
136 
137  private:
138   // Emits instructions to perform a scalar dot product (a multiply of the
139   // LHS and RHS) and store the results in the target.
140   Status EmitScalarDot();
141 
142   // Emits a call to the CPU runtime to perform the matrix multiply.
143   Status EmitCallToRuntime();
144 
145   // Represents the dimensions of a matrix-matrix multiply operation.
146   struct MatMultDims {
147     // The number of rows in the LHS.
148     int64 m;
149 
150     // The number of columns in the LHS, which is also must be equal to the
151     // number of rows in the RHS.
152     int64 k;
153 
154     // The number of columns on the RHS.
155     int64 n;
156 
157     // True if the LHS matrix is column major.
158     bool lhs_column_major;
159 
160     // True if the LHS contraction dimension is 1.
161     bool lhs_canonical;
162 
163     // True if the RHS matrix is column major.
164     bool rhs_column_major;
165 
166     // True if the RHS contraction dimension is 0.
167     bool rhs_canonical;
168   };
169 
170   // Get the MatMultDims instance for the dot product this DotOpEmitter
171   // represents.  Precondition: the dot is of rank 2 (and thus its operands are
172   // of rank 2 as well).
173   MatMultDims GetMatMultDims() const;
174 
175   // Lowers the dot operation as a tiled Matrix*Vector loop.
176   void EmitTiledLlvmIrGemv();
177 
178   // Lowers the dot operation as a tiled Matrix*Matrix loop.
179   void EmitTiledLlvmIrGemm();
180 
181   // Lowers the dot operation through MLIR's linalg.matmul.
182   Status EmitLinalgMatmul();
183 
184   // Lowers the dot operation as a naive nested loop that computes the result
185   // one element at a time.
186   void EmitNaiveLlvmIrGemm();
187 
188   // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
189   // registers.
GetGemvTilingFactor() const190   int64 GetGemvTilingFactor() const {
191     const int64 kDefaultTilingFactor = 8;
192     return options::LlvmIrGemvTilingFactor(hlo_module_config_)
193         .value_or(kDefaultTilingFactor);
194   }
195 
GetGemmTileSize() const196   std::tuple<int64, int64, int64> GetGemmTileSize() const {
197     // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
198     //
199     // TODO(b/80093688): Tune for other architectures and centralize this
200     // information in one place.
201     const std::tuple<int64, int64, int64> kDefaultTileSize =
202         std::tuple<int64, int64, int64>(11, 9, 1);
203     return options::LlvmIrGemmTileSize(hlo_module_config_)
204         .value_or(kDefaultTileSize);
205   }
206 
GetMlirGemmTileSize() const207   std::array<int64_t, 3> GetMlirGemmTileSize() const {
208     // Tile by 4 x registers x register size. This was picked by running
209     // small matmuls on Haswell and Skylake. There's a lot of room for
210     // improvement here.
211     constexpr int64_t kDefaultTileSizeForM = 4;
212     int64_t elements_per_register =
213         target_machine_features_.vector_register_num_elements(
214             *b_->GetInsertBlock()->getParent(),
215             dot_info_.result_shape.element_type());
216     int64_t num_registers = target_machine_features_.vector_register_count(
217         *b_->GetInsertBlock()->getParent());
218     return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
219   }
220 
221   DotInfo dot_info_;
222   string dot_hlo_name_;
223   const llvm_ir::IrArray& target_array_;
224   const llvm_ir::IrArray& lhs_array_;
225   const llvm_ir::IrArray& rhs_array_;
226   const llvm_ir::IrArray* addend_array_;
227   llvm::Value* executable_run_options_value_;
228   llvm::IRBuilder<>* b_;
229   mlir::MLIRContext* mlir_context_;
230   const HloModuleConfig& hlo_module_config_;
231   const TargetMachineFeatures& target_machine_features_;
232 };
233 }  // namespace
234 
DotOpEmitter(DotInfo dot_info,string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)235 DotOpEmitter::DotOpEmitter(
236     DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
237     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
238     const llvm_ir::IrArray* addend_array,
239     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
240     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
241     const TargetMachineFeatures& target_machine_features)
242     : dot_info_(std::move(dot_info)),
243       dot_hlo_name_(std::move(dot_hlo_name)),
244       target_array_(target_array),
245       lhs_array_(lhs_array),
246       rhs_array_(rhs_array),
247       addend_array_(addend_array),
248       executable_run_options_value_(executable_run_options_value),
249       b_(b),
250       mlir_context_(mlir_context),
251       hlo_module_config_(hlo_module_config),
252       target_machine_features_(target_machine_features) {}
253 
EmitLinalgMatmul()254 Status DotOpEmitter::EmitLinalgMatmul() {
255   Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
256   llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
257                                  rhs_array_.GetBasePointer()};
258   llvm::Value* target_ptr = target_array_.GetBasePointer();
259 
260   // Zero out the output buffer.
261   int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
262   b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
263                    /*Align=*/llvm::MaybeAlign(1));
264 
265   std::string name =
266       absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
267                    dot_info_.lhs_shape.ToString(true), "_",
268                    dot_info_.rhs_shape.ToString(true));
269 
270   return EmitMlirFuncAndCall(
271       mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
272       operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
273         CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
274         CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
275         mlir::MLIRContext* context = builder->getContext();
276         mlir::edsc::ScopedContext scope(*builder, function.getLoc());
277         mlir::Value a = function.getArgument(0), b = function.getArgument(1),
278                     c = function.getArgument(2);
279 
280         llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
281             dot_info_.lhs_shape.rank());
282         llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
283             dot_info_.rhs_shape.rank());
284 
285         llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
286         mlir::AffineExpr reduce_expr;
287         for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
288           parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
289         }
290         reduce_expr =
291             mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
292 
293         // The reduction expr is shared for both inputs.
294         b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
295         c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
296 
297         // Fill in the remaining parallel exprs.
298         int par_expr_num = 0;
299         for (auto* v : {&b_exprs, &c_exprs}) {
300           for (auto& e : *v) {
301             if (!e) {
302               e = parallel_exprs[par_expr_num++];
303             }
304           }
305         }
306 
307         llvm::SmallVector<mlir::IteratorType, 4> iteratorTypes(
308             parallel_exprs.size(), mlir::IteratorType::Parallel);
309         iteratorTypes.push_back(mlir::IteratorType::Reduction);
310 
311         mlir::edsc::StructuredIndexed s_a(a), s_b(b), s_c(c);
312         mlir::edsc::makeGenericLinalgOp(
313             /*iteratorTypes=*/iteratorTypes,
314             /*inputs=*/{s_b(b_exprs), s_c(c_exprs)},
315             /*outputs=*/{s_a(parallel_exprs)},
316             /*resultTensorTypes=*/{}, mlir::edsc::ops::macRegionBuilder);
317         mlir::edsc::intrinsics::std_ret();
318 
319         mlir::linalg::LinalgTilingOptions tilingOptions;
320         tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
321         int64 alignment =
322             target_machine_features_.minimum_alignment_for_allocation(
323                 ShapeUtil::ByteSizeOf(dot_info_.result_shape));
324         mlir::linalg::CodegenStrategy strategy;
325         strategy.tile<mlir::linalg::GenericOp>(tilingOptions)
326             .promote<mlir::linalg::GenericOp>(
327                 mlir::linalg::LinalgPromotionOptions()
328                     .setAlignment(alignment)
329                     .setUseFullTileBuffersByDefault(true)
330                     .setUseAlloca(true))
331             .vectorize<mlir::linalg::GenericOp>()
332             .setVectorTransformsOptions(
333                 mlir::vector::VectorTransformsOptions()
334                     .setVectorTransformsOptions(
335                         mlir::vector::VectorContractLowering::OuterProduct))
336             .setVectorTransferToSCFOptions(
337                 mlir::VectorTransferToSCFOptions().setUnroll(true));
338         strategy.transform(function);
339       });
340 }
341 
EmitTiledLlvmIrGemm()342 void DotOpEmitter::EmitTiledLlvmIrGemm() {
343   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
344   MatMultDims mat_mult_dims = GetMatMultDims();
345 
346   llvm::Value* lhs = lhs_array_.GetBasePointer();
347   llvm::Value* rhs = rhs_array_.GetBasePointer();
348   llvm::Value* target = target_array_.GetBasePointer();
349   int64 m = mat_mult_dims.m;
350   int64 k = mat_mult_dims.k;
351   int64 n = mat_mult_dims.n;
352 
353   if (mat_mult_dims.lhs_column_major) {
354     std::swap(lhs, rhs);
355     std::swap(m, n);
356   }
357 
358   int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
359   b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
360                    /*Align=*/llvm::MaybeAlign(1));
361 
362   int64 max_target_vector_width =
363       target_machine_features_.vector_register_num_elements(
364           *b_->GetInsertBlock()->getParent(), primitive_type);
365 
366   int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
367   std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
368       GetGemmTileSize();
369 
370   EmitSmallGemm(
371       /*scalar_type=*/primitive_type,
372       /*m=*/m, /*k=*/k, /*n=*/n,
373       /*max_vectorization_width=*/max_target_vector_width,
374       /*max_vector_count=*/tile_size_n_in_vector_width,
375       /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
376       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
377       /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
378 }
379 
EmitTiledLlvmIrGemv()380 void DotOpEmitter::EmitTiledLlvmIrGemv() {
381   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
382 
383   CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
384         primitive_util::IsIntegralType(primitive_type));
385 
386   MatMultDims mat_mult_dims = GetMatMultDims();
387   bool is_column_major_matrix_vector_gemv = false;
388   bool is_row_major_matrix_vector_gemv = false;
389 
390   int64 m, k;
391   bool swap_operands;
392 
393   if (mat_mult_dims.m == 1) {
394     // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
395     // we actually want V*M.  We implement V*M as follows (Tr(X) = Transpose of
396     // X):
397     //
398     //   V*M = Tr(Tr(V*M))  // Tr(Tr(X)) == X
399     //       = Tr(Tr(M) * Tr(V))  // Tr(A * B) == Tr(B) * Tr(A)
400     //
401     // Since transposing a vector is physically a no-op, this is really
402     // equivalent to `Tr(M) * V`.  We further implement Tr(M) by pretending that
403     // M is row major if it is actually column major and vice-versa.
404 
405     bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
406                                             ? mat_mult_dims.rhs_column_major
407                                             : !mat_mult_dims.rhs_column_major;
408 
409     if (rhs_effectively_column_major) {
410       k = mat_mult_dims.k;
411       m = mat_mult_dims.n;
412 
413       // We set is_row_major_matrix_vector_gemv and not
414       // is_column_major_matrix_vector_gemv to implement the Transpose trick
415       // mentioned above.
416       is_row_major_matrix_vector_gemv = true;
417       swap_operands = true;
418     } else {
419       k = mat_mult_dims.k;
420       m = mat_mult_dims.n;
421 
422       // We set is_column_major_matrix_vector_gemv and not
423       // is_row_major_matrix_vector_gemv to implement the Transpose trick
424       // mentioned above.
425       is_column_major_matrix_vector_gemv = true;
426       swap_operands = true;
427     }
428   }
429 
430   if (mat_mult_dims.n == 1) {
431     bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
432                                             ? mat_mult_dims.lhs_column_major
433                                             : !mat_mult_dims.lhs_column_major;
434 
435     if (lhs_effectively_column_major) {
436       m = mat_mult_dims.m;
437       k = mat_mult_dims.k;
438       is_column_major_matrix_vector_gemv = true;
439       swap_operands = false;
440     } else {
441       m = mat_mult_dims.m;
442       k = mat_mult_dims.k;
443       is_row_major_matrix_vector_gemv = true;
444       swap_operands = false;
445     }
446   }
447 
448   CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
449 
450   int64 tiling_factor = GetGemvTilingFactor();
451   CHECK_GT(tiling_factor, 0);
452 
453   llvm::Value* result_op = target_array_.GetBasePointer();
454   llvm::Value* lhs_op =
455       swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
456   llvm::Value* rhs_op =
457       swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
458 
459   const int target_vector_register_element_size =
460       target_machine_features_.vector_register_num_elements(
461           *b_->GetInsertBlock()->getParent(), primitive_type);
462 
463   // We may not always know the vector register size for the target we're
464   // compiling against, in which case target_vector_register_element_size is 0.
465   // In these cases we choose a default LLVM IR register size.
466   const int kUnknownTargetVectorRegisterSize = 4;
467   const int vector_register_element_size =
468       target_vector_register_element_size == 0
469           ? kUnknownTargetVectorRegisterSize
470           : target_vector_register_element_size;
471 
472   if (is_column_major_matrix_vector_gemv) {
473     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
474             << " and k = " << k;
475     EmitColumnMajorGemv(
476         /*scalar_type=*/primitive_type,
477         /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
478         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
479         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
480         /*result=*/result_op, b_, hlo_module_config_);
481   } else {
482     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
483             << " and k = " << k;
484     EmitRowMajorGemv(
485         /*scalar_type=*/primitive_type,
486         /*tile_rows=*/tiling_factor,
487         /*tile_cols=*/vector_register_element_size,
488         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
489         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
490         /*result=*/result_op, b_, hlo_module_config_);
491   }
492 }
493 
Emit()494 Status DotOpEmitter::Emit() {
495   // The dot operation performs a sum of products over dimension 0 of the left
496   // hand side operand and dimension 1 of the right hand side operand.
497   //
498   // Let the shapes of lhs and rhs be defined as below:
499   //
500   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
501   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
502   //
503   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
504   // the rhs has size R{1}. Necessarily, then:
505   //
506   //   L{0} == R{1}
507   //
508   // The output of the operation has the following shape:
509   //
510   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
511   //
512   // To perform the operation we construct a loop nest with one for-loop for
513   // each dimension of the output. Inside this loop nest is another for-loop
514   // which performs the sum-of-products (the reduction loop) before storing
515   // the result in the output buffer.
516 
517   const Shape& lhs_shape = lhs_array_.GetShape();
518   const Shape& rhs_shape = rhs_array_.GetShape();
519 
520   if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
521     // If the operands are scalar, don't emit any loops.
522     TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
523                  ShapeUtil::IsScalar(rhs_shape));
524     return EmitScalarDot();
525   }
526 
527   switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
528                                        target_machine_features_)) {
529     case DotImplementationStrategy::kNaiveLlvmIr:
530       EmitNaiveLlvmIrGemm();
531       return Status::OK();
532 
533     case DotImplementationStrategy::kTiledLlvmIrGemv:
534       EmitTiledLlvmIrGemv();
535       return Status::OK();
536 
537     case DotImplementationStrategy::kTiledLlvmIrGemm:
538       EmitTiledLlvmIrGemm();
539       return Status::OK();
540 
541     case DotImplementationStrategy::kLinalgMatmul:
542       return EmitLinalgMatmul();
543 
544     case DotImplementationStrategy::kEigen:
545       return EmitCallToRuntime();
546   }
547 }
548 
EmitNaiveLlvmIrGemm()549 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
550   CHECK_EQ(addend_array_, nullptr);
551 
552   const Shape& lhs_shape = lhs_array_.GetShape();
553   const Shape& rhs_shape = rhs_array_.GetShape();
554   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
555 
556   // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
557   // case where the reduction dimension is 0 for both LHS and RHS. This results
558   // in a vector dot product producing a scalar.
559   int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
560   int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
561 
562   // Verify the reduction dimension in the two operands are the same size.
563   CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
564            rhs_shape.dimensions(rhs_reduction_dimension));
565 
566   bool lhs_reduction_along_minor_dimension =
567       lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
568   bool rhs_reduction_along_minor_dimension =
569       rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
570 
571   // Create loop nests which loop through the LHS operand dimensions and the RHS
572   // operand dimensions. The reduction dimension of the LHS and RHS are handled
573   // in a separate innermost loop which performs the sum of products.
574   llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
575   std::vector<llvm::Value*> lhs_multi_index =
576       loop_nest.EmitOperandArrayLoopNest(
577           lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
578   std::vector<llvm::Value*> rhs_multi_index =
579       loop_nest.EmitOperandArrayLoopNest(
580           rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
581 
582   // Create the loop which does the sum of products reduction.
583   //
584   // The prevent_unrolling bit is working around a deficiency in LLVM's loop
585   // vectorization pipeline, wherein in some cases unrolling a loop can prevent
586   // effective vectorization.  Since we know that the IR we generate when
587   // reducing across the minor dimension in both LHS and RHS is vectorized well
588   // by the loop vectorizer, we block unrolling in that case to stop loop unroll
589   // from messing up the vectorization.
590   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
591       0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
592       /*unroll_mode=*/
593       (lhs_reduction_along_minor_dimension &&
594        rhs_reduction_along_minor_dimension)
595           ? xla::llvm_ir::UnrollMode::kNoUnroll
596           : xla::llvm_ir::UnrollMode::kDefaultUnroll);
597 
598   // The final entry in the rhs and lhs indexes is the indvar of the
599   // reduction loop.
600   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
601   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
602                                     b_->getInt64Ty());
603   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
604   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
605                                     b_->getInt64Ty());
606 
607   // For computing the sum of products we alloca a single location to store the
608   // dot product result as we accumulate it within the reduction loop. After the
609   // reduction loop we load the result and store into the output array.
610 
611   // Function entry basic block.
612   // - Emit alloca for accumulator
613   llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
614   SetToFirstInsertPoint(&func->getEntryBlock(), b_);
615   llvm::Type* accum_type = target_array_.GetElementLlvmType();
616   llvm::Value* accum_address =
617       b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
618 
619   // Preheader basic block of reduction loop:
620   // - Initialize accumulator to zero.
621   llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
622   b_->SetInsertPoint(preheader_bb->getTerminator());
623 
624   b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
625 
626   // Body basic block of reduction loop:
627   // - Load elements from lhs and rhs array.
628   // - Multiply lhs-element and rhs-element.
629   // - Load accumulator and add to product.
630   // - Store sum back into accumulator.
631   SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
632 
633   llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
634   llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
635 
636   llvm::Value* accum = b_->CreateLoad(accum_address);
637   llvm::Value* updated_accum;
638   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
639     auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
640     auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
641     llvm::Value* product_real =
642         b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
643                        b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
644     llvm::Value* product_imag =
645         b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
646                        b_->CreateFMul(imag(lhs_element), real(rhs_element)));
647     updated_accum = b_->CreateInsertValue(
648         accum, b_->CreateFAdd(real(accum), product_real), {0});
649     updated_accum = b_->CreateInsertValue(
650         updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
651   } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
652     llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
653     updated_accum = b_->CreateAdd(accum, product);
654   } else if (lhs_shape.element_type() == PRED) {
655     llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
656     updated_accum = b_->CreateOr(accum, product);
657   } else {
658     llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
659     updated_accum = b_->CreateFAdd(accum, product);
660   }
661   b_->CreateStore(updated_accum, accum_address);
662 
663   // Exit basic block of reduction loop.
664   // - Load accumulator value (the result).
665   // - Store into output array.
666   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
667 
668   llvm::Value* result = b_->CreateLoad(accum_address);
669 
670   // Create index into target address. The target index is the concatenation of
671   // the rhs and lhs indexes with the reduction dimensions removed. The terms
672   // from the rhs index are the lower dimensions in the index so we add them
673   // first.
674   std::vector<llvm::Value*> target_multi_index;
675   for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
676     if (dimension != lhs_reduction_dimension) {
677       target_multi_index.push_back(lhs_index[dimension]);
678     }
679   }
680   for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
681     if (dimension != rhs_reduction_dimension) {
682       target_multi_index.push_back(rhs_index[dimension]);
683     }
684   }
685 
686   llvm_ir::IrArray::Index target_index(
687       target_multi_index, target_array_.GetShape(), lhs_index.GetType());
688   target_array_.EmitWriteArrayElement(target_index, result, b_);
689 
690   // Set the IR builder insert point to the exit basic block of the outer most
691   // loop.
692   b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
693 }
694 
EmitScalarDot()695 Status DotOpEmitter::EmitScalarDot() {
696   // A scalar dot is just a scalar multiply.
697   llvm::Value* result;
698   // Use the same index_type for all tensor accesses in the same kernel.
699   llvm::Type* index_type = b_->getInt64Ty();
700   llvm_ir::IrArray::Index element_index(index_type);
701   llvm::Value* lhs_value =
702       lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
703   llvm::Value* rhs_value =
704       rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
705   if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
706     auto get_real = [&](llvm::Value* x) {
707       return b_->CreateExtractValue(x, {0});
708     };
709 
710     auto get_imag = [&](llvm::Value* x) {
711       return b_->CreateExtractValue(x, {1});
712     };
713 
714     llvm::Value* real = b_->CreateFSub(
715         b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
716         b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
717     llvm::Value* imag = b_->CreateFAdd(
718         b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
719         b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
720     result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
721     result = b_->CreateInsertValue(result, real, {0});
722     result = b_->CreateInsertValue(result, imag, {1});
723   } else {
724     result = b_->CreateFMul(lhs_value, rhs_value);
725   }
726   target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
727   return Status::OK();
728 }
729 
EmitCallToRuntime()730 Status DotOpEmitter::EmitCallToRuntime() {
731   // The signature of the Eigen runtime matmul function is:
732   //
733   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
734   //          int64 m, int64 n, int64 k, int32 transpose_lhs,
735   //          int32 transpose_rhs);
736   // The two transpose_... parameters are actually booleans, but we use int32
737   // to avoid target-dependent calling convention details.
738 
739   bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
740   bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
741   PrimitiveType type = target_array_.GetShape().element_type();
742   llvm::Function* function = b_->GetInsertBlock()->getParent();
743   llvm::Module* module = function->getParent();
744   llvm::Type* float_type;
745   const char* fn_name;
746   switch (type) {
747     case F16:
748       fn_name = multi_threaded
749                     ? runtime::kEigenMatMulF16SymbolName
750                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
751       float_type = b_->getHalfTy();
752       break;
753     case F32:
754       fn_name = multi_threaded
755                     ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
756                                    : runtime::kEigenMatMulF32SymbolName)
757                     : (use_mkl_dnn
758                            ? runtime::kMKLSingleThreadedMatMulF32SymbolName
759                            : runtime::kEigenSingleThreadedMatMulF32SymbolName);
760       float_type = b_->getFloatTy();
761       break;
762     case F64:
763       fn_name = multi_threaded
764                     ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
765                                    : runtime::kEigenMatMulF64SymbolName)
766                     : (use_mkl_dnn
767                            ? runtime::kMKLSingleThreadedMatMulF64SymbolName
768                            : runtime::kEigenSingleThreadedMatMulF64SymbolName);
769       float_type = b_->getDoubleTy();
770       break;
771     case C64:
772       fn_name = multi_threaded
773                     ? runtime::kEigenMatMulC64SymbolName
774                     : runtime::kEigenSingleThreadedMatMulC64SymbolName;
775       float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
776       break;
777     case C128:
778       fn_name = multi_threaded
779                     ? runtime::kEigenMatMulC128SymbolName
780                     : runtime::kEigenSingleThreadedMatMulC128SymbolName;
781       float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
782       break;
783     case S32:
784       fn_name = multi_threaded
785                     ? runtime::kEigenMatMulS32SymbolName
786                     : runtime::kEigenSingleThreadedMatMulS32SymbolName;
787       float_type = b_->getInt32Ty();
788       break;
789     default:
790       return Unimplemented("Invalid type %s for dot operation",
791                            PrimitiveType_Name(type));
792   }
793 
794   llvm::Type* float_ptr_type = float_type->getPointerTo();
795   llvm::Type* int64_type = b_->getInt64Ty();
796   llvm::Type* int32_type = b_->getInt32Ty();
797   llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
798   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
799       b_->getVoidTy(),
800       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
801        int64_type, int64_type, int64_type, int32_type, int32_type},
802       /*isVarArg=*/false);
803 
804   llvm::FunctionCallee matmul_func =
805       module->getOrInsertFunction(fn_name, matmul_type);
806   if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
807     fn->setCallingConv(llvm::CallingConv::C);
808     fn->setDoesNotThrow();
809     fn->setOnlyAccessesArgMemory();
810   }
811 
812   // The Eigen runtime function expects column-major layout. If the matrices are
813   // row major, then use the following identity to compute the product:
814   //
815   //   (A x B)^T = B^T x A^T
816   //
817   // The connection between this identity and memory layout is that the
818   // transpose operation can also be considered as an operation that changes the
819   // memory layout of a matrix from row-major to column-major or vice versa.
820   //
821   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
822 
823   MatMultDims mat_mult_dims = GetMatMultDims();
824 
825   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
826 
827   const llvm_ir::IrArray* lhs = &lhs_array_;
828   const llvm_ir::IrArray* rhs = &rhs_array_;
829   bool transpose_lhs = !mat_mult_dims.lhs_canonical;
830   bool transpose_rhs = !mat_mult_dims.rhs_canonical;
831 
832   if (!mat_mult_dims.lhs_column_major) {
833     std::swap(mat_mult_dims.m, mat_mult_dims.n);
834     std::swap(lhs, rhs);
835     std::swap(transpose_lhs, transpose_rhs);
836   }
837 
838   b_->CreateCall(
839       matmul_func,
840       {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
841        b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
842        b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
843        b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
844        b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
845        b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
846        b_->getInt32(transpose_rhs)});
847   return Status::OK();
848 }
849 
GetMatMultDims() const850 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
851   CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
852 
853   const Shape& lhs_shape = lhs_array_.GetShape();
854   const Shape& rhs_shape = rhs_array_.GetShape();
855   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
856 
857   auto is_column_major = [](const Shape& shape) {
858     return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
859   };
860 
861   // Non-contracting dots should never make it here.
862   CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
863   CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
864 
865   return {
866       /*m=*/lhs_shape.rank() <= 1
867           ? 1LL
868           : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
869       /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
870       /*n=*/rhs_shape.rank() <= 1
871           ? 1LL
872           : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
873       /*lhs_column_major=*/is_column_major(lhs_shape),
874       /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
875           dim_nums.lhs_contracting_dimensions(0) == 1,
876       /*rhs_column_major=*/is_column_major(rhs_shape),
877       /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
878 }
879 
880 // For vector-matrix dot products, it is always profitable to make the Rhs
881 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)882 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
883     const HloInstruction& hlo) {
884   if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
885     if (hlo.operand(0)->shape().rank() != 1 ||
886         hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
887       return {};
888     }
889 
890     // Don't bother if the other operand is tiny, switching to column major
891     // wouldn't use tiling.
892     constexpr int kColumnMajorThresholdInBytes = 32;
893     int64 lhs_size =
894         ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
895         ShapeUtil::ElementsIn(hlo.operand(0)->shape());
896     if (lhs_size < kColumnMajorThresholdInBytes) {
897       return {};
898     }
899 
900     return 1;
901   }
902 
903   if (hlo.IsOutputFusion()) {
904     auto* fusion_root =
905         hlo.fused_instructions_computation()->root_instruction();
906     if (fusion_root->opcode() != HloOpcode::kAdd) {
907       return {};
908     }
909 
910     for (auto* fusion_root_op : fusion_root->operands()) {
911       if (fusion_root_op->opcode() != HloOpcode::kDot) {
912         continue;
913       }
914       if (auto operand_num =
915               ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
916         auto* operand = fusion_root_op->operand(*operand_num);
917         if (operand->opcode() == HloOpcode::kParameter &&
918             operand->user_count() == 1) {
919           return operand->parameter_number();
920         }
921       }
922     }
923   }
924 
925   return {};
926 }
927 
928 namespace {
929 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)930 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
931 
IsSimpleLayout(const Layout & layout)932 bool IsSimpleLayout(const Layout& layout) {
933   return layout.tiles().empty() && layout.format() == DENSE;
934 }
935 
936 // In a gemm operation where output = lhs * rhs, check whether the given shapes
937 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)938 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
939                    const Shape& output_shape,
940                    const TargetMachineFeatures& target_machine_features) {
941   CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
942       << lhs_shape.DebugString();
943   CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
944       << rhs_shape.DebugString();
945   CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
946       << output_shape.DebugString();
947 
948   switch (output_shape.element_type()) {
949     case F16:
950     case F32:
951     case F64:
952     case C64:
953     case C128:
954     case S32:
955       return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
956     default:
957       return false;
958   }
959 }
960 
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)961 bool IsAlignedGemm(const DotInfo& dot_info,
962                    const TargetMachineFeatures& target_machine_features) {
963   if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
964       ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
965     return false;
966   }
967 
968   return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
969                        dot_info.result_shape, target_machine_features);
970 }
971 
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)972 bool CanEmitTiledLlvmIrGemm(
973     const HloModuleConfig& config, const DotInfo& dot_info,
974     const TargetMachineFeatures& target_machine_features) {
975   CHECK(IsAlignedGemm(dot_info, target_machine_features));
976 
977   if (ShouldUseMultiThreadedEigen(config)) {
978     return false;
979   }
980 
981   int m = dot_info.result_shape.dimensions(0);
982   int k = dot_info.lhs_shape.dimensions(
983       dot_info.dim_nums.lhs_contracting_dimensions(0));
984   int n = dot_info.result_shape.dimensions(1);
985 
986   if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
987     // TODO(sanjoy):  We should make these numbers micro-arch specific.
988     bool small_gemm =
989         k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
990     if (!small_gemm) {
991       return false;
992     }
993   }
994 
995   bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
996   bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
997 
998   if (!(lhs_canonical && rhs_canonical)) {
999     return false;
1000   }
1001 
1002   if (dot_info.result_shape.element_type() == F16 ||
1003       dot_info.result_shape.element_type() == C64 ||
1004       dot_info.result_shape.element_type() == C128) {
1005     // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
1006     // adding this comment NFC.
1007     return false;
1008   }
1009 
1010   return true;
1011 }
1012 
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1013 DotImplementationStrategy GetDotImplementationStrategy(
1014     const HloModuleConfig& config, const DotInfo& dot_info,
1015     const TargetMachineFeatures& target_machine_features) {
1016   PrimitiveType element_type = dot_info.result_shape.element_type();
1017   // Any Matrix-Vector product of floating point or integral type, or
1018   // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1019   // IR implementation.
1020   if ((dot_info.result_shape.dimensions_size() <= 1 ||
1021        (dot_info.result_shape.dimensions_size() == 2 &&
1022         (dot_info.result_shape.dimensions(0) == 1 ||
1023          dot_info.result_shape.dimensions(1) == 1))) &&
1024       (primitive_util::IsFloatingPointType(element_type) ||
1025        primitive_util::IsIntegralType(element_type))) {
1026     return DotImplementationStrategy::kTiledLlvmIrGemv;
1027   }
1028 
1029   if (IsAlignedGemm(dot_info, target_machine_features)) {
1030     if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
1031       return DotImplementationStrategy::kTiledLlvmIrGemm;
1032     }
1033     return DotImplementationStrategy::kEigen;
1034   }
1035 
1036   return DotImplementationStrategy::kNaiveLlvmIr;
1037 }
1038 
EmitNonBatchDotOperation(DotInfo dot_info,string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1039 Status EmitNonBatchDotOperation(
1040     DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
1041     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1042     const llvm_ir::IrArray* addend_array,
1043     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1044     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1045     const TargetMachineFeatures& target_machine_features) {
1046   PrimitiveType type = target_array.GetShape().element_type();
1047   TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
1048                U16 == type || S32 == type || U32 == type || S64 == type ||
1049                U64 == type || F16 == type || F32 == type || F64 == type ||
1050                C64 == type || C128 == type);
1051   DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
1052                            target_array, lhs_array, rhs_array, addend_array,
1053                            executable_run_options_value, b, mlir_context,
1054                            hlo_module_config, target_machine_features);
1055   return dot_emitter.Emit();
1056 }
1057 
DropFirstDim(const Shape & shape)1058 Shape DropFirstDim(const Shape& shape) {
1059   absl::Span<int64 const> array_shape_dims(shape.dimensions());
1060   array_shape_dims.remove_prefix(1);
1061   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1062                                                   array_shape_dims);
1063 }
1064 
CollapseFirstNDims(const Shape & shape,int64 n)1065 Shape CollapseFirstNDims(const Shape& shape, int64 n) {
1066   absl::Span<int64 const> input_shape_dims(shape.dimensions());
1067   int64 prefix_dim =
1068       std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1069                       1ll, std::multiplies<int64>());
1070   DimensionVector result_dims;
1071   result_dims.push_back(prefix_dim);
1072   std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
1073             std::back_inserter(result_dims));
1074   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1075                                                   result_dims);
1076 }
1077 
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64 n)1078 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
1079                                     const llvm_ir::IrArray& array, int64 n) {
1080   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1081   const Shape& shape = array.GetShape();
1082   CHECK(shape.has_layout() &&
1083         LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
1084   CHECK_GE(shape.dimensions_size(), n);
1085   Shape new_shape = CollapseFirstNDims(shape, n);
1086   llvm::Value* new_value = b->CreateBitCast(
1087       array.GetBasePointer(),
1088       llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
1089   return llvm_ir::IrArray(new_value, std::move(new_shape));
1090 }
1091 
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)1092 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
1093   // Checks some invariants that do not hold in general, but DotDecomposer
1094   // should have established for us.  This is just a debugging aid.
1095   TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
1096   std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
1097   absl::c_iota(batch_dim_numbers, 0);
1098   TF_RET_CHECK(
1099       absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
1100   TF_RET_CHECK(
1101       absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
1102   return Status::OK();
1103 }
1104 
1105 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)1106 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
1107                                     llvm::Value* batch_index,
1108                                     llvm::IRBuilder<>* b) {
1109   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1110 
1111   Shape inner_shape = DropFirstDim(outer_array.GetShape());
1112   std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
1113                                            b->getInt64(0));
1114   multidim_index[0] = batch_index;
1115   llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
1116                                       batch_index->getType());
1117   llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
1118   llvm::Type* slice_ptr_type =
1119       llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
1120   return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
1121                           std::move(inner_shape));
1122 }
1123 
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1124 Status EmitBatchDotOperation(
1125     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1126     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1127     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1128     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1129     const TargetMachineFeatures& target_machine_features) {
1130   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
1131 
1132   // Lower a batch dot into a sequence of non-batch dot operations.
1133 
1134   int64 num_batch_dims =
1135       dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1136 
1137   // First reshape the inputs to make sure we only have one batch dimension.
1138   // This is a no-op bitcast because the operands have to be in row-major layout
1139   // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
1140   // dimensions (established by DotDecomposer and checked by
1141   // ValidateDotDimensionNumbers above).
1142   llvm_ir::IrArray lhs_array_reshaped =
1143       CollapseFirstNDims(b, lhs_array, num_batch_dims);
1144   llvm_ir::IrArray rhs_array_reshaped =
1145       CollapseFirstNDims(b, rhs_array, num_batch_dims);
1146   llvm_ir::IrArray target_array_reshaped =
1147       CollapseFirstNDims(b, target_array, num_batch_dims);
1148 
1149   int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0);
1150 
1151   KernelSupportLibrary ksl(b);
1152 
1153   return ksl.ForWithStatus(
1154       llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
1155       /*step=*/1, [&](llvm::Value* indvar) {
1156         DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
1157         adjusted_dim_numbers.clear_lhs_batch_dimensions();
1158         adjusted_dim_numbers.clear_rhs_batch_dimensions();
1159 
1160         // Create a DotInfo representing the "inner" non-batch dot operation.
1161         DotInfo dot_info;
1162         dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1163         dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1164         dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
1165         dot_info.dim_nums = dot.dot_dimension_numbers();
1166         dot_info.dim_nums.clear_lhs_batch_dimensions();
1167         dot_info.dim_nums.clear_rhs_batch_dimensions();
1168 
1169         dot_info.dim_nums.set_lhs_contracting_dimensions(
1170             0,
1171             dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1172         dot_info.dim_nums.set_rhs_contracting_dimensions(
1173             0,
1174             dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1175 
1176         llvm_ir::IrArray lhs_slice =
1177             SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
1178         llvm_ir::IrArray rhs_slice =
1179             SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
1180         llvm_ir::IrArray target_slice = SliceOutInnerArray(
1181             target_array_reshaped, /*batch_index=*/indvar, b);
1182 
1183         // Emit the inner non-batch dot operation.
1184         return EmitNonBatchDotOperation(
1185             dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
1186             executable_run_options_value, b, mlir_context, hlo_module_config,
1187             target_machine_features);
1188       });
1189 }
1190 
IsBatchDot(const HloInstruction & instr)1191 bool IsBatchDot(const HloInstruction& instr) {
1192   if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
1193     return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
1194   }
1195 
1196   return false;
1197 }
1198 }  // namespace
1199 
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1200 bool DotImplementationCanHandleTranspose(
1201     const HloInstruction& dot_instr,
1202     const TargetMachineFeatures& target_machine_features) {
1203   DotImplementationStrategy impl_strategy =
1204       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1205                                    DotInfo(dot_instr), target_machine_features);
1206 
1207   return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
1208          impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
1209          impl_strategy == DotImplementationStrategy::kEigen;
1210 }
1211 
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1212 bool DotOperandsAndResultMustHaveRowMajorLayout(
1213     const HloInstruction& dot_instr,
1214     const TargetMachineFeatures& target_machine_features) {
1215   // Batched dots require the batch dimensions to be major. DotDecomposer always
1216   // moves batch dimensions to the front of the shape, so force a row-major
1217   // layout.
1218   if (IsBatchDot(dot_instr)) {
1219     return true;
1220   }
1221 
1222   DotImplementationStrategy impl_strategy =
1223       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1224                                    DotInfo(dot_instr), target_machine_features);
1225 
1226   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1227          impl_strategy == DotImplementationStrategy::kEigen;
1228 }
1229 
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1230 Status EmitDotOperation(const HloInstruction& dot,
1231                         const llvm_ir::IrArray& target_array,
1232                         const llvm_ir::IrArray& lhs_array,
1233                         const llvm_ir::IrArray& rhs_array,
1234                         const llvm_ir::IrArray* addend_array,
1235                         llvm::Value* executable_run_options_value,
1236                         llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
1237                         const HloModuleConfig& hlo_module_config,
1238                         const TargetMachineFeatures& target_machine_features) {
1239   // This routine assumes that the dot operation is not in a parallelized
1240   // enclosing computation.
1241   CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1242 
1243   if (IsBatchDot(dot)) {
1244     TF_RET_CHECK(addend_array == nullptr);
1245     return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1246                                  executable_run_options_value, b, mlir_context,
1247                                  hlo_module_config, target_machine_features);
1248   }
1249 
1250   return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1251                                   lhs_array, rhs_array, addend_array,
1252                                   executable_run_options_value, b, mlir_context,
1253                                   hlo_module_config, target_machine_features);
1254 }
1255 }  // namespace cpu
1256 }  // namespace xla
1257