1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Module.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
31 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
33 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
34 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
35 #include "tensorflow/compiler/xla/service/hlo_computation.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/core/lib/core/errors.h"
49
50 namespace xla {
51
52 using llvm_ir::IrName;
53 using llvm_ir::SetToFirstInsertPoint;
54
55 namespace gpu {
56
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)57 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
58 IrEmitterContext* ir_emitter_context, bool is_nested)
59 : ir_emitter_context_(ir_emitter_context),
60 module_(ir_emitter_context->llvm_module()),
61 b_(module_->getContext()),
62 bindings_(ir_emitter_context->hlo_module(),
63 &ir_emitter_context->buffer_assignment(), &b_, module_,
64 is_nested),
65 hlo_module_config_(hlo_module_config) {
66 }
67
DefaultAction(HloInstruction * hlo)68 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
69 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
70 for (const HloInstruction* operand : hlo->operands()) {
71 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
72 return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
73 };
74 }
75 return EmitTargetElementLoop(
76 *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
77 GetNestedComputer())
78 .MakeElementGenerator(hlo, operand_to_generator));
79 }
80
HandleConstant(HloInstruction * constant)81 Status IrEmitter::HandleConstant(HloInstruction* constant) {
82 return Status::OK();
83 }
84
HandleBitcast(HloInstruction * bitcast)85 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
86 VLOG(2) << "HandleBitcast: " << bitcast->ToString();
87 const HloInstruction* operand = bitcast->operand(0);
88 // Bitcast is a no-op, but we still want to bind it to an llvm::Value
89 // sometimes, e.g., when it's operand is a constant or a bitcast of a
90 // constant.
91 if (bindings_.BoundToIrValue(*operand)) {
92 bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand));
93 }
94 return Status::OK();
95 }
96
HandleAddDependency(HloInstruction * add_dependency)97 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
98 VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
99 const HloInstruction* operand = add_dependency->operand(0);
100 // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
101 // sometimes, e.g., when it's operand is a constant or a bitcast of a
102 // constant.
103 if (bindings_.BoundToIrValue(*operand)) {
104 bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
105 }
106 return Status::OK();
107 }
108
HandleGetTupleElement(HloInstruction * get_tuple_element)109 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
110 auto operand = get_tuple_element->operand(0);
111 CHECK(bindings_.BoundToIrValue(*operand));
112 bindings_.BindHloToIrValue(
113 *get_tuple_element,
114 llvm_ir::EmitGetTupleElement(
115 get_tuple_element->shape(), get_tuple_element->tuple_index(),
116 // TODO(b/26344050): tighten the alignment here
117 // based on the real element type.
118 /*alignment=*/1, GetBasePointer(*operand), &b_));
119 return Status::OK();
120 }
121
HandleSend(HloInstruction *)122 Status IrEmitter::HandleSend(HloInstruction*) {
123 return Unimplemented("Send is not implemented on GPU");
124 }
125
HandleSendDone(HloInstruction *)126 Status IrEmitter::HandleSendDone(HloInstruction*) {
127 return Unimplemented("Send-Done is not implemented on GPU");
128 }
129
HandleRecv(HloInstruction *)130 Status IrEmitter::HandleRecv(HloInstruction*) {
131 return Unimplemented("Recv is not implemented on GPU");
132 }
133
HandleRecvDone(HloInstruction *)134 Status IrEmitter::HandleRecvDone(HloInstruction*) {
135 return Unimplemented("Recv-done is not implemented on GPU");
136 }
137
HandleScatter(HloInstruction *)138 Status IrEmitter::HandleScatter(HloInstruction*) {
139 return Unimplemented("Scatter is not implemented on GPUs.");
140 }
141
HandleTuple(HloInstruction * tuple)142 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
143 std::vector<llvm::Value*> base_ptrs;
144 for (const HloInstruction* operand : tuple->operands()) {
145 base_ptrs.push_back(GetBasePointer(*operand));
146 }
147 llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
148 return Status::OK();
149 }
150
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)151 Status IrEmitter::EmitCallToNestedComputation(
152 const HloComputation& nested_computation,
153 absl::Span<llvm::Value* const> operands, llvm::Value* output) {
154 TF_RET_CHECK(nested_computation.num_parameters() > 0);
155 llvm::Function*& emitted_function =
156 computation_to_ir_function_[&nested_computation];
157 if (emitted_function == nullptr) {
158 IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
159 ir_emitter_context_);
160 TF_RETURN_IF_ERROR(
161 nested_computation.root_instruction()->Accept(&ir_emitter_nested));
162 emitted_function = ir_emitter_nested.GetEmittedFunction();
163 }
164
165 std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
166 arguments.push_back(output);
167 arguments.push_back(bindings_.GetTempBufferBase());
168 Call(emitted_function, arguments);
169
170 return Status::OK();
171 }
172
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)173 bool IrEmitter::MaybeEmitDirectAtomicOperation(
174 const HloComputation& computation, llvm::Value* output_address,
175 llvm::Value* source_address) {
176 CHECK_EQ(2, computation.num_parameters());
177
178 if (computation.instruction_count() != 3) {
179 // We special-case only computations with one computing instruction for now.
180 // Such computation has exactly three instructions given it has two
181 // parameters.
182 return false;
183 }
184
185 HloOpcode root_opcode = computation.root_instruction()->opcode();
186 PrimitiveType element_type =
187 computation.root_instruction()->shape().element_type();
188 bool is_atomic_integral = element_type == S32 || element_type == U32 ||
189 element_type == S64 || element_type == U64;
190 llvm::Value* source = Load(source_address, "source");
191
192 // kCopy of RHS -> atomic store.
193 if (root_opcode == HloOpcode::kCopy &&
194 (element_type == F32 || is_atomic_integral) &&
195 computation.root_instruction()->operand(0)->opcode() ==
196 HloOpcode::kParameter &&
197 computation.root_instruction()->operand(0)->parameter_number() == 1) {
198 llvm::StoreInst* store = Store(source, output_address);
199 store->setAtomic(llvm::AtomicOrdering::Unordered);
200 // Derive a minimum alignment from the type. The optimizer can increase it
201 // later.
202 store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type));
203 return true;
204 }
205
206 if (root_opcode == HloOpcode::kAdd) {
207 // NVPTX supports atomicAdd on F32 and integer types.
208 if (element_type == F32) {
209 // F32 + F32
210 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32,
211 {output_address, source},
212 {output_address->getType()}, &b_);
213 return true;
214 }
215 if (is_atomic_integral) {
216 // integral + integral
217 AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
218 llvm::AtomicOrdering::SequentiallyConsistent);
219 return true;
220 }
221 }
222
223 // NVPTX supports atomicMax and atomicMin only on integer types.
224 if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
225 // max(integral, integral)
226 auto opcode = primitive_util::IsSignedIntegralType(element_type)
227 ? llvm::AtomicRMWInst::Max
228 : llvm::AtomicRMWInst::UMax;
229 AtomicRMW(opcode, output_address, source,
230 llvm::AtomicOrdering::SequentiallyConsistent);
231 return true;
232 }
233
234 if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
235 // min(integral, integral)
236 auto opcode = primitive_util::IsSignedIntegralType(element_type)
237 ? llvm::AtomicRMWInst::Min
238 : llvm::AtomicRMWInst::UMin;
239 AtomicRMW(opcode, output_address, source,
240 llvm::AtomicOrdering::SequentiallyConsistent);
241 return true;
242 }
243
244 return false;
245 }
246
247 // Implements atomic binary operations using atomic compare-and-swap
248 // (atomicCAS) as follows:
249 // 1. Reads the value from the memory pointed to by output_address and
250 // records it as old_output.
251 // 2. Uses old_output as one of the source operand to perform the binary
252 // operation and stores the result in new_output.
253 // 3. Calls atomicCAS which implements compare-and-swap as an atomic
254 // operation. In particular, atomicCAS reads the value from the memory
255 // pointed to by output_address, and compares the value with old_output. If
256 // the two values equal, new_output is written to the same memory location
257 // and true is returned to indicate that the atomic operation succeeds.
258 // Otherwise, the new value read from the memory is returned. In this case,
259 // the new value is copied to old_output, and steps 2. and 3. are repeated
260 // until atomicCAS succeeds.
261 //
262 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
263 // the element type of the binary operation is 32 bits or 64 bits, the integer
264 // type of the same size is used for the atomicCAS operation. On the other hand,
265 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
266 // operation. In this case, atomicCAS reads and writes 32 bit values from
267 // the memory, which is larger than the memory size required by the original
268 // atomic binary operation. We mask off the last two bits of the output_address
269 // and use the result as an address to read the 32 bit values from the memory.
270 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
271 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
272 //
273 // The pseudo code is shown below. Variables *_address are pointers to a memory
274 // region with a size equal to the size of the atomicCAS operation, with the
275 // exception that new_output_address is a pointer to a memory region with a size
276 // equal to the element size of the binary operation.
277 //
278 // element_size = sizeof(element_type);
279 // atomic_size = max(32, element_size);
280 // cas_new_output_address = alloca(atomic_size);
281 // cas_old_output_address = alloca(atomic_size);
282 // if (atomic_size != element_size) {
283 // atomic_address = output_address & ((int64)(-4));
284 // new_output_address = cas_new_output_address + (output_address & 3);
285 // } else {
286 // atomic_address = output_address;
287 // new_output_address = cas_new_output_address;
288 // }
289 //
290 // *cas_old_output_address = *atomic_address;
291 // do {
292 // *cas_new_output_address = *cas_old_output_address;
293 // *new_output_address = operation(*new_output_address, *source_address);
294 // (*cas_old_output_address, success) =
295 // atomicCAS(atomic_address, *cas_old_output_address,
296 // *cas_new_output_address);
297 // } while (!success);
298 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)299 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
300 llvm::Value* output_address,
301 llvm::Value* source_address) {
302 llvm::PointerType* output_address_type =
303 llvm::dyn_cast<llvm::PointerType>(output_address->getType());
304 CHECK_NE(output_address_type, nullptr);
305
306 // element_type is the data type for the binary operation.
307 llvm::Type* element_type = output_address_type->getPointerElementType();
308 int element_size = llvm_ir::GetSizeInBits(element_type);
309 llvm::Type* element_address_type = element_type->getPointerTo();
310
311 int atomic_size = (element_size < 32) ? 32 : element_size;
312 llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
313 llvm::Type* atomic_address_type =
314 atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
315
316 // cas_old_output_address and cas_new_output_address point to the scratch
317 // memory where we store the old and new values for the repeated atomicCAS
318 // operations.
319 llvm::Value* cas_old_output_address =
320 Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
321 llvm::Value* cas_new_output_address =
322 Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
323
324 // Emit preparation code to the preheader.
325 llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
326
327 llvm::Value* atomic_memory_address;
328 // binop_output_address points to the scratch memory that stores the
329 // result of the binary operation.
330 llvm::Value* binop_output_address;
331 if (element_size < 32) {
332 // Assume the element size is an integer number of bytes.
333 CHECK_EQ((element_size % sizeof(char)), 0);
334 llvm::Type* address_int_type =
335 module_->getDataLayout().getIntPtrType(output_address_type);
336 atomic_memory_address = PtrToInt(output_address, address_int_type);
337 llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
338 llvm::Value* offset = And(atomic_memory_address, mask);
339 mask = llvm::ConstantInt::get(address_int_type, -4);
340 atomic_memory_address = And(atomic_memory_address, mask);
341 atomic_memory_address =
342 IntToPtr(atomic_memory_address, atomic_address_type);
343 binop_output_address =
344 Add(PtrToInt(cas_new_output_address, address_int_type), offset);
345 binop_output_address = IntToPtr(binop_output_address, element_address_type);
346 } else {
347 atomic_memory_address = BitCast(output_address, atomic_address_type);
348 binop_output_address =
349 BitCast(cas_new_output_address, element_address_type);
350 }
351
352 // Use the value from the memory that atomicCAS operates on to initialize
353 // cas_old_output.
354 llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
355 Store(cas_old_output, cas_old_output_address);
356
357 llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
358 b_.GetInsertPoint(), "atomic_op_loop_exit");
359 llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
360 b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
361 b_.SetInsertPoint(loop_body_bb);
362 // Change preheader's successor from loop_exit_bb to loop_body_bb.
363 loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
364
365 // Emit the body of the loop that repeatedly invokes atomicCAS.
366 //
367 // Use cas_old_output to initialize cas_new_output.
368 cas_old_output = Load(cas_old_output_address, "cas_old_output");
369 Store(cas_old_output, cas_new_output_address);
370 // Emits code to calculate new_output = operation(old_output, source);
371 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
372 computation, {binop_output_address, source_address},
373 binop_output_address));
374
375 llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
376
377 // Emit code to perform the atomicCAS operation
378 // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
379 // cas_new_output);
380 llvm::Value* ret_value =
381 AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
382 llvm::AtomicOrdering::SequentiallyConsistent,
383 llvm::AtomicOrdering::SequentiallyConsistent);
384
385 // Extract the memory value returned from atomicCAS and store it as
386 // cas_old_output.
387 Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
388 // Extract the success bit returned from atomicCAS and generate a
389 // conditional branch on the success bit.
390 CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
391
392 // Set the insertion point to the exit basic block so that the caller of
393 // this method can continue emitting code to the right place.
394 SetToFirstInsertPoint(loop_exit_bb, &b_);
395 return Status::OK();
396 }
397
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)398 Status IrEmitter::EmitAtomicOperationForNestedComputation(
399 const HloComputation& computation, llvm::Value* output_address,
400 llvm::Value* source_address) {
401 if (computation.num_parameters() != 2) {
402 // TODO(b/30258929): We only accept binary computations so far.
403 return Unimplemented(
404 "We only support atomic functions with exactly two parameters, but "
405 "computation %s has %d.",
406 computation.name(), computation.num_parameters());
407 }
408
409 if (MaybeEmitDirectAtomicOperation(computation, output_address,
410 source_address)) {
411 return Status::OK();
412 }
413
414 return EmitAtomicOperationUsingCAS(computation, output_address,
415 source_address);
416 }
417
HandleSelect(HloInstruction * select)418 Status IrEmitter::HandleSelect(HloInstruction* select) {
419 auto pred = select->operand(0);
420 TF_RET_CHECK(pred->shape().element_type() == PRED);
421 // We must not call the subclass `DefaultAction` method, lest its
422 // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction`
423 // assume no handler has already been called.
424 return IrEmitter::DefaultAction(select);
425 }
426
HandleTupleSelect(HloInstruction * tuple_select)427 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
428 auto pred = tuple_select->operand(0);
429 auto on_true = tuple_select->operand(1);
430 auto on_false = tuple_select->operand(2);
431 TF_RET_CHECK(pred->shape().element_type() == PRED);
432 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
433 TF_RET_CHECK(tuple_select->shape().IsTuple());
434 llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
435 GetIrArray(*pred, *tuple_select),
436 GetBasePointer(*on_true), GetBasePointer(*on_false),
437 &b_);
438 return Status::OK();
439 }
440
441 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)442 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
443 return b->CreateExtractValue(x, {0});
444 }
445
Imag(llvm::Value * x,llvm::IRBuilder<> * b)446 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
447 return b->CreateExtractValue(x, {1});
448 }
449
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)450 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
451 llvm::Value* rhs_value,
452 llvm::IRBuilder<>* b) {
453 llvm::Value* lhs_real = Real(lhs_value, b);
454 llvm::Value* lhs_imag = Imag(lhs_value, b);
455 llvm::Value* rhs_real = Real(rhs_value, b);
456 llvm::Value* rhs_imag = Imag(rhs_value, b);
457 llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
458 llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
459 llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
460 llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
461 llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
462 llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
463 return {real_result, imag_result};
464 }
465 } // namespace
466
HandleDot(HloInstruction * dot)467 Status IrEmitter::HandleDot(HloInstruction* dot) {
468 auto lhs_instruction = dot->operand(0);
469 auto rhs_instruction = dot->operand(1);
470 const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
471 const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
472 const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
473
474 const Shape& lhs_shape = lhs_instruction->shape();
475 const Shape& rhs_shape = rhs_instruction->shape();
476 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
477 CHECK_EQ(dnums.lhs_batch_dimensions_size(),
478 dnums.rhs_batch_dimensions_size());
479
480 // TODO(b/110211620): Convert to use i32 index_type when it is possible.
481 llvm::Type* index_type = b_.getInt64Ty();
482 llvm_ir::IrArray::Index element_index(index_type);
483 if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
484 // If the operands are scalar, don't emit any loops.
485 llvm::Value* lhs_value =
486 lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
487 llvm::Value* rhs_value =
488 rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
489 llvm::Value* result;
490 if (ShapeUtil::ElementIsComplex(lhs_shape)) {
491 auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
492 result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
493 result = InsertValue(result, value.first, {0});
494 result = InsertValue(result, value.second, {1});
495 } else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
496 result = FMul(lhs_value, rhs_value);
497 } else {
498 TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
499 result = Mul(lhs_value, rhs_value);
500 }
501 target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
502 return Status::OK();
503 }
504
505 // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See
506 // the semantics of Dot in the XLA documentation for details.
507 TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) &&
508 !ShapeUtil::IsScalar(rhs_shape));
509
510 const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0);
511 const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0);
512
513 // Check that the batch dims don't cover the reduction dimensions.
514 for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
515 CHECK_NE(lhs_reduction_dimension, batch_dim);
516 CHECK_NE(rhs_reduction_dimension, batch_dim);
517 }
518
519 // Verify the reduction dimension in the two operands are the same size.
520 TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
521 rhs_shape.dimensions(rhs_reduction_dimension))
522 << "lhs_shape.dimensions(" << lhs_reduction_dimension
523 << ") = " << lhs_shape.dimensions(lhs_reduction_dimension)
524 << ", and rhs_shape.dimensions(" << rhs_reduction_dimension
525 << ") = " << rhs_shape.dimensions(rhs_reduction_dimension);
526
527 // Create loop nests which loop through the LHS operand dimensions and the RHS
528 // operand dimensions. The reduction dimension of the LHS and RHS are handled
529 // in a separate innermost loop which performs the sum of products.
530 llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_);
531 std::vector<llvm::Value*> lhs_multi_index =
532 loop_nest.EmitOperandArrayLoopNest(
533 lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
534 std::vector<llvm::Value*> rhs_multi_index =
535 loop_nest.EmitOperandArrayLoopNest(
536 rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
537
538 // We don't have to iterate over the batch dimensions in both arrays, simplify
539 // the loop nest of the rhs.
540 for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
541 DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
542 rhs_multi_index[i] = lhs_multi_index[i];
543 }
544
545 // Create the reduction loop which does the sum of products reduction.
546 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
547 /*start_index=*/0,
548 /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension),
549 /*suffix=*/"reduction");
550
551 // The final entry in the rhs and lhs indexes is the indvar of the reduction
552 // loop.
553 lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
554 rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
555
556 // For computing the sum of products we alloca a single location to store the
557 // dot product result as we accumulate it within the reduction loop. After the
558 // reduction loop we load the result and store into the output array.
559 llvm::Type* accum_type = target_array.GetElementLlvmType();
560 llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
561 accum_type, // The pointee type of the alloca instruction.
562 "accum_address", // The name of the alloca instruction.
563 &b_);
564
565 // Initialize the accumulator in the preheader to zero.
566 new llvm::StoreInst(
567 llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0
568 accum_address, // The address.
569 reduction_loop->GetPreheaderBasicBlock()
570 ->getTerminator()); // The instruction this store is inserted before.
571
572 // Emit the body of the reduction loop:
573 // accum = *accum_address
574 // updated_accum = accum + lhs_element * rhs_element
575 // *accum_address = updated_accum
576 TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
577 b_.SetInsertPoint(
578 &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
579 llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(),
580 b_.getInt64Ty());
581 llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
582 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(),
583 b_.getInt64Ty());
584 llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
585 llvm::Value* accum = Load(accum_address);
586 llvm::Value* updated_accum;
587 if (ShapeUtil::ElementIsComplex(lhs_shape)) {
588 auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
589 llvm::Value* accum_real = Real(accum, &b_);
590 llvm::Value* real_sum = FAdd(accum_real, value.first);
591 updated_accum = InsertValue(accum, real_sum, {0});
592 llvm::Value* accum_imag = Imag(accum, &b_);
593 llvm::Value* imag_sum = FAdd(accum_imag, value.second);
594 updated_accum = InsertValue(updated_accum, imag_sum, {1});
595 } else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
596 llvm::Value* product = FMul(lhs_element, rhs_element);
597 updated_accum = FAdd(accum, product);
598 } else {
599 TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
600 llvm::Value* product = Mul(lhs_element, rhs_element);
601 updated_accum = Add(accum, product);
602 }
603 Store(updated_accum, accum_address);
604
605 // After the reduction loop exits, store the accumulator into the target
606 // address. The index into the target address is the concatenation of the rhs
607 // and lhs indexes with the reduction dimensions removed. The terms from the
608 // rhs index are the lower dimensions in the index so we add them first.
609 std::vector<llvm::Value*> target_multi_index;
610 for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) {
611 if (dimension != lhs_reduction_dimension) {
612 target_multi_index.push_back(lhs_index[dimension]);
613 }
614 }
615 // Skip over the batch dimensions to not have them in the index twice.
616 for (size_t dimension = dnums.lhs_batch_dimensions_size();
617 dimension < rhs_index.size(); ++dimension) {
618 if (dimension != rhs_reduction_dimension) {
619 target_multi_index.push_back(rhs_index[dimension]);
620 }
621 }
622 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
623 llvm_ir::IrArray::Index target_index(target_multi_index,
624 target_array.GetShape(), index_type);
625 target_array.EmitWriteArrayElement(
626 target_index,
627 Load(accum_address), // The value written to the target array.
628 &b_);
629
630 // Set the IR builder insert point to the exit basic block of the outer most
631 // loop. This ensures later instructions are inserted after this loop nest.
632 b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
633
634 return Status::OK();
635 }
636
HandleConvolution(HloInstruction * convolution)637 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
638 if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
639 // Emit no code for an empty output.
640 return Status::OK();
641 }
642 // TODO(b/31409998): Support convolution with dilation.
643 return Unimplemented(
644 "Hit a case for convolution that is not implemented on GPU.");
645 }
646
HandleFft(HloInstruction * fft)647 Status IrEmitter::HandleFft(HloInstruction* fft) {
648 if (ShapeUtil::IsZeroElementArray(fft->shape())) {
649 // Emit no code for an empty output.
650 return Status::OK();
651 }
652 return Unimplemented("Hit a case for fft that is not implemented on GPU.");
653 }
654
HandleAllReduce(HloInstruction * crs)655 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
656 // TODO(b/33011107): Support cross replica sum on GPU.
657 return Unimplemented("AllReduce is not implemented on GPU.");
658 }
659
HandleParameter(HloInstruction * parameter)660 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
661 return Status::OK();
662 }
663
HandleReduce(HloInstruction * reduce)664 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
665 // TODO(b/118332391): Support variadic reduce.
666 if (!reduce->shape().IsArray()) {
667 return Unimplemented("Variadic reduce is not supported on GPU");
668 }
669 auto arg = reduce->operand(0);
670 auto init_value = reduce->operand(1);
671 absl::Span<const int64> dimensions(reduce->dimensions());
672 HloComputation* function = reduce->to_apply();
673 return EmitTargetElementLoop(
674 *reduce,
675 [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
676 // Initialize an accumulator with init_value.
677 llvm::AllocaInst* accumulator_addr =
678 Alloca(llvm_ir::PrimitiveTypeToIrType(
679 reduce->shape().element_type(), module_));
680 Store(Load(GetBasePointer(*init_value)), accumulator_addr);
681
682 // The enclosing loops go over all the target elements. Now we have to
683 // compute the actual target element. For this, we build a new loop nest
684 // to iterate over all the reduction dimensions in the argument.
685 // AddLoopsForShapeOnDimensions will return an Index where induction
686 // Value*s are placed for each dimension in dimensions, and all the rest
687 // are nullptrs.
688 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
689 std::vector<llvm::Value*> input_multi_index =
690 loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
691 "reduction_dim");
692
693 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
694
695 // Build a full index for the input argument, using reduced_dims_index
696 // as the base. In reduced_dims_index only the reduction dimensions are
697 // filled in. We fill in the rest of the dimensions with induction
698 // Value*s taken from 'index' which iterates over the target array.
699 // See the high-level description in the XLA documentation for details.
700 llvm_ir::IrArray::Index::const_iterator it = index.begin();
701
702 for (auto& i : input_multi_index) {
703 if (i == nullptr) {
704 i = *it++;
705 }
706 }
707 CHECK(index.end() == it);
708
709 // Apply the reduction function to the loaded value.
710 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
711 b_.getInt64Ty());
712 llvm::Value* input_address =
713 GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_);
714 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
715 *function, {accumulator_addr, input_address}, accumulator_addr));
716
717 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
718 return Load(accumulator_addr);
719 });
720 }
721
HandleFusion(HloInstruction * fusion)722 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
723 // kFusion for library calls should be handled by
724 // IrEmitterUnnested::HandleFusion.
725 CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
726 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
727 GetNestedComputer());
728 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
729 &elemental_emitter);
730 TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
731
732 return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
733 }
734
HandleCall(HloInstruction * call)735 Status IrEmitter::HandleCall(HloInstruction* call) {
736 std::vector<llvm::Value*> operand_addresses;
737 for (HloInstruction* operand : call->operands()) {
738 operand_addresses.push_back(GetBasePointer(*operand));
739 }
740 return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
741 GetBasePointer(*call));
742 }
743
HandleCustomCall(HloInstruction *)744 Status IrEmitter::HandleCustomCall(HloInstruction*) {
745 return Unimplemented("custom-call");
746 }
747
HandleInfeed(HloInstruction *)748 Status IrEmitter::HandleInfeed(HloInstruction*) {
749 // TODO(b/30467474): Implement infeed on GPU.
750 return Unimplemented("Infeed is not supported on GPU.");
751 }
752
HandleOutfeed(HloInstruction *)753 Status IrEmitter::HandleOutfeed(HloInstruction*) {
754 // TODO(b/34359662): Implement outfeed on GPU.
755 return Unimplemented("Outfeed is not supported on GPU.");
756 }
757
HandleBatchNormInference(HloInstruction *)758 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
759 return Unimplemented(
760 "The GPU backend does not implement BatchNormInference directly. It "
761 "should be lowered before IR emission to HLO-soup using "
762 "BatchNormRewriter or to a cudnn CustomCall using "
763 "CudnnBatchNormRewriter.");
764 }
765
HandleBatchNormTraining(HloInstruction *)766 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
767 return Unimplemented(
768 "The GPU backend does not implement BatchNormTraining directly. It "
769 "should be lowered before IR emission to HLO-soup using "
770 "BatchNormRewriter or to a cudnn CustomCall using "
771 "CudnnBatchNormRewriter.");
772 }
773
HandleBatchNormGrad(HloInstruction *)774 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
775 return Unimplemented(
776 "The GPU backend does not implement BatchNormGrad directly. It should "
777 "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
778 "to a cudnn CustomCall using CudnnBatchNormRewriter.");
779 }
780
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)781 StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
782 const HloComputation& computation,
783 absl::Span<llvm::Value* const> parameter_elements) {
784 llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
785 llvm_ir::PrimitiveTypeToIrType(
786 computation.root_instruction()->shape().element_type(), module_),
787 "return_buffer", &b_);
788 std::vector<llvm::Value*> parameter_buffers;
789 for (llvm::Value* parameter_element : parameter_elements) {
790 parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
791 parameter_element->getType(), "parameter_buffer", &b_));
792 Store(parameter_element, parameter_buffers.back());
793 }
794 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
795 return_buffer));
796 return Load(return_buffer);
797 }
798
ConstructIrArrayForOutputs(const HloInstruction & hlo)799 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
800 const HloInstruction& hlo) {
801 std::vector<llvm_ir::IrArray> output_arrays;
802 if (hlo.shape().IsTuple()) {
803 int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
804 output_arrays.reserve(num_outputs);
805 for (int64 i = 0; i < num_outputs; ++i) {
806 output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
807 }
808 } else {
809 output_arrays.push_back(GetIrArray(hlo, hlo));
810 }
811 return output_arrays;
812 }
813
814 } // namespace gpu
815 } // namespace xla
816