1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project
27 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" // from @llvm-project
28 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project
29 #include "mlir/EDSC/Builders.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
38 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
39 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
40 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
41 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
42 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
47 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/platform/logging.h"
53
54 namespace xla {
55
56 using llvm_ir::SetToFirstInsertPoint;
57
58 namespace cpu {
59 namespace {
60 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)61 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
62 return config.debug_options().xla_cpu_multi_thread_eigen();
63 }
64
65 // Represents a dot operation. We use this in lieu of an `HloInstruction`
66 // because we want to be able to create this for the "inner" dot operation in a
67 // batch dot, for which there is no separate HLO instruction.
68 struct DotInfo {
69 Shape lhs_shape;
70 Shape rhs_shape;
71 Shape result_shape;
72 DotDimensionNumbers dim_nums;
73
74 DotInfo() = default;
75
DotInfoxla::cpu::__anonb70c17dc0111::DotInfo76 explicit DotInfo(const HloInstruction& instr) {
77 CHECK_EQ(instr.opcode(), HloOpcode::kDot);
78 lhs_shape = instr.operand(0)->shape();
79 rhs_shape = instr.operand(1)->shape();
80 result_shape = instr.shape();
81 dim_nums = instr.dot_dimension_numbers();
82 }
83 };
84
85 // Dictates how a dot operation is implemented.
86 enum class DotImplementationStrategy {
87 // The dot operation is lowered into LLVM IR that implements a naive nested
88 // loop that computes the result one element at a time. This is our
89 // "fallback"; we don't really want this to kick in for any non-trival dot
90 // operation.
91 kNaiveLlvmIr,
92
93 // The dot operation is lowered into LLVM IR that implements a tiled
94 // Matrix*Vector operation. This strategy also allows fusing in a bias add
95 // into the dot. The matrix can be row major or column major, both are
96 // supported.
97 kTiledLlvmIrGemv,
98
99 // The dot operation is lowered into LLVM IR that implements a tiled
100 // Matrix*Matrix operation. No fusions are supported. The two inputs
101 // and the output have to be row major.
102 kTiledLlvmIrGemm,
103
104 // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
105 kLinalgMatmul,
106
107 // The dot operation is lowered into a call into an Eigen routine. No fusions
108 // are supported today. The two inputs and the output have to be row major.
109 // However, we do allow transposing either the LHS or the RHS as part of the
110 // GEMM -- we expose this flexibility as flexibility in the contraction
111 // dimensions, but we can also see this as flexibility in the input layouts.
112 kEigen,
113 };
114
115 // Returns the implementation strategy for a dot with the configuration
116 // `dot_info`.
117 DotImplementationStrategy GetDotImplementationStrategy(
118 const HloModuleConfig& config, const DotInfo& dot_info,
119 const TargetMachineFeatures& target_machine_features);
120
121 // Helper class for emitting LLVM IR to perform the dot operation.
122 class DotOpEmitter {
123 public:
124 explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
125 const llvm_ir::IrArray& target_array,
126 const llvm_ir::IrArray& lhs_array,
127 const llvm_ir::IrArray& rhs_array,
128 const llvm_ir::IrArray* addend_array,
129 llvm::Value* executable_run_options_value,
130 llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
131 const HloModuleConfig& hlo_module_config,
132 const TargetMachineFeatures& target_machine_features);
133
134 // Emits the IR to perform the dot operation.
135 Status Emit();
136
137 private:
138 // Emits instructions to perform a scalar dot product (a multiply of the
139 // LHS and RHS) and store the results in the target.
140 Status EmitScalarDot();
141
142 // Emits a call to the CPU runtime to perform the matrix multiply.
143 Status EmitCallToRuntime();
144
145 // Represents the dimensions of a matrix-matrix multiply operation.
146 struct MatMultDims {
147 // The number of rows in the LHS.
148 int64 m;
149
150 // The number of columns in the LHS, which is also must be equal to the
151 // number of rows in the RHS.
152 int64 k;
153
154 // The number of columns on the RHS.
155 int64 n;
156
157 // True if the LHS matrix is column major.
158 bool lhs_column_major;
159
160 // True if the LHS contraction dimension is 1.
161 bool lhs_canonical;
162
163 // True if the RHS matrix is column major.
164 bool rhs_column_major;
165
166 // True if the RHS contraction dimension is 0.
167 bool rhs_canonical;
168 };
169
170 // Get the MatMultDims instance for the dot product this DotOpEmitter
171 // represents. Precondition: the dot is of rank 2 (and thus its operands are
172 // of rank 2 as well).
173 MatMultDims GetMatMultDims() const;
174
175 // Lowers the dot operation as a tiled Matrix*Vector loop.
176 void EmitTiledLlvmIrGemv();
177
178 // Lowers the dot operation as a tiled Matrix*Matrix loop.
179 void EmitTiledLlvmIrGemm();
180
181 // Lowers the dot operation through MLIR's linalg.matmul.
182 Status EmitLinalgMatmul();
183
184 // Lowers the dot operation as a naive nested loop that computes the result
185 // one element at a time.
186 void EmitNaiveLlvmIrGemm();
187
188 // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
189 // registers.
GetGemvTilingFactor() const190 int64 GetGemvTilingFactor() const {
191 const int64 kDefaultTilingFactor = 8;
192 return options::LlvmIrGemvTilingFactor(hlo_module_config_)
193 .value_or(kDefaultTilingFactor);
194 }
195
GetGemmTileSize() const196 std::tuple<int64, int64, int64> GetGemmTileSize() const {
197 // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
198 //
199 // TODO(b/80093688): Tune for other architectures and centralize this
200 // information in one place.
201 const std::tuple<int64, int64, int64> kDefaultTileSize =
202 std::tuple<int64, int64, int64>(11, 9, 1);
203 return options::LlvmIrGemmTileSize(hlo_module_config_)
204 .value_or(kDefaultTileSize);
205 }
206
GetMlirGemmTileSize() const207 std::array<int64_t, 3> GetMlirGemmTileSize() const {
208 // Tile by 4 x registers x register size. This was picked by running
209 // small matmuls on Haswell and Skylake. There's a lot of room for
210 // improvement here.
211 constexpr int64_t kDefaultTileSizeForM = 4;
212 int64_t elements_per_register =
213 target_machine_features_.vector_register_num_elements(
214 *b_->GetInsertBlock()->getParent(),
215 dot_info_.result_shape.element_type());
216 int64_t num_registers = target_machine_features_.vector_register_count(
217 *b_->GetInsertBlock()->getParent());
218 return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
219 }
220
221 DotInfo dot_info_;
222 string dot_hlo_name_;
223 const llvm_ir::IrArray& target_array_;
224 const llvm_ir::IrArray& lhs_array_;
225 const llvm_ir::IrArray& rhs_array_;
226 const llvm_ir::IrArray* addend_array_;
227 llvm::Value* executable_run_options_value_;
228 llvm::IRBuilder<>* b_;
229 mlir::MLIRContext* mlir_context_;
230 const HloModuleConfig& hlo_module_config_;
231 const TargetMachineFeatures& target_machine_features_;
232 };
233 } // namespace
234
DotOpEmitter(DotInfo dot_info,string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)235 DotOpEmitter::DotOpEmitter(
236 DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
237 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
238 const llvm_ir::IrArray* addend_array,
239 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
240 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
241 const TargetMachineFeatures& target_machine_features)
242 : dot_info_(std::move(dot_info)),
243 dot_hlo_name_(std::move(dot_hlo_name)),
244 target_array_(target_array),
245 lhs_array_(lhs_array),
246 rhs_array_(rhs_array),
247 addend_array_(addend_array),
248 executable_run_options_value_(executable_run_options_value),
249 b_(b),
250 mlir_context_(mlir_context),
251 hlo_module_config_(hlo_module_config),
252 target_machine_features_(target_machine_features) {}
253
EmitLinalgMatmul()254 Status DotOpEmitter::EmitLinalgMatmul() {
255 Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
256 llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
257 rhs_array_.GetBasePointer()};
258 llvm::Value* target_ptr = target_array_.GetBasePointer();
259
260 // Zero out the output buffer.
261 int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
262 b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
263 /*Align=*/llvm::MaybeAlign(1));
264
265 std::string name =
266 absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
267 dot_info_.lhs_shape.ToString(true), "_",
268 dot_info_.rhs_shape.ToString(true));
269
270 return EmitMlirFuncAndCall(
271 mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
272 operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
273 CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
274 CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
275 mlir::MLIRContext* context = builder->getContext();
276 mlir::edsc::ScopedContext scope(*builder, function.getLoc());
277 mlir::Value a = function.getArgument(0), b = function.getArgument(1),
278 c = function.getArgument(2);
279
280 llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
281 dot_info_.lhs_shape.rank());
282 llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
283 dot_info_.rhs_shape.rank());
284
285 llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
286 mlir::AffineExpr reduce_expr;
287 for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
288 parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
289 }
290 reduce_expr =
291 mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
292
293 // The reduction expr is shared for both inputs.
294 b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
295 c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
296
297 // Fill in the remaining parallel exprs.
298 int par_expr_num = 0;
299 for (auto* v : {&b_exprs, &c_exprs}) {
300 for (auto& e : *v) {
301 if (!e) {
302 e = parallel_exprs[par_expr_num++];
303 }
304 }
305 }
306
307 llvm::SmallVector<mlir::IteratorType, 4> iteratorTypes(
308 parallel_exprs.size(), mlir::IteratorType::Parallel);
309 iteratorTypes.push_back(mlir::IteratorType::Reduction);
310
311 mlir::edsc::StructuredIndexed s_a(a), s_b(b), s_c(c);
312 mlir::edsc::makeGenericLinalgOp(
313 /*iteratorTypes=*/iteratorTypes,
314 /*inputs=*/{s_b(b_exprs), s_c(c_exprs)},
315 /*outputs=*/{s_a(parallel_exprs)},
316 /*resultTensorTypes=*/{}, mlir::edsc::ops::macRegionBuilder);
317 mlir::edsc::intrinsics::std_ret();
318
319 mlir::linalg::LinalgTilingOptions tilingOptions;
320 tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
321 int64 alignment =
322 target_machine_features_.minimum_alignment_for_allocation(
323 ShapeUtil::ByteSizeOf(dot_info_.result_shape));
324 mlir::linalg::CodegenStrategy strategy;
325 strategy.tile<mlir::linalg::GenericOp>(tilingOptions)
326 .promote<mlir::linalg::GenericOp>(
327 mlir::linalg::LinalgPromotionOptions()
328 .setAlignment(alignment)
329 .setUseFullTileBuffersByDefault(true)
330 .setUseAlloca(true))
331 .vectorize<mlir::linalg::GenericOp>()
332 .setVectorTransformsOptions(
333 mlir::vector::VectorTransformsOptions()
334 .setVectorTransformsOptions(
335 mlir::vector::VectorContractLowering::OuterProduct))
336 .setVectorTransferToSCFOptions(
337 mlir::VectorTransferToSCFOptions().setUnroll(true));
338 strategy.transform(function);
339 });
340 }
341
EmitTiledLlvmIrGemm()342 void DotOpEmitter::EmitTiledLlvmIrGemm() {
343 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
344 MatMultDims mat_mult_dims = GetMatMultDims();
345
346 llvm::Value* lhs = lhs_array_.GetBasePointer();
347 llvm::Value* rhs = rhs_array_.GetBasePointer();
348 llvm::Value* target = target_array_.GetBasePointer();
349 int64 m = mat_mult_dims.m;
350 int64 k = mat_mult_dims.k;
351 int64 n = mat_mult_dims.n;
352
353 if (mat_mult_dims.lhs_column_major) {
354 std::swap(lhs, rhs);
355 std::swap(m, n);
356 }
357
358 int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
359 b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
360 /*Align=*/llvm::MaybeAlign(1));
361
362 int64 max_target_vector_width =
363 target_machine_features_.vector_register_num_elements(
364 *b_->GetInsertBlock()->getParent(), primitive_type);
365
366 int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
367 std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
368 GetGemmTileSize();
369
370 EmitSmallGemm(
371 /*scalar_type=*/primitive_type,
372 /*m=*/m, /*k=*/k, /*n=*/n,
373 /*max_vectorization_width=*/max_target_vector_width,
374 /*max_vector_count=*/tile_size_n_in_vector_width,
375 /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
376 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
377 /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
378 }
379
EmitTiledLlvmIrGemv()380 void DotOpEmitter::EmitTiledLlvmIrGemv() {
381 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
382
383 CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
384 primitive_util::IsIntegralType(primitive_type));
385
386 MatMultDims mat_mult_dims = GetMatMultDims();
387 bool is_column_major_matrix_vector_gemv = false;
388 bool is_row_major_matrix_vector_gemv = false;
389
390 int64 m, k;
391 bool swap_operands;
392
393 if (mat_mult_dims.m == 1) {
394 // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
395 // we actually want V*M. We implement V*M as follows (Tr(X) = Transpose of
396 // X):
397 //
398 // V*M = Tr(Tr(V*M)) // Tr(Tr(X)) == X
399 // = Tr(Tr(M) * Tr(V)) // Tr(A * B) == Tr(B) * Tr(A)
400 //
401 // Since transposing a vector is physically a no-op, this is really
402 // equivalent to `Tr(M) * V`. We further implement Tr(M) by pretending that
403 // M is row major if it is actually column major and vice-versa.
404
405 bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
406 ? mat_mult_dims.rhs_column_major
407 : !mat_mult_dims.rhs_column_major;
408
409 if (rhs_effectively_column_major) {
410 k = mat_mult_dims.k;
411 m = mat_mult_dims.n;
412
413 // We set is_row_major_matrix_vector_gemv and not
414 // is_column_major_matrix_vector_gemv to implement the Transpose trick
415 // mentioned above.
416 is_row_major_matrix_vector_gemv = true;
417 swap_operands = true;
418 } else {
419 k = mat_mult_dims.k;
420 m = mat_mult_dims.n;
421
422 // We set is_column_major_matrix_vector_gemv and not
423 // is_row_major_matrix_vector_gemv to implement the Transpose trick
424 // mentioned above.
425 is_column_major_matrix_vector_gemv = true;
426 swap_operands = true;
427 }
428 }
429
430 if (mat_mult_dims.n == 1) {
431 bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
432 ? mat_mult_dims.lhs_column_major
433 : !mat_mult_dims.lhs_column_major;
434
435 if (lhs_effectively_column_major) {
436 m = mat_mult_dims.m;
437 k = mat_mult_dims.k;
438 is_column_major_matrix_vector_gemv = true;
439 swap_operands = false;
440 } else {
441 m = mat_mult_dims.m;
442 k = mat_mult_dims.k;
443 is_row_major_matrix_vector_gemv = true;
444 swap_operands = false;
445 }
446 }
447
448 CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
449
450 int64 tiling_factor = GetGemvTilingFactor();
451 CHECK_GT(tiling_factor, 0);
452
453 llvm::Value* result_op = target_array_.GetBasePointer();
454 llvm::Value* lhs_op =
455 swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
456 llvm::Value* rhs_op =
457 swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
458
459 const int target_vector_register_element_size =
460 target_machine_features_.vector_register_num_elements(
461 *b_->GetInsertBlock()->getParent(), primitive_type);
462
463 // We may not always know the vector register size for the target we're
464 // compiling against, in which case target_vector_register_element_size is 0.
465 // In these cases we choose a default LLVM IR register size.
466 const int kUnknownTargetVectorRegisterSize = 4;
467 const int vector_register_element_size =
468 target_vector_register_element_size == 0
469 ? kUnknownTargetVectorRegisterSize
470 : target_vector_register_element_size;
471
472 if (is_column_major_matrix_vector_gemv) {
473 VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
474 << " and k = " << k;
475 EmitColumnMajorGemv(
476 /*scalar_type=*/primitive_type,
477 /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
478 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
479 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
480 /*result=*/result_op, b_, hlo_module_config_);
481 } else {
482 VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
483 << " and k = " << k;
484 EmitRowMajorGemv(
485 /*scalar_type=*/primitive_type,
486 /*tile_rows=*/tiling_factor,
487 /*tile_cols=*/vector_register_element_size,
488 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
489 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
490 /*result=*/result_op, b_, hlo_module_config_);
491 }
492 }
493
Emit()494 Status DotOpEmitter::Emit() {
495 // The dot operation performs a sum of products over dimension 0 of the left
496 // hand side operand and dimension 1 of the right hand side operand.
497 //
498 // Let the shapes of lhs and rhs be defined as below:
499 //
500 // lhs = [L{n-1} x L{n-2} x ... L{0}]
501 // rhs = [R{m-1} x R{m-2} x ... R{0}]
502 //
503 // The sum-of-products dimension in the lhs has size L{0} and the dimension in
504 // the rhs has size R{1}. Necessarily, then:
505 //
506 // L{0} == R{1}
507 //
508 // The output of the operation has the following shape:
509 //
510 // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
511 //
512 // To perform the operation we construct a loop nest with one for-loop for
513 // each dimension of the output. Inside this loop nest is another for-loop
514 // which performs the sum-of-products (the reduction loop) before storing
515 // the result in the output buffer.
516
517 const Shape& lhs_shape = lhs_array_.GetShape();
518 const Shape& rhs_shape = rhs_array_.GetShape();
519
520 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
521 // If the operands are scalar, don't emit any loops.
522 TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
523 ShapeUtil::IsScalar(rhs_shape));
524 return EmitScalarDot();
525 }
526
527 switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
528 target_machine_features_)) {
529 case DotImplementationStrategy::kNaiveLlvmIr:
530 EmitNaiveLlvmIrGemm();
531 return Status::OK();
532
533 case DotImplementationStrategy::kTiledLlvmIrGemv:
534 EmitTiledLlvmIrGemv();
535 return Status::OK();
536
537 case DotImplementationStrategy::kTiledLlvmIrGemm:
538 EmitTiledLlvmIrGemm();
539 return Status::OK();
540
541 case DotImplementationStrategy::kLinalgMatmul:
542 return EmitLinalgMatmul();
543
544 case DotImplementationStrategy::kEigen:
545 return EmitCallToRuntime();
546 }
547 }
548
EmitNaiveLlvmIrGemm()549 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
550 CHECK_EQ(addend_array_, nullptr);
551
552 const Shape& lhs_shape = lhs_array_.GetShape();
553 const Shape& rhs_shape = rhs_array_.GetShape();
554 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
555
556 // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
557 // case where the reduction dimension is 0 for both LHS and RHS. This results
558 // in a vector dot product producing a scalar.
559 int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
560 int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
561
562 // Verify the reduction dimension in the two operands are the same size.
563 CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
564 rhs_shape.dimensions(rhs_reduction_dimension));
565
566 bool lhs_reduction_along_minor_dimension =
567 lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
568 bool rhs_reduction_along_minor_dimension =
569 rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
570
571 // Create loop nests which loop through the LHS operand dimensions and the RHS
572 // operand dimensions. The reduction dimension of the LHS and RHS are handled
573 // in a separate innermost loop which performs the sum of products.
574 llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
575 std::vector<llvm::Value*> lhs_multi_index =
576 loop_nest.EmitOperandArrayLoopNest(
577 lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
578 std::vector<llvm::Value*> rhs_multi_index =
579 loop_nest.EmitOperandArrayLoopNest(
580 rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
581
582 // Create the loop which does the sum of products reduction.
583 //
584 // The prevent_unrolling bit is working around a deficiency in LLVM's loop
585 // vectorization pipeline, wherein in some cases unrolling a loop can prevent
586 // effective vectorization. Since we know that the IR we generate when
587 // reducing across the minor dimension in both LHS and RHS is vectorized well
588 // by the loop vectorizer, we block unrolling in that case to stop loop unroll
589 // from messing up the vectorization.
590 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
591 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
592 /*unroll_mode=*/
593 (lhs_reduction_along_minor_dimension &&
594 rhs_reduction_along_minor_dimension)
595 ? xla::llvm_ir::UnrollMode::kNoUnroll
596 : xla::llvm_ir::UnrollMode::kDefaultUnroll);
597
598 // The final entry in the rhs and lhs indexes is the indvar of the
599 // reduction loop.
600 lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
601 llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
602 b_->getInt64Ty());
603 rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
604 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
605 b_->getInt64Ty());
606
607 // For computing the sum of products we alloca a single location to store the
608 // dot product result as we accumulate it within the reduction loop. After the
609 // reduction loop we load the result and store into the output array.
610
611 // Function entry basic block.
612 // - Emit alloca for accumulator
613 llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
614 SetToFirstInsertPoint(&func->getEntryBlock(), b_);
615 llvm::Type* accum_type = target_array_.GetElementLlvmType();
616 llvm::Value* accum_address =
617 b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
618
619 // Preheader basic block of reduction loop:
620 // - Initialize accumulator to zero.
621 llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
622 b_->SetInsertPoint(preheader_bb->getTerminator());
623
624 b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
625
626 // Body basic block of reduction loop:
627 // - Load elements from lhs and rhs array.
628 // - Multiply lhs-element and rhs-element.
629 // - Load accumulator and add to product.
630 // - Store sum back into accumulator.
631 SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
632
633 llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
634 llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
635
636 llvm::Value* accum = b_->CreateLoad(accum_address);
637 llvm::Value* updated_accum;
638 if (ShapeUtil::ElementIsComplex(lhs_shape)) {
639 auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
640 auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
641 llvm::Value* product_real =
642 b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
643 b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
644 llvm::Value* product_imag =
645 b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
646 b_->CreateFMul(imag(lhs_element), real(rhs_element)));
647 updated_accum = b_->CreateInsertValue(
648 accum, b_->CreateFAdd(real(accum), product_real), {0});
649 updated_accum = b_->CreateInsertValue(
650 updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
651 } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
652 llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
653 updated_accum = b_->CreateAdd(accum, product);
654 } else if (lhs_shape.element_type() == PRED) {
655 llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
656 updated_accum = b_->CreateOr(accum, product);
657 } else {
658 llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
659 updated_accum = b_->CreateFAdd(accum, product);
660 }
661 b_->CreateStore(updated_accum, accum_address);
662
663 // Exit basic block of reduction loop.
664 // - Load accumulator value (the result).
665 // - Store into output array.
666 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
667
668 llvm::Value* result = b_->CreateLoad(accum_address);
669
670 // Create index into target address. The target index is the concatenation of
671 // the rhs and lhs indexes with the reduction dimensions removed. The terms
672 // from the rhs index are the lower dimensions in the index so we add them
673 // first.
674 std::vector<llvm::Value*> target_multi_index;
675 for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
676 if (dimension != lhs_reduction_dimension) {
677 target_multi_index.push_back(lhs_index[dimension]);
678 }
679 }
680 for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
681 if (dimension != rhs_reduction_dimension) {
682 target_multi_index.push_back(rhs_index[dimension]);
683 }
684 }
685
686 llvm_ir::IrArray::Index target_index(
687 target_multi_index, target_array_.GetShape(), lhs_index.GetType());
688 target_array_.EmitWriteArrayElement(target_index, result, b_);
689
690 // Set the IR builder insert point to the exit basic block of the outer most
691 // loop.
692 b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
693 }
694
EmitScalarDot()695 Status DotOpEmitter::EmitScalarDot() {
696 // A scalar dot is just a scalar multiply.
697 llvm::Value* result;
698 // Use the same index_type for all tensor accesses in the same kernel.
699 llvm::Type* index_type = b_->getInt64Ty();
700 llvm_ir::IrArray::Index element_index(index_type);
701 llvm::Value* lhs_value =
702 lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
703 llvm::Value* rhs_value =
704 rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
705 if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
706 auto get_real = [&](llvm::Value* x) {
707 return b_->CreateExtractValue(x, {0});
708 };
709
710 auto get_imag = [&](llvm::Value* x) {
711 return b_->CreateExtractValue(x, {1});
712 };
713
714 llvm::Value* real = b_->CreateFSub(
715 b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
716 b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
717 llvm::Value* imag = b_->CreateFAdd(
718 b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
719 b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
720 result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
721 result = b_->CreateInsertValue(result, real, {0});
722 result = b_->CreateInsertValue(result, imag, {1});
723 } else {
724 result = b_->CreateFMul(lhs_value, rhs_value);
725 }
726 target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
727 return Status::OK();
728 }
729
EmitCallToRuntime()730 Status DotOpEmitter::EmitCallToRuntime() {
731 // The signature of the Eigen runtime matmul function is:
732 //
733 // (void)(void* run_options, float* out, float* lhs, float* rhs,
734 // int64 m, int64 n, int64 k, int32 transpose_lhs,
735 // int32 transpose_rhs);
736 // The two transpose_... parameters are actually booleans, but we use int32
737 // to avoid target-dependent calling convention details.
738
739 bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
740 bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
741 PrimitiveType type = target_array_.GetShape().element_type();
742 llvm::Function* function = b_->GetInsertBlock()->getParent();
743 llvm::Module* module = function->getParent();
744 llvm::Type* float_type;
745 const char* fn_name;
746 switch (type) {
747 case F16:
748 fn_name = multi_threaded
749 ? runtime::kEigenMatMulF16SymbolName
750 : runtime::kEigenSingleThreadedMatMulF16SymbolName;
751 float_type = b_->getHalfTy();
752 break;
753 case F32:
754 fn_name = multi_threaded
755 ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
756 : runtime::kEigenMatMulF32SymbolName)
757 : (use_mkl_dnn
758 ? runtime::kMKLSingleThreadedMatMulF32SymbolName
759 : runtime::kEigenSingleThreadedMatMulF32SymbolName);
760 float_type = b_->getFloatTy();
761 break;
762 case F64:
763 fn_name = multi_threaded
764 ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
765 : runtime::kEigenMatMulF64SymbolName)
766 : (use_mkl_dnn
767 ? runtime::kMKLSingleThreadedMatMulF64SymbolName
768 : runtime::kEigenSingleThreadedMatMulF64SymbolName);
769 float_type = b_->getDoubleTy();
770 break;
771 case C64:
772 fn_name = multi_threaded
773 ? runtime::kEigenMatMulC64SymbolName
774 : runtime::kEigenSingleThreadedMatMulC64SymbolName;
775 float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
776 break;
777 case C128:
778 fn_name = multi_threaded
779 ? runtime::kEigenMatMulC128SymbolName
780 : runtime::kEigenSingleThreadedMatMulC128SymbolName;
781 float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
782 break;
783 case S32:
784 fn_name = multi_threaded
785 ? runtime::kEigenMatMulS32SymbolName
786 : runtime::kEigenSingleThreadedMatMulS32SymbolName;
787 float_type = b_->getInt32Ty();
788 break;
789 default:
790 return Unimplemented("Invalid type %s for dot operation",
791 PrimitiveType_Name(type));
792 }
793
794 llvm::Type* float_ptr_type = float_type->getPointerTo();
795 llvm::Type* int64_type = b_->getInt64Ty();
796 llvm::Type* int32_type = b_->getInt32Ty();
797 llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
798 llvm::FunctionType* matmul_type = llvm::FunctionType::get(
799 b_->getVoidTy(),
800 {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
801 int64_type, int64_type, int64_type, int32_type, int32_type},
802 /*isVarArg=*/false);
803
804 llvm::FunctionCallee matmul_func =
805 module->getOrInsertFunction(fn_name, matmul_type);
806 if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
807 fn->setCallingConv(llvm::CallingConv::C);
808 fn->setDoesNotThrow();
809 fn->setOnlyAccessesArgMemory();
810 }
811
812 // The Eigen runtime function expects column-major layout. If the matrices are
813 // row major, then use the following identity to compute the product:
814 //
815 // (A x B)^T = B^T x A^T
816 //
817 // The connection between this identity and memory layout is that the
818 // transpose operation can also be considered as an operation that changes the
819 // memory layout of a matrix from row-major to column-major or vice versa.
820 //
821 // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
822
823 MatMultDims mat_mult_dims = GetMatMultDims();
824
825 CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
826
827 const llvm_ir::IrArray* lhs = &lhs_array_;
828 const llvm_ir::IrArray* rhs = &rhs_array_;
829 bool transpose_lhs = !mat_mult_dims.lhs_canonical;
830 bool transpose_rhs = !mat_mult_dims.rhs_canonical;
831
832 if (!mat_mult_dims.lhs_column_major) {
833 std::swap(mat_mult_dims.m, mat_mult_dims.n);
834 std::swap(lhs, rhs);
835 std::swap(transpose_lhs, transpose_rhs);
836 }
837
838 b_->CreateCall(
839 matmul_func,
840 {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
841 b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
842 b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
843 b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
844 b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
845 b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
846 b_->getInt32(transpose_rhs)});
847 return Status::OK();
848 }
849
GetMatMultDims() const850 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
851 CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
852
853 const Shape& lhs_shape = lhs_array_.GetShape();
854 const Shape& rhs_shape = rhs_array_.GetShape();
855 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
856
857 auto is_column_major = [](const Shape& shape) {
858 return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
859 };
860
861 // Non-contracting dots should never make it here.
862 CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
863 CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
864
865 return {
866 /*m=*/lhs_shape.rank() <= 1
867 ? 1LL
868 : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
869 /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
870 /*n=*/rhs_shape.rank() <= 1
871 ? 1LL
872 : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
873 /*lhs_column_major=*/is_column_major(lhs_shape),
874 /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
875 dim_nums.lhs_contracting_dimensions(0) == 1,
876 /*rhs_column_major=*/is_column_major(rhs_shape),
877 /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
878 }
879
880 // For vector-matrix dot products, it is always profitable to make the Rhs
881 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)882 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
883 const HloInstruction& hlo) {
884 if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
885 if (hlo.operand(0)->shape().rank() != 1 ||
886 hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
887 return {};
888 }
889
890 // Don't bother if the other operand is tiny, switching to column major
891 // wouldn't use tiling.
892 constexpr int kColumnMajorThresholdInBytes = 32;
893 int64 lhs_size =
894 ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
895 ShapeUtil::ElementsIn(hlo.operand(0)->shape());
896 if (lhs_size < kColumnMajorThresholdInBytes) {
897 return {};
898 }
899
900 return 1;
901 }
902
903 if (hlo.IsOutputFusion()) {
904 auto* fusion_root =
905 hlo.fused_instructions_computation()->root_instruction();
906 if (fusion_root->opcode() != HloOpcode::kAdd) {
907 return {};
908 }
909
910 for (auto* fusion_root_op : fusion_root->operands()) {
911 if (fusion_root_op->opcode() != HloOpcode::kDot) {
912 continue;
913 }
914 if (auto operand_num =
915 ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
916 auto* operand = fusion_root_op->operand(*operand_num);
917 if (operand->opcode() == HloOpcode::kParameter &&
918 operand->user_count() == 1) {
919 return operand->parameter_number();
920 }
921 }
922 }
923 }
924
925 return {};
926 }
927
928 namespace {
929 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)930 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
931
IsSimpleLayout(const Layout & layout)932 bool IsSimpleLayout(const Layout& layout) {
933 return layout.tiles().empty() && layout.format() == DENSE;
934 }
935
936 // In a gemm operation where output = lhs * rhs, check whether the given shapes
937 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)938 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
939 const Shape& output_shape,
940 const TargetMachineFeatures& target_machine_features) {
941 CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
942 << lhs_shape.DebugString();
943 CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
944 << rhs_shape.DebugString();
945 CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
946 << output_shape.DebugString();
947
948 switch (output_shape.element_type()) {
949 case F16:
950 case F32:
951 case F64:
952 case C64:
953 case C128:
954 case S32:
955 return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
956 default:
957 return false;
958 }
959 }
960
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)961 bool IsAlignedGemm(const DotInfo& dot_info,
962 const TargetMachineFeatures& target_machine_features) {
963 if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
964 ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
965 return false;
966 }
967
968 return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
969 dot_info.result_shape, target_machine_features);
970 }
971
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)972 bool CanEmitTiledLlvmIrGemm(
973 const HloModuleConfig& config, const DotInfo& dot_info,
974 const TargetMachineFeatures& target_machine_features) {
975 CHECK(IsAlignedGemm(dot_info, target_machine_features));
976
977 if (ShouldUseMultiThreadedEigen(config)) {
978 return false;
979 }
980
981 int m = dot_info.result_shape.dimensions(0);
982 int k = dot_info.lhs_shape.dimensions(
983 dot_info.dim_nums.lhs_contracting_dimensions(0));
984 int n = dot_info.result_shape.dimensions(1);
985
986 if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
987 // TODO(sanjoy): We should make these numbers micro-arch specific.
988 bool small_gemm =
989 k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
990 if (!small_gemm) {
991 return false;
992 }
993 }
994
995 bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
996 bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
997
998 if (!(lhs_canonical && rhs_canonical)) {
999 return false;
1000 }
1001
1002 if (dot_info.result_shape.element_type() == F16 ||
1003 dot_info.result_shape.element_type() == C64 ||
1004 dot_info.result_shape.element_type() == C128) {
1005 // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
1006 // adding this comment NFC.
1007 return false;
1008 }
1009
1010 return true;
1011 }
1012
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1013 DotImplementationStrategy GetDotImplementationStrategy(
1014 const HloModuleConfig& config, const DotInfo& dot_info,
1015 const TargetMachineFeatures& target_machine_features) {
1016 PrimitiveType element_type = dot_info.result_shape.element_type();
1017 // Any Matrix-Vector product of floating point or integral type, or
1018 // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1019 // IR implementation.
1020 if ((dot_info.result_shape.dimensions_size() <= 1 ||
1021 (dot_info.result_shape.dimensions_size() == 2 &&
1022 (dot_info.result_shape.dimensions(0) == 1 ||
1023 dot_info.result_shape.dimensions(1) == 1))) &&
1024 (primitive_util::IsFloatingPointType(element_type) ||
1025 primitive_util::IsIntegralType(element_type))) {
1026 return DotImplementationStrategy::kTiledLlvmIrGemv;
1027 }
1028
1029 if (IsAlignedGemm(dot_info, target_machine_features)) {
1030 if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
1031 return DotImplementationStrategy::kTiledLlvmIrGemm;
1032 }
1033 return DotImplementationStrategy::kEigen;
1034 }
1035
1036 return DotImplementationStrategy::kNaiveLlvmIr;
1037 }
1038
EmitNonBatchDotOperation(DotInfo dot_info,string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1039 Status EmitNonBatchDotOperation(
1040 DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
1041 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1042 const llvm_ir::IrArray* addend_array,
1043 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1044 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1045 const TargetMachineFeatures& target_machine_features) {
1046 PrimitiveType type = target_array.GetShape().element_type();
1047 TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
1048 U16 == type || S32 == type || U32 == type || S64 == type ||
1049 U64 == type || F16 == type || F32 == type || F64 == type ||
1050 C64 == type || C128 == type);
1051 DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
1052 target_array, lhs_array, rhs_array, addend_array,
1053 executable_run_options_value, b, mlir_context,
1054 hlo_module_config, target_machine_features);
1055 return dot_emitter.Emit();
1056 }
1057
DropFirstDim(const Shape & shape)1058 Shape DropFirstDim(const Shape& shape) {
1059 absl::Span<int64 const> array_shape_dims(shape.dimensions());
1060 array_shape_dims.remove_prefix(1);
1061 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1062 array_shape_dims);
1063 }
1064
CollapseFirstNDims(const Shape & shape,int64 n)1065 Shape CollapseFirstNDims(const Shape& shape, int64 n) {
1066 absl::Span<int64 const> input_shape_dims(shape.dimensions());
1067 int64 prefix_dim =
1068 std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1069 1ll, std::multiplies<int64>());
1070 DimensionVector result_dims;
1071 result_dims.push_back(prefix_dim);
1072 std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
1073 std::back_inserter(result_dims));
1074 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1075 result_dims);
1076 }
1077
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64 n)1078 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
1079 const llvm_ir::IrArray& array, int64 n) {
1080 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1081 const Shape& shape = array.GetShape();
1082 CHECK(shape.has_layout() &&
1083 LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
1084 CHECK_GE(shape.dimensions_size(), n);
1085 Shape new_shape = CollapseFirstNDims(shape, n);
1086 llvm::Value* new_value = b->CreateBitCast(
1087 array.GetBasePointer(),
1088 llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
1089 return llvm_ir::IrArray(new_value, std::move(new_shape));
1090 }
1091
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)1092 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
1093 // Checks some invariants that do not hold in general, but DotDecomposer
1094 // should have established for us. This is just a debugging aid.
1095 TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
1096 std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
1097 absl::c_iota(batch_dim_numbers, 0);
1098 TF_RET_CHECK(
1099 absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
1100 TF_RET_CHECK(
1101 absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
1102 return Status::OK();
1103 }
1104
1105 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)1106 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
1107 llvm::Value* batch_index,
1108 llvm::IRBuilder<>* b) {
1109 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1110
1111 Shape inner_shape = DropFirstDim(outer_array.GetShape());
1112 std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
1113 b->getInt64(0));
1114 multidim_index[0] = batch_index;
1115 llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
1116 batch_index->getType());
1117 llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
1118 llvm::Type* slice_ptr_type =
1119 llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
1120 return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
1121 std::move(inner_shape));
1122 }
1123
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1124 Status EmitBatchDotOperation(
1125 const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1126 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1127 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1128 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1129 const TargetMachineFeatures& target_machine_features) {
1130 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
1131
1132 // Lower a batch dot into a sequence of non-batch dot operations.
1133
1134 int64 num_batch_dims =
1135 dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1136
1137 // First reshape the inputs to make sure we only have one batch dimension.
1138 // This is a no-op bitcast because the operands have to be in row-major layout
1139 // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
1140 // dimensions (established by DotDecomposer and checked by
1141 // ValidateDotDimensionNumbers above).
1142 llvm_ir::IrArray lhs_array_reshaped =
1143 CollapseFirstNDims(b, lhs_array, num_batch_dims);
1144 llvm_ir::IrArray rhs_array_reshaped =
1145 CollapseFirstNDims(b, rhs_array, num_batch_dims);
1146 llvm_ir::IrArray target_array_reshaped =
1147 CollapseFirstNDims(b, target_array, num_batch_dims);
1148
1149 int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0);
1150
1151 KernelSupportLibrary ksl(b);
1152
1153 return ksl.ForWithStatus(
1154 llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
1155 /*step=*/1, [&](llvm::Value* indvar) {
1156 DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
1157 adjusted_dim_numbers.clear_lhs_batch_dimensions();
1158 adjusted_dim_numbers.clear_rhs_batch_dimensions();
1159
1160 // Create a DotInfo representing the "inner" non-batch dot operation.
1161 DotInfo dot_info;
1162 dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1163 dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1164 dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
1165 dot_info.dim_nums = dot.dot_dimension_numbers();
1166 dot_info.dim_nums.clear_lhs_batch_dimensions();
1167 dot_info.dim_nums.clear_rhs_batch_dimensions();
1168
1169 dot_info.dim_nums.set_lhs_contracting_dimensions(
1170 0,
1171 dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1172 dot_info.dim_nums.set_rhs_contracting_dimensions(
1173 0,
1174 dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1175
1176 llvm_ir::IrArray lhs_slice =
1177 SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
1178 llvm_ir::IrArray rhs_slice =
1179 SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
1180 llvm_ir::IrArray target_slice = SliceOutInnerArray(
1181 target_array_reshaped, /*batch_index=*/indvar, b);
1182
1183 // Emit the inner non-batch dot operation.
1184 return EmitNonBatchDotOperation(
1185 dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
1186 executable_run_options_value, b, mlir_context, hlo_module_config,
1187 target_machine_features);
1188 });
1189 }
1190
IsBatchDot(const HloInstruction & instr)1191 bool IsBatchDot(const HloInstruction& instr) {
1192 if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
1193 return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
1194 }
1195
1196 return false;
1197 }
1198 } // namespace
1199
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1200 bool DotImplementationCanHandleTranspose(
1201 const HloInstruction& dot_instr,
1202 const TargetMachineFeatures& target_machine_features) {
1203 DotImplementationStrategy impl_strategy =
1204 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1205 DotInfo(dot_instr), target_machine_features);
1206
1207 return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
1208 impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
1209 impl_strategy == DotImplementationStrategy::kEigen;
1210 }
1211
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1212 bool DotOperandsAndResultMustHaveRowMajorLayout(
1213 const HloInstruction& dot_instr,
1214 const TargetMachineFeatures& target_machine_features) {
1215 // Batched dots require the batch dimensions to be major. DotDecomposer always
1216 // moves batch dimensions to the front of the shape, so force a row-major
1217 // layout.
1218 if (IsBatchDot(dot_instr)) {
1219 return true;
1220 }
1221
1222 DotImplementationStrategy impl_strategy =
1223 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1224 DotInfo(dot_instr), target_machine_features);
1225
1226 return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1227 impl_strategy == DotImplementationStrategy::kEigen;
1228 }
1229
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1230 Status EmitDotOperation(const HloInstruction& dot,
1231 const llvm_ir::IrArray& target_array,
1232 const llvm_ir::IrArray& lhs_array,
1233 const llvm_ir::IrArray& rhs_array,
1234 const llvm_ir::IrArray* addend_array,
1235 llvm::Value* executable_run_options_value,
1236 llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
1237 const HloModuleConfig& hlo_module_config,
1238 const TargetMachineFeatures& target_machine_features) {
1239 // This routine assumes that the dot operation is not in a parallelized
1240 // enclosing computation.
1241 CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1242
1243 if (IsBatchDot(dot)) {
1244 TF_RET_CHECK(addend_array == nullptr);
1245 return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1246 executable_run_options_value, b, mlir_context,
1247 hlo_module_config, target_machine_features);
1248 }
1249
1250 return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1251 lhs_array, rhs_array, addend_array,
1252 executable_run_options_value, b, mlir_context,
1253 hlo_module_config, target_machine_features);
1254 }
1255 } // namespace cpu
1256 } // namespace xla
1257