1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
27 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
29 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
30 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/platform/logging.h"
41
42 namespace xla {
43
44 using llvm_ir::SetToFirstInsertPoint;
45
46 namespace cpu {
47 namespace {
48 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)49 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
50 return config.debug_options().xla_cpu_multi_thread_eigen();
51 }
52
53 // Represents a dot operation. We use this in lieu of an `HloInstruction`
54 // because we want to be able to create this for the "inner" dot operation in a
55 // batch dot, for which there is no separate HLO instruction.
56 struct DotInfo {
57 Shape lhs_shape;
58 Shape rhs_shape;
59 Shape result_shape;
60 DotDimensionNumbers dim_nums;
61
62 DotInfo() = default;
63
DotInfoxla::cpu::__anonb70c17dc0111::DotInfo64 explicit DotInfo(const HloInstruction& instr) {
65 CHECK_EQ(instr.opcode(), HloOpcode::kDot);
66 lhs_shape = instr.operand(0)->shape();
67 rhs_shape = instr.operand(1)->shape();
68 result_shape = instr.shape();
69 dim_nums = instr.dot_dimension_numbers();
70 }
71 };
72
73 // Dictates how a dot operation is implemented.
74 enum class DotImplementationStrategy {
75 // The dot operation is lowered into LLVM IR that implements a naive nested
76 // loop that computes the result one element at a time. This is our
77 // "fallback"; we don't really want this to kick in for any non-trival dot
78 // operation.
79 kNaiveLlvmIr,
80
81 // The dot operation is lowered into LLVM IR that implements a tiled
82 // Matrix*Vector operation. This strategy also allows fusing in a bias add
83 // into the dot. The matrix can be row major or column major, both are
84 // supported.
85 kTiledLlvmIrGemv,
86
87 // The dot operation is lowered into LLVM IR that implemetns a tiled
88 // Matrix*Matrix operation. No fusions are supported. The two inputs
89 // and the output have to be row major.
90 kTiledLlvmIrGemm,
91
92 // The dot operation is lowered into a call into an Eigen routine. No fusions
93 // are supported today. The two inputs and the output have to be row major.
94 // However, we do allow transposing either the LHS or the RHS as part of the
95 // GEMM -- we expose this flexibility as flexibility in the contraction
96 // dimensions, but we can also see this as flexibility in the input layouts.
97 kEigen,
98 };
99
100 // Returns the implementation strategy for a dot with the configuration
101 // `dot_info`.
102 DotImplementationStrategy GetDotImplementationStrategy(
103 const HloModuleConfig& config, const DotInfo& dot_info,
104 const TargetMachineFeatures& target_machine_features);
105
106 // Helper class for emitting LLVM IR to perform the dot operation.
107 class DotOpEmitter {
108 public:
109 explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
110 const llvm_ir::IrArray& target_array,
111 const llvm_ir::IrArray& lhs_array,
112 const llvm_ir::IrArray& rhs_array,
113 const llvm_ir::IrArray* addend_array,
114 llvm::Value* executable_run_options_value,
115 llvm::IRBuilder<>* b,
116 const HloModuleConfig& hlo_module_config,
117 const TargetMachineFeatures& target_machine_features);
118
119 // Emits the IR to perform the dot operation.
120 Status Emit();
121
122 private:
123 // Emits instructions to perform a scalar dot product (a multiply of the
124 // LHS and RHS) and store the results in the target.
125 Status EmitScalarDot();
126
127 // Emits a call to the CPU runtime to perform the matrix multiply.
128 Status EmitCallToRuntime();
129
130 // Represents the dimensions of a matrix-matrix multiply operation.
131 struct MatMultDims {
132 // The number of rows in the LHS.
133 int64 m;
134
135 // The number of columns in the LHS, which is also must be equal to the
136 // number of rows in the RHS.
137 int64 k;
138
139 // The number of columns on the RHS.
140 int64 n;
141
142 // True if the LHS matrix is column major.
143 bool lhs_column_major;
144
145 // True if the LHS contraction dimension is not 1.
146 bool lhs_non_canonical;
147
148 // True if the RHS matrix is column major.
149 bool rhs_column_major;
150
151 // True if the RHS contraction dimension is not 0.
152 bool rhs_non_canonical;
153
154 // True if the result matrix is column major.
155 bool target_column_major;
156 };
157
158 // Get the MatMultDims instance for the dot product this DotOpEmitter
159 // represents. Precondition: the dot is of rank 2 (and thus its operands are
160 // of rank 2 as well).
161 MatMultDims GetMatMultDims() const;
162
163 // Lowers the dot operation as a tiled Matrix*Vector loop.
164 void EmitTiledLlvmIrGemv();
165
166 // Lowers the dot operation as a tiled Matrix*Matrix loop.
167 void EmitTiledLlvmIrGemm();
168
169 // Lowers the dot operation as a naive nested loop that computes the result
170 // one element at a time.
171 void EmitNaiveLlvmIrGemm();
172
173 // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
174 // registers.
GetGemvTilingFactor() const175 int64 GetGemvTilingFactor() const {
176 const int64 kDefaultTilingFactor = 8;
177 return options::LlvmIrGemvTilingFactor(hlo_module_config_)
178 .value_or(kDefaultTilingFactor);
179 }
180
GetGemmTileSize() const181 std::tuple<int64, int64, int64> GetGemmTileSize() const {
182 // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
183 //
184 // TODO(b/80093688): Tune for other architectures and centralize this
185 // information in one place.
186 const std::tuple<int64, int64, int64> kDefaultTileSize =
187 std::tuple<int64, int64, int64>(11, 9, 1);
188 return options::LlvmIrGemmTileSize(hlo_module_config_)
189 .value_or(kDefaultTileSize);
190 }
191
192 DotInfo dot_info_;
193 string dot_hlo_name_;
194 const llvm_ir::IrArray& target_array_;
195 const llvm_ir::IrArray& lhs_array_;
196 const llvm_ir::IrArray& rhs_array_;
197 const llvm_ir::IrArray* addend_array_;
198 llvm::Value* executable_run_options_value_;
199 llvm::IRBuilder<>* b_;
200 const HloModuleConfig& hlo_module_config_;
201 const TargetMachineFeatures& target_machine_features_;
202 };
203 } // namespace
204
DotOpEmitter(DotInfo dot_info,string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)205 DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
206 const llvm_ir::IrArray& target_array,
207 const llvm_ir::IrArray& lhs_array,
208 const llvm_ir::IrArray& rhs_array,
209 const llvm_ir::IrArray* addend_array,
210 llvm::Value* executable_run_options_value,
211 llvm::IRBuilder<>* b,
212 const HloModuleConfig& hlo_module_config,
213 const TargetMachineFeatures& target_machine_features)
214 : dot_info_(std::move(dot_info)),
215 dot_hlo_name_(std::move(dot_hlo_name)),
216 target_array_(target_array),
217 lhs_array_(lhs_array),
218 rhs_array_(rhs_array),
219 addend_array_(addend_array),
220 executable_run_options_value_(executable_run_options_value),
221 b_(b),
222 hlo_module_config_(hlo_module_config),
223 target_machine_features_(target_machine_features) {}
224
EmitTiledLlvmIrGemm()225 void DotOpEmitter::EmitTiledLlvmIrGemm() {
226 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
227 MatMultDims mat_mult_dims = GetMatMultDims();
228
229 llvm::Value* lhs = lhs_array_.GetBasePointer();
230 llvm::Value* rhs = rhs_array_.GetBasePointer();
231 llvm::Value* target = target_array_.GetBasePointer();
232 int64 m = mat_mult_dims.m;
233 int64 k = mat_mult_dims.k;
234 int64 n = mat_mult_dims.n;
235
236 if (mat_mult_dims.lhs_column_major) {
237 std::swap(lhs, rhs);
238 std::swap(m, n);
239 }
240
241 int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
242 b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
243 /*Align=*/1);
244
245 int64 max_target_vector_width =
246 target_machine_features_.vector_register_num_elements(
247 *b_->GetInsertBlock()->getParent(), primitive_type);
248
249 int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
250 std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
251 GetGemmTileSize();
252
253 EmitSmallGemm(
254 /*scalar_type=*/primitive_type,
255 /*m=*/m, /*k=*/k, /*n=*/n,
256 /*max_vectorization_width=*/max_target_vector_width,
257 /*max_vector_count=*/tile_size_n_in_vector_width,
258 /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
259 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
260 /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
261 }
262
EmitTiledLlvmIrGemv()263 void DotOpEmitter::EmitTiledLlvmIrGemv() {
264 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
265
266 CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
267 primitive_util::IsIntegralType(primitive_type));
268
269 MatMultDims mat_mult_dims = GetMatMultDims();
270 bool is_column_major_matrix_vector = false;
271 bool is_row_major_matrix_vector = false;
272
273 int64 m, k;
274 bool swap_operands;
275
276 if (mat_mult_dims.m == 1) {
277 bool rhs_effectively_row_major =
278 mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major;
279 if (rhs_effectively_row_major) {
280 k = mat_mult_dims.k;
281 m = mat_mult_dims.n;
282 is_column_major_matrix_vector = true;
283 swap_operands = true;
284 } else {
285 k = mat_mult_dims.k;
286 m = mat_mult_dims.n;
287 is_row_major_matrix_vector = true;
288 swap_operands = true;
289 }
290 }
291
292 if (mat_mult_dims.n == 1) {
293 bool lhs_effectively_column_major =
294 mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major;
295 if (lhs_effectively_column_major) {
296 m = mat_mult_dims.m;
297 k = mat_mult_dims.k;
298 is_column_major_matrix_vector = true;
299 swap_operands = false;
300 } else {
301 m = mat_mult_dims.m;
302 k = mat_mult_dims.k;
303 is_row_major_matrix_vector = true;
304 swap_operands = false;
305 }
306 }
307
308 CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector);
309
310 int64 tiling_factor = GetGemvTilingFactor();
311 CHECK_GT(tiling_factor, 0);
312
313 llvm::Value* result_op = target_array_.GetBasePointer();
314 llvm::Value* lhs_op =
315 swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
316 llvm::Value* rhs_op =
317 swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
318
319 const int target_vector_register_element_size =
320 target_machine_features_.vector_register_num_elements(
321 *b_->GetInsertBlock()->getParent(), primitive_type);
322
323 // We may not always know the vector register size for the target we're
324 // compiling against, in which case target_vector_register_element_size is 0.
325 // In these cases we choose a default LLVM IR register size.
326 const int kUnknownTargetVectorRegisterSize = 4;
327 const int vector_register_element_size =
328 target_vector_register_element_size == 0
329 ? kUnknownTargetVectorRegisterSize
330 : target_vector_register_element_size;
331
332 if (is_column_major_matrix_vector) {
333 VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
334 << " and k = " << k;
335 EmitColumnMajorGemv(
336 /*scalar_type=*/primitive_type,
337 /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
338 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
339 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
340 /*result=*/result_op, b_, hlo_module_config_);
341 } else {
342 VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
343 << " and k = " << k;
344 EmitRowMajorGemv(
345 /*scalar_type=*/primitive_type,
346 /*tile_rows=*/tiling_factor,
347 /*tile_cols=*/vector_register_element_size,
348 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
349 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
350 /*result=*/result_op, b_, hlo_module_config_);
351 }
352 }
353
Emit()354 Status DotOpEmitter::Emit() {
355 // The dot operation performs a sum of products over dimension 0 of the left
356 // hand side operand and dimension 1 of the right hand side operand.
357 //
358 // Let the shapes of lhs and rhs be defined as below:
359 //
360 // lhs = [L{n-1} x L{n-2} x ... L{0}]
361 // rhs = [R{m-1} x R{m-2} x ... R{0}]
362 //
363 // The sum-of-products dimension in the lhs has size L{0} and the dimension in
364 // the rhs has size R{1}. Necessarily, then:
365 //
366 // L{0} == R{1}
367 //
368 // The output of the operation has the following shape:
369 //
370 // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
371 //
372 // To perform the operation we construct a loop nest with one for-loop for
373 // each dimension of the output. Inside this loop nest is another for-loop
374 // which performs the sum-of-products (the reduction loop) before storing
375 // the result in the output buffer.
376
377 const Shape& lhs_shape = lhs_array_.GetShape();
378 const Shape& rhs_shape = rhs_array_.GetShape();
379
380 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
381 // If the operands are scalar, don't emit any loops.
382 TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
383 ShapeUtil::IsScalar(rhs_shape));
384 return EmitScalarDot();
385 }
386
387 switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
388 target_machine_features_)) {
389 case DotImplementationStrategy::kNaiveLlvmIr:
390 EmitNaiveLlvmIrGemm();
391 return Status::OK();
392
393 case DotImplementationStrategy::kTiledLlvmIrGemv:
394 EmitTiledLlvmIrGemv();
395 return Status::OK();
396
397 case DotImplementationStrategy::kTiledLlvmIrGemm:
398 EmitTiledLlvmIrGemm();
399 return Status::OK();
400
401 case DotImplementationStrategy::kEigen:
402 return EmitCallToRuntime();
403 }
404 }
405
EmitNaiveLlvmIrGemm()406 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
407 CHECK_EQ(addend_array_, nullptr);
408
409 const Shape& lhs_shape = lhs_array_.GetShape();
410 const Shape& rhs_shape = rhs_array_.GetShape();
411 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
412
413 // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
414 // case where the reduction dimension is 0 for both LHS and RHS. This results
415 // in a vector dot product producing a scalar.
416 int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
417 int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
418
419 // Verify the reduction dimension in the two operands are the same size.
420 CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
421 rhs_shape.dimensions(rhs_reduction_dimension));
422
423 bool lhs_reduction_along_minor_dimension =
424 lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
425 bool rhs_reduction_along_minor_dimension =
426 rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
427
428 // Create loop nests which loop through the LHS operand dimensions and the RHS
429 // operand dimensions. The reduction dimension of the LHS and RHS are handled
430 // in a separate innermost loop which performs the sum of products.
431 llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
432 std::vector<llvm::Value*> lhs_multi_index =
433 loop_nest.EmitOperandArrayLoopNest(
434 lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
435 std::vector<llvm::Value*> rhs_multi_index =
436 loop_nest.EmitOperandArrayLoopNest(
437 rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
438
439 // Create the loop which does the sum of products reduction.
440 //
441 // The prevent_unrolling bit is working around a deficiency in LLVM's loop
442 // vectorization pipeline, wherein in some cases unrolling a loop can prevent
443 // effective vectorization. Since we know that the IR we generate when
444 // reducing across the minor dimension in both LHS and RHS is vectorized well
445 // by the loop vectorizer, we block unrolling in that case to stop loop unroll
446 // from messing up the vectorization.
447 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
448 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
449 /*unroll_mode=*/
450 (lhs_reduction_along_minor_dimension &&
451 rhs_reduction_along_minor_dimension)
452 ? xla::llvm_ir::UnrollMode::kNoUnroll
453 : xla::llvm_ir::UnrollMode::kDefaultUnroll);
454
455 // The final entry in the rhs and lhs indexes is the indvar of the
456 // reduction loop.
457 lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
458 llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
459 b_->getInt64Ty());
460 rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
461 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
462 b_->getInt64Ty());
463
464 // For computing the sum of products we alloca a single location to store the
465 // dot product result as we accumulate it within the reduction loop. After the
466 // reduction loop we load the result and store into the output array.
467
468 // Function entry basic block.
469 // - Emit alloca for accumulator
470 llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
471 SetToFirstInsertPoint(&func->getEntryBlock(), b_);
472 llvm::Type* accum_type = target_array_.GetElementLlvmType();
473 llvm::Value* accum_address =
474 b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
475
476 // Preheader basic block of reduction loop:
477 // - Initialize accumulator to zero.
478 llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
479 b_->SetInsertPoint(preheader_bb->getTerminator());
480
481 b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
482
483 // Body basic block of reduction loop:
484 // - Load elements from lhs and rhs array.
485 // - Multiply lhs-element and rhs-element.
486 // - Load accumulator and add to product.
487 // - Store sum back into accumulator.
488 SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
489
490 llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
491 llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
492
493 llvm::Value* accum = b_->CreateLoad(accum_address);
494 llvm::Value* updated_accum;
495 if (ShapeUtil::ElementIsComplex(lhs_shape)) {
496 auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
497 auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
498 llvm::Value* product_real =
499 b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
500 b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
501 llvm::Value* product_imag =
502 b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
503 b_->CreateFMul(imag(lhs_element), real(rhs_element)));
504 updated_accum = b_->CreateInsertValue(
505 accum, b_->CreateFAdd(real(accum), product_real), {0});
506 updated_accum = b_->CreateInsertValue(
507 updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
508 } else {
509 llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
510 updated_accum = b_->CreateFAdd(accum, product);
511 }
512 b_->CreateStore(updated_accum, accum_address);
513
514 // Exit basic block of reduction loop.
515 // - Load accumulator value (the result).
516 // - Store into output array.
517 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
518
519 llvm::Value* result = b_->CreateLoad(accum_address);
520
521 // Create index into target address. The target index is the concatenation of
522 // the rhs and lhs indexes with the reduction dimensions removed. The terms
523 // from the rhs index are the lower dimensions in the index so we add them
524 // first.
525 std::vector<llvm::Value*> target_multi_index;
526 for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
527 if (dimension != lhs_reduction_dimension) {
528 target_multi_index.push_back(lhs_index[dimension]);
529 }
530 }
531 for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
532 if (dimension != rhs_reduction_dimension) {
533 target_multi_index.push_back(rhs_index[dimension]);
534 }
535 }
536
537 llvm_ir::IrArray::Index target_index(
538 target_multi_index, target_array_.GetShape(), lhs_index.GetType());
539 target_array_.EmitWriteArrayElement(target_index, result, b_);
540
541 // Set the IR builder insert point to the exit basic block of the outer most
542 // loop.
543 b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
544 }
545
EmitScalarDot()546 Status DotOpEmitter::EmitScalarDot() {
547 // A scalar dot is just a scalar multiply.
548 llvm::Value* result;
549 // Use the same index_type for all tensor accesses in the same kernel.
550 llvm::Type* index_type = b_->getInt64Ty();
551 llvm_ir::IrArray::Index element_index(index_type);
552 llvm::Value* lhs_value =
553 lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
554 llvm::Value* rhs_value =
555 rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
556 if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
557 auto get_real = [&](llvm::Value* x) {
558 return b_->CreateExtractValue(x, {0});
559 };
560
561 auto get_imag = [&](llvm::Value* x) {
562 return b_->CreateExtractValue(x, {1});
563 };
564
565 llvm::Value* real = b_->CreateFSub(
566 b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
567 b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
568 llvm::Value* imag = b_->CreateFAdd(
569 b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
570 b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
571 result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
572 result = b_->CreateInsertValue(result, real, {0});
573 result = b_->CreateInsertValue(result, imag, {1});
574 } else {
575 result = b_->CreateFMul(lhs_value, rhs_value);
576 }
577 target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
578 return Status::OK();
579 }
580
EmitCallToRuntime()581 Status DotOpEmitter::EmitCallToRuntime() {
582 // The signature of the Eigen runtime matmul function is:
583 //
584 // (void)(void* run_options, float* out, float* lhs, float* rhs,
585 // int64 m, int64 n, int64 k, int32 transpose_lhs,
586 // int32 transpose_rhs);
587 // The two transpose_... parameters are actually booleans, but we use int32
588 // to avoid target-dependent calling convention details.
589
590 bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
591 bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
592 PrimitiveType type = target_array_.GetShape().element_type();
593 llvm::Type* float_type;
594 const char* fn_name;
595 switch (type) {
596 case F16:
597 fn_name = multi_threaded
598 ? runtime::kEigenMatMulF16SymbolName
599 : runtime::kEigenSingleThreadedMatMulF16SymbolName;
600 float_type = b_->getHalfTy();
601 break;
602 case F32:
603 fn_name = multi_threaded
604 ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
605 : runtime::kEigenMatMulF32SymbolName)
606 : (use_mkl_dnn
607 ? runtime::kMKLSingleThreadedMatMulF32SymbolName
608 : runtime::kEigenSingleThreadedMatMulF32SymbolName);
609 float_type = b_->getFloatTy();
610 break;
611 case F64:
612 fn_name = multi_threaded
613 ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
614 : runtime::kEigenMatMulF64SymbolName)
615 : (use_mkl_dnn
616 ? runtime::kMKLSingleThreadedMatMulF64SymbolName
617 : runtime::kEigenSingleThreadedMatMulF64SymbolName);
618 float_type = b_->getDoubleTy();
619 break;
620 default:
621 return Unimplemented("Invalid type %s for dot operation",
622 PrimitiveType_Name(type));
623 }
624
625 llvm::Type* float_ptr_type = float_type->getPointerTo();
626 llvm::Type* int64_type = b_->getInt64Ty();
627 llvm::Type* int32_type = b_->getInt32Ty();
628 llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
629 llvm::FunctionType* matmul_type = llvm::FunctionType::get(
630 b_->getVoidTy(),
631 {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
632 int64_type, int64_type, int64_type, int32_type, int32_type},
633 /*isVarArg=*/false);
634
635 llvm::Function* function = b_->GetInsertBlock()->getParent();
636 llvm::Module* module = function->getParent();
637
638 llvm::FunctionCallee matmul_func =
639 module->getOrInsertFunction(fn_name, matmul_type);
640 if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
641 fn->setCallingConv(llvm::CallingConv::C);
642 fn->setDoesNotThrow();
643 fn->setOnlyAccessesArgMemory();
644 }
645
646 // The Eigen runtime function expects column-major layout. If the matrices are
647 // row major, then use the following identity to compute the product:
648 //
649 // (A x B)^T = B^T x A^T
650 //
651 // The connection between this identity and memory layout is that the
652 // transpose operation can also be considered as an operation that changes the
653 // memory layout of a matrix from row-major to column-major or vice versa.
654 //
655 // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
656
657 MatMultDims mat_mult_dims = GetMatMultDims();
658
659 CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
660
661 const llvm_ir::IrArray* lhs = &lhs_array_;
662 const llvm_ir::IrArray* rhs = &rhs_array_;
663 bool transpose_lhs = mat_mult_dims.lhs_non_canonical;
664 bool transpose_rhs = mat_mult_dims.rhs_non_canonical;
665
666 if (!mat_mult_dims.lhs_column_major) {
667 std::swap(mat_mult_dims.m, mat_mult_dims.n);
668 std::swap(lhs, rhs);
669 std::swap(transpose_lhs, transpose_rhs);
670 }
671
672 b_->CreateCall(
673 matmul_func,
674 {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
675 b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
676 b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
677 b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
678 b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
679 b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
680 b_->getInt32(transpose_rhs)});
681 return Status::OK();
682 }
683
GetMatMultDims() const684 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
685 CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2);
686
687 const Shape& lhs_shape = lhs_array_.GetShape();
688 const Shape& rhs_shape = rhs_array_.GetShape();
689 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
690
691 return {
692 /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
693 /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
694 /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
695 /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
696 /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0,
697 /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
698 /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1,
699 /*target_column_major=*/
700 LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
701 }
702
703 // For vector-matrix dot products, it is always profitable to make the Rhs
704 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)705 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
706 const HloInstruction& hlo) {
707 if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
708 hlo.shape().dimensions(0) == 1) {
709 if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
710 return 1;
711 }
712 return {};
713 }
714
715 if (hlo.opcode() == HloOpcode::kFusion &&
716 hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) {
717 auto* fusion_root =
718 hlo.fused_instructions_computation()->root_instruction();
719 if (fusion_root->opcode() != HloOpcode::kAdd) {
720 return {};
721 }
722
723 for (auto* fusion_root_op : fusion_root->operands()) {
724 if (fusion_root_op->opcode() != HloOpcode::kDot) {
725 continue;
726 }
727 if (auto operand_num =
728 ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
729 auto* operand = fusion_root_op->operand(*operand_num);
730 if (operand->opcode() == HloOpcode::kParameter &&
731 operand->user_count() == 1) {
732 return operand->parameter_number();
733 }
734 }
735 }
736 }
737
738 return {};
739 }
740
741 namespace {
742 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)743 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
744
IsSimpleLayout(const Layout & layout)745 bool IsSimpleLayout(const Layout& layout) {
746 return layout.tiles().empty() && layout.format() == DENSE;
747 }
748
749 // In a gemm operation where output = lhs * rhs, check whether the given shapes
750 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)751 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
752 const Shape& output_shape,
753 const TargetMachineFeatures& target_machine_features) {
754 CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
755 << lhs_shape.DebugString();
756 CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
757 << rhs_shape.DebugString();
758 CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
759 << output_shape.DebugString();
760
761 switch (output_shape.element_type()) {
762 case F64:
763 case F32:
764 case F16:
765 return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
766 default:
767 return false;
768 }
769 }
770
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)771 bool IsAlignedGemm(const DotInfo& dot_info,
772 const TargetMachineFeatures& target_machine_features) {
773 if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
774 ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
775 return false;
776 }
777
778 return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
779 dot_info.result_shape, target_machine_features);
780 }
781
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)782 bool CanEmitTiledLlvmIrGemm(
783 const HloModuleConfig& config, const DotInfo& dot_info,
784 const TargetMachineFeatures& target_machine_features) {
785 CHECK(IsAlignedGemm(dot_info, target_machine_features));
786
787 if (ShouldUseMultiThreadedEigen(config)) {
788 return false;
789 }
790
791 int m = dot_info.result_shape.dimensions(0);
792 int k = dot_info.lhs_shape.dimensions(
793 dot_info.dim_nums.lhs_contracting_dimensions(0));
794 int n = dot_info.result_shape.dimensions(1);
795
796 if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
797 // TODO(sanjoy): We should make these numbers micro-arch specific.
798 bool small_gemm =
799 k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
800 if (!small_gemm) {
801 return false;
802 }
803 }
804
805 bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0;
806 bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1;
807
808 if (lhs_non_canonical || rhs_non_canonical) {
809 return false;
810 }
811
812 if (dot_info.result_shape.element_type() == F16) {
813 // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
814 // adding this comment NFC.
815 return false;
816 }
817
818 return true;
819 }
820
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)821 DotImplementationStrategy GetDotImplementationStrategy(
822 const HloModuleConfig& config, const DotInfo& dot_info,
823 const TargetMachineFeatures& target_machine_features) {
824 PrimitiveType element_type = dot_info.result_shape.element_type();
825 // Any Matrix-Vector product of floating point or integral type, or
826 // a transpose-dot fusion of the same can be lowered to a tiled LLVM
827 // IR implementation.
828 if (dot_info.result_shape.dimensions_size() == 2 &&
829 (dot_info.result_shape.dimensions(0) == 1 ||
830 dot_info.result_shape.dimensions(1) == 1) &&
831 (primitive_util::IsFloatingPointType(element_type) ||
832 primitive_util::IsIntegralType(element_type))) {
833 return DotImplementationStrategy::kTiledLlvmIrGemv;
834 }
835
836 if (IsAlignedGemm(dot_info, target_machine_features)) {
837 return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)
838 ? DotImplementationStrategy::kTiledLlvmIrGemm
839 : DotImplementationStrategy::kEigen;
840 }
841
842 return DotImplementationStrategy::kNaiveLlvmIr;
843 }
844
EmitNonBatchDotOperation(DotInfo dot_info,string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)845 Status EmitNonBatchDotOperation(
846 DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
847 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
848 const llvm_ir::IrArray* addend_array,
849 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
850 const HloModuleConfig& hlo_module_config,
851 const TargetMachineFeatures& target_machine_features) {
852 PrimitiveType type = target_array.GetShape().element_type();
853 TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type ||
854 C128 == type);
855 DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
856 target_array, lhs_array, rhs_array, addend_array,
857 executable_run_options_value, b, hlo_module_config,
858 target_machine_features);
859 return dot_emitter.Emit();
860 }
861
DropFirstDim(const Shape & shape)862 Shape DropFirstDim(const Shape& shape) {
863 absl::Span<int64 const> array_shape_dims(shape.dimensions());
864 array_shape_dims.remove_prefix(1);
865 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
866 array_shape_dims);
867 }
868
CollapseFirstNDims(const Shape & shape,int64 n)869 Shape CollapseFirstNDims(const Shape& shape, int64 n) {
870 absl::Span<int64 const> input_shape_dims(shape.dimensions());
871 int64 prefix_dim =
872 std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
873 1ll, std::multiplies<int64>());
874 DimensionVector result_dims;
875 result_dims.push_back(prefix_dim);
876 std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
877 std::back_inserter(result_dims));
878 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
879 result_dims);
880 }
881
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64 n)882 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
883 const llvm_ir::IrArray& array, int64 n) {
884 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
885 const Shape& shape = array.GetShape();
886 CHECK(shape.has_layout() &&
887 LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
888 CHECK_GE(shape.dimensions_size(), n);
889 Shape new_shape = CollapseFirstNDims(shape, n);
890 llvm::Value* new_value = b->CreateBitCast(
891 array.GetBasePointer(),
892 llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
893 return llvm_ir::IrArray(new_value, std::move(new_shape));
894 }
895
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)896 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
897 // Checks some invariants that do not hold in general, but DotDecomposer
898 // should have established for us. This is just a debugging aid.
899 TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
900 std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
901 absl::c_iota(batch_dim_numbers, 0);
902 TF_RET_CHECK(
903 absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
904 TF_RET_CHECK(
905 absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
906 return Status::OK();
907 }
908
909 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)910 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
911 llvm::Value* batch_index,
912 llvm::IRBuilder<>* b) {
913 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
914
915 Shape inner_shape = DropFirstDim(outer_array.GetShape());
916 std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
917 b->getInt64(0));
918 multidim_index[0] = batch_index;
919 llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
920 batch_index->getType());
921 llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
922 llvm::Type* slice_ptr_type =
923 llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
924 return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
925 std::move(inner_shape));
926 }
927
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)928 Status EmitBatchDotOperation(
929 const HloInstruction& dot, const llvm_ir::IrArray& target_array,
930 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
931 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
932 const HloModuleConfig& hlo_module_config,
933 const TargetMachineFeatures& target_machine_features) {
934 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
935
936 // Lower a batch dot into a sequence of non-batch dot operations.
937
938 int64 num_batch_dims =
939 dot.dot_dimension_numbers().lhs_batch_dimensions_size();
940
941 // First reshape the inputs to make sure we only have one batch dimension.
942 // This is a no-op bitcast because the operands have to be in row-major layout
943 // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
944 // dimensions (established by DotDecomposer and checked by
945 // ValidateDotDimensionNumbers above).
946 llvm_ir::IrArray lhs_array_reshaped =
947 CollapseFirstNDims(b, lhs_array, num_batch_dims);
948 llvm_ir::IrArray rhs_array_reshaped =
949 CollapseFirstNDims(b, rhs_array, num_batch_dims);
950 llvm_ir::IrArray target_array_reshaped =
951 CollapseFirstNDims(b, target_array, num_batch_dims);
952
953 int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0);
954
955 KernelSupportLibrary ksl(b);
956
957 return ksl.ForWithStatus(
958 llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
959 /*step=*/1, [&](llvm::Value* indvar) {
960 DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
961 adjusted_dim_numbers.clear_lhs_batch_dimensions();
962 adjusted_dim_numbers.clear_rhs_batch_dimensions();
963
964 // Create a DotInfo representing the "inner" non-batch dot operation.
965 DotInfo dot_info;
966 dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
967 dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
968 dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
969 dot_info.dim_nums = dot.dot_dimension_numbers();
970 dot_info.dim_nums.clear_lhs_batch_dimensions();
971 dot_info.dim_nums.clear_rhs_batch_dimensions();
972
973 dot_info.dim_nums.set_lhs_contracting_dimensions(
974 0,
975 dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
976 dot_info.dim_nums.set_rhs_contracting_dimensions(
977 0,
978 dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
979
980 llvm_ir::IrArray lhs_slice =
981 SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
982 llvm_ir::IrArray rhs_slice =
983 SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
984 llvm_ir::IrArray target_slice = SliceOutInnerArray(
985 target_array_reshaped, /*batch_index=*/indvar, b);
986
987 // Emit the inner non-batch dot operation.
988 return EmitNonBatchDotOperation(
989 dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
990 executable_run_options_value, b, hlo_module_config,
991 target_machine_features);
992 });
993 }
994
IsBatchDot(const HloInstruction & instr)995 bool IsBatchDot(const HloInstruction& instr) {
996 if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
997 return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
998 }
999
1000 return false;
1001 }
1002 } // namespace
1003
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1004 bool DotImplementationCanHandleTranspose(
1005 const HloInstruction& dot_instr,
1006 const TargetMachineFeatures& target_machine_features) {
1007 DotImplementationStrategy impl_strategy =
1008 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1009 DotInfo(dot_instr), target_machine_features);
1010
1011 // TODO(sanjoy): This is not quite right, it should be `impl_strategy ==
1012 // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy ==
1013 // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping
1014 // the CL adding this comment NFC.
1015 return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1016 impl_strategy == DotImplementationStrategy::kEigen;
1017 }
1018
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1019 bool DotOperandsAndResultMustHaveRowMajorLayout(
1020 const HloInstruction& dot_instr,
1021 const TargetMachineFeatures& target_machine_features) {
1022 DotImplementationStrategy impl_strategy =
1023 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1024 DotInfo(dot_instr), target_machine_features);
1025
1026 return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1027 impl_strategy == DotImplementationStrategy::kEigen;
1028 }
1029
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1030 Status EmitDotOperation(const HloInstruction& dot,
1031 const llvm_ir::IrArray& target_array,
1032 const llvm_ir::IrArray& lhs_array,
1033 const llvm_ir::IrArray& rhs_array,
1034 const llvm_ir::IrArray* addend_array,
1035 llvm::Value* executable_run_options_value,
1036 llvm::IRBuilder<>* b,
1037 const HloModuleConfig& hlo_module_config,
1038 const TargetMachineFeatures& target_machine_features) {
1039 // This routine assumes that the dot operation is not in a parallelized
1040 // enclosing computation.
1041 CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1042
1043 if (IsBatchDot(dot)) {
1044 TF_RET_CHECK(addend_array == nullptr);
1045 return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1046 executable_run_options_value, b,
1047 hlo_module_config, target_machine_features);
1048 }
1049
1050 return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1051 lhs_array, rhs_array, addend_array,
1052 executable_run_options_value, b,
1053 hlo_module_config, target_machine_features);
1054 }
1055 } // namespace cpu
1056 } // namespace xla
1057