1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstring>
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
24
25 #include "absl/algorithm/container.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/LLVMContext.h"
36 #include "llvm/IR/Module.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
41 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
42 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
43 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
44 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
45 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
46 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
47 #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
48 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
49 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
50 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
51 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
52 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
53 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
54 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
55 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
56 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
57 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
58 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
59 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
60 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
61 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
62 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
63 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
64 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
65 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
66 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
67 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
68 #include "tensorflow/compiler/xla/service/hlo_computation.h"
69 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
70 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
71 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
72 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
73 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
74 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
75 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
76 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
77 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
78 #include "tensorflow/compiler/xla/service/name_uniquer.h"
79 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
80 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
81 #include "tensorflow/compiler/xla/shape_util.h"
82 #include "tensorflow/compiler/xla/status_macros.h"
83 #include "tensorflow/compiler/xla/types.h"
84 #include "tensorflow/compiler/xla/util.h"
85 #include "tensorflow/compiler/xla/window_util.h"
86 #include "tensorflow/compiler/xla/xla_data.pb.h"
87 #include "tensorflow/core/lib/core/bits.h"
88 #include "tensorflow/core/lib/core/status.h"
89 #include "tensorflow/core/platform/logging.h"
90
91 namespace xla {
92 namespace gpu {
93
94 using llvm_ir::KernelMappingScheme;
95 using EmitElementFunction =
96 std::function<void(const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
97 llvm::Value* x_loc, int64 x_iter_num)>;
98
99 namespace {
100
101 using absl::InlinedVector;
102 using absl::nullopt;
103 using absl::optional;
104 using absl::StrCat;
105 using llvm_ir::IrArray;
106 using llvm_ir::IrName;
107
108 namespace m = match;
109
110 // If a dimensions is smaller than this, untiled transposition may be more
111 // efficient.
112 const int64 kMinDimensionToTransposeTiled = 16;
113
114 // Returns true if all paths from `hlo` to `root` contain only tuples. The
115 // result of such an HloInstruction does not need to be materialized, when the
116 // computation can have a hybrid result.
ReachRootViaOnlyTuples(const HloInstruction & hlo,const HloInstruction & root)117 bool ReachRootViaOnlyTuples(const HloInstruction& hlo,
118 const HloInstruction& root) {
119 if (hlo.opcode() != HloOpcode::kTuple) {
120 return false;
121 }
122
123 if (&hlo == &root) {
124 return true;
125 }
126
127 for (HloInstruction* user : hlo.users()) {
128 if (!ReachRootViaOnlyTuples(*user, root)) {
129 return false;
130 }
131 }
132
133 return true;
134 }
135
136 // If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself.
StripTranspose(const HloInstruction & hlo)137 const HloInstruction* StripTranspose(const HloInstruction& hlo) {
138 if (hlo.IsRank2Transpose()) {
139 return hlo.operand(0);
140 }
141 return &hlo;
142 }
143
144 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
145 // of the corresponding IR kernel in "llvm_module".
146 // Precondition: "thunk" must be a KernelThunk.
UpdateLaunchDimensions(const LaunchDimensions & launch_dims,Thunk * thunk,llvm::Module * llvm_module)147 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
148 llvm::Module* llvm_module) {
149 CHECK(Thunk::Kind::kKernel == thunk->kind());
150 KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
151 kernel_thunk->SetLaunchDimensions(launch_dims);
152
153 // Add __launch_bounds__ to metadata. This limits registers per thread to
154 // avoid out-of-resources launching errors.
155 llvm::NamedMDNode* nvvm_annotations_node =
156 llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
157 llvm::Function* ir_kernel =
158 llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
159 llvm::LLVMContext& llvm_context = llvm_module->getContext();
160 llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
161 llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
162 launch_dims.threads_per_block());
163 // Our launch bounds are exact, so we can specify them as reqntidx rather than
164 // maxntidx.
165 nvvm_annotations_node->addOperand(llvm::MDNode::get(
166 llvm_context,
167 {llvm::ConstantAsMetadata::get(ir_kernel),
168 llvm::MDString::get(llvm_context, "reqntidx"),
169 llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
170 }
171
172 } // namespace
173
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)174 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
175 const HloComputation* hlo_computation,
176 IrEmitterContext* ir_emitter_context)
177 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
178 hlo_computation_(hlo_computation) {
179 // Initialize thunk_sequence_ to an empty list of thunks.
180 thunk_sequence_.reset(new ThunkSequence());
181 }
182
Postprocess(HloInstruction * hlo)183 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
184 bindings_.UnbindAllLocalIrValues();
185 return DfsHloVisitor::Postprocess(hlo);
186 }
187
BuildKernelPrototype(const HloInstruction & inst,absl::Span<const BufferAllocation * const> args)188 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
189 const HloInstruction& inst,
190 absl::Span<const BufferAllocation* const> args) {
191 // Compute the kernel name. The opcode string may contain "-" which cannot be
192 // in a PTX function name, so sanitize the name before uniquifying it.
193 string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
194 llvm_ir::SanitizeFunctionName(inst.name()));
195
196 // Create the kernel and add it to the module.
197 llvm::Module* module = ir_emitter_context_->llvm_module();
198 llvm::LLVMContext& context = module->getContext();
199 llvm::FunctionType* kernel_type = llvm::FunctionType::get(
200 /*Result=*/llvm::Type::getVoidTy(context),
201 std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
202 /*isVarArg=*/false);
203 llvm::Function* kernel =
204 llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
205 kernel_name.c_str(), module);
206
207 // Add dereferenceable and alignment information to each of the kernel's
208 // parameters.
209 auto arg_it = kernel->arg_begin();
210 for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
211 const BufferAllocation* alloc = args[arg_no];
212 llvm::Argument* fn_arg = &*arg_it;
213 ++arg_it;
214
215 kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
216
217 const int64 alignment = [&] {
218 if (alloc->is_entry_computation_parameter()) {
219 return kEntryParameterAlignBytes;
220 } else if (alloc->is_constant()) {
221 return kConstantBufferAlignBytes;
222 } else {
223 return kXlaAllocatedBufferAlignBytes;
224 }
225 }();
226
227 kernel->addParamAttr(
228 arg_no,
229 llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
230
231 if (alloc->IsPreallocatedTempBuffer()) {
232 fn_arg->setName("temp_buf");
233 } else {
234 fn_arg->setName(StrCat("alloc", alloc->index()));
235 }
236 }
237
238 // TODO(b/65380986): Investigate if adding fast math flags for generated
239 // kernels makes sense.
240
241 // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
242 // treats it as a CUDA kernel.
243 llvm::NamedMDNode* nvvm_annotations_node =
244 module->getOrInsertNamedMetadata("nvvm.annotations");
245 nvvm_annotations_node->addOperand(llvm::MDNode::get(
246 context, {llvm::ConstantAsMetadata::get(kernel),
247 llvm::MDString::get(context, "kernel"),
248 llvm::ConstantAsMetadata::get(b_.getInt32(1))}));
249
250 // Update the insert point to the entry basic block.
251 llvm::BasicBlock* entry_bb =
252 llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
253
254 // Emit a "return void" at entry_bb's end, and set the insert point before
255 // that return instruction.
256 b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
257
258 return kernel;
259 }
260
261 namespace {
262 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const HloInstruction * hlo)263 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
264 int max_unroll_factor = hlo->GetModule()
265 ->config()
266 .debug_options()
267 .xla_gpu_max_kernel_unroll_factor();
268
269 // Find the largest possible power of two to unroll by.
270 // TODO(kramerb): Make this smarter.
271 const Shape& element_shape = hlo->IsMultiOutputFusion()
272 ? ShapeUtil::GetSubshape(hlo->shape(), {0})
273 : hlo->shape();
274 int64 num_elements = ShapeUtil::ElementsIn(element_shape);
275 for (int i = max_unroll_factor; i > 1; i /= 2) {
276 if (num_elements % i == 0) {
277 return i;
278 }
279 }
280
281 // Cannot unroll.
282 return 1;
283 }
284
285 // Returns the llvm type for the indices used in the kernel that contains the
286 // hlo instruction. Such indices include the index for the parallel loop and
287 // the indices for the tensors accessed by the kernel. The return type is i32
288 // iff the following conditions are met:
289 // . The launch_size of the kernel is within the range of i32.
290 // . The sizes of all the tensors accessed within the kernel are within the
291 // range of i32.
292 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64 launch_size,llvm::IRBuilder<> * b)293 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
294 llvm::IRBuilder<>* b) {
295 // Find the unnested hlo instructon for which the kernel is generated for.
296 const HloInstruction* unnested_hlo = hlo;
297 const HloComputation* computation = hlo->parent();
298 if (computation->IsFusionComputation()) {
299 unnested_hlo = computation->FusionInstruction();
300 }
301
302 auto shape_in_range = [&](const Shape& s) {
303 bool in_range = true;
304 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
305 const ShapeIndex& /*index*/) {
306 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
307 in_range = false;
308 }
309 });
310
311 return in_range;
312 };
313
314 llvm::Type* i64_ty = b->getInt64Ty();
315 // Check launch dimension
316 if (!IsInt32(launch_size)) {
317 return i64_ty;
318 }
319
320 // Check the size of result tensors
321 if (!shape_in_range(unnested_hlo->shape())) {
322 return i64_ty;
323 }
324
325 auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
326 return shape_in_range(operand->shape());
327 };
328
329 // Check the size of input tensors
330 if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
331 return i64_ty;
332 }
333
334 // Check the size of the internal result tensors
335 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
336 if (!absl::c_all_of(
337 unnested_hlo->fused_instructions_computation()->instructions(),
338 hlo_shape_in_range)) {
339 return i64_ty;
340 }
341 }
342
343 return b->getInt32Ty();
344 }
345
346 } // namespace
347
DefaultAction(HloInstruction * hlo)348 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
349 return IrEmitter::DefaultAction(hlo);
350 }
351
HandleDot(HloInstruction * dot)352 Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
353 if (ImplementedAsGemm(*dot)) {
354 AddThunkToThunkSequence(BuildGemmThunk(dot));
355 return Status::OK();
356 }
357 AddThunkToThunkSequence(
358 BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
359 return IrEmitter::HandleDot(dot);
360 }
361
HandleConditional(HloInstruction * conditional)362 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
363 AddThunkToThunkSequence(BuildConditionalThunk(conditional));
364 return Status::OK();
365 }
366
HandleConvolution(HloInstruction * convolution)367 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
368 AddThunkToThunkSequence(
369 BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
370 return IrEmitter::HandleConvolution(convolution);
371 }
372
HandleCustomCall(HloInstruction * custom_call)373 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
374 // A CustomCall on the GPU backend can either be a custom-call to a
375 // user-supplied kernel, or a call into a library like cudnn.
376
377 // Lower custom-calls to cudnn batchnorm ops to specialized thunks. It's part
378 // of the contract of these cudnn batchnorm calls that the epsilon and
379 // feature_index operands be constants.
380 if (custom_call->custom_call_target() ==
381 kCudnnBatchNormForwardInferenceCallTarget) {
382 const HloInstruction* epsilon = custom_call->operand(5);
383 CHECK(epsilon->IsConstant());
384 float epsilon_value = epsilon->literal().Get<float>({});
385
386 const HloInstruction* feature_index = custom_call->operand(6);
387 CHECK(feature_index->IsConstant());
388 int64 feature_index_value = feature_index->literal().Get<int64>({});
389
390 AddThunkToThunkSequence(
391 absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
392 /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
393 /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
394 /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
395 /*mean=*/GetAllocationSlice(*custom_call->operand(3)),
396 /*variance=*/GetAllocationSlice(*custom_call->operand(4)),
397 /*epsilon=*/epsilon_value,
398 /*feature_index=*/feature_index_value,
399 /*output=*/GetAllocationSlice(*custom_call),
400 /*hlo=*/custom_call));
401 return Status::OK();
402 }
403
404 if (custom_call->custom_call_target() ==
405 kCudnnBatchNormForwardTrainingCallTarget) {
406 const HloInstruction* epsilon = custom_call->operand(3);
407 CHECK(epsilon->IsConstant());
408 float epsilon_value = epsilon->literal().Get<float>({});
409
410 const HloInstruction* feature_index = custom_call->operand(4);
411 CHECK(feature_index->IsConstant());
412 int64 feature_index_value = feature_index->literal().Get<int64>({});
413
414 // BatchNormTraining returns a tuple of three elements: data, calculated
415 // mean, and calculated 1/sqrt(variance + epsilon).
416 const auto& assn = ir_emitter_context_->buffer_assignment();
417 auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
418 auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
419 auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
420 AddThunkToThunkSequence(
421 absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
422 /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
423 /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
424 /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
425 /*epsilon=*/epsilon_value,
426 /*feature_index=*/feature_index_value,
427 /*output_data=*/output_data,
428 /*output_mean=*/output_mean,
429 /*output_inv_stddev=*/output_inv_stddev,
430 /*output_tuple=*/GetAllocationSlice(*custom_call),
431 /*hlo=*/custom_call));
432 return Status::OK();
433 }
434
435 if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) {
436 const HloInstruction* epsilon = custom_call->operand(5);
437 CHECK(epsilon->IsConstant());
438 float epsilon_value = epsilon->literal().Get<float>({});
439
440 const HloInstruction* feature_index = custom_call->operand(6);
441 CHECK(feature_index->IsConstant());
442 int64 feature_index_value = feature_index->literal().Get<int64>({});
443
444 // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
445 // grad_offset.
446 const auto& assn = ir_emitter_context_->buffer_assignment();
447 auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
448 auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
449 auto output_grad_offset =
450 assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
451 AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
452 /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
453 /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
454 /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
455 /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
456 /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
457 /*epsilon=*/epsilon_value,
458 /*feature_index=*/feature_index_value,
459 /*output_grad_data=*/output_grad_data,
460 /*output_grad_scale=*/output_grad_scale,
461 /*output_grad_offset=*/output_grad_offset,
462 /*output_tuple=*/GetAllocationSlice(*custom_call),
463 /*hlo=*/custom_call));
464 return Status::OK();
465 }
466
467 if (IsCustomCallToDnnConvolution(*custom_call)) {
468 const auto& assn = ir_emitter_context_->buffer_assignment();
469 std::vector<BufferAllocation::Slice> operand_slices;
470 operand_slices.reserve(custom_call->operand_count());
471 for (const auto* operand : custom_call->operands()) {
472 operand_slices.push_back(GetAllocationSlice(*operand));
473 }
474 auto tuple_result_slice = GetAllocationSlice(*custom_call);
475 auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
476 auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
477
478 AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
479 Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
480 conv_result_slice, scratch_slice, tuple_result_slice));
481 return Status::OK();
482 }
483
484 if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
485 TF_ASSIGN_OR_RETURN(CholeskyOptions options,
486 custom_call->backend_config<CholeskyOptions>());
487
488 const Shape& shape = custom_call->operand(0)->shape();
489 int ndim = shape.dimensions_size();
490 CHECK_GE(ndim, 2);
491 int64 n = shape.dimensions(ndim - 1);
492
493 const auto& dims = shape.dimensions();
494 int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
495 [](int64 a, int64 b) { return a * b; });
496
497 auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
498
499 const auto& assn = ir_emitter_context_->buffer_assignment();
500 auto a_buffer = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
501 auto workspace_buffer = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
502 auto info_buffer = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
503
504 std::vector<std::unique_ptr<Thunk>> thunks;
505
506 if (operand_buffer != a_buffer) {
507 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
508 /*source_address=*/operand_buffer,
509 /*destination_buffer=*/a_buffer,
510 /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
511 }
512
513 thunks.push_back(absl::make_unique<CholeskyThunk>(
514 options, a_buffer, workspace_buffer, info_buffer,
515 custom_call->operand(0)->shape().element_type(), batch_size, n,
516 custom_call));
517
518 // Elide the sequential thunk if there's no copy.
519 if (thunks.size() == 1) {
520 AddThunkToThunkSequence(std::move(thunks[0]));
521 } else {
522 AddThunkToThunkSequence(
523 absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
524 }
525
526 return Status::OK();
527 }
528
529 return IrEmitter::HandleCustomCall(custom_call);
530 }
531
HandleFft(HloInstruction * fft)532 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
533 TF_RET_CHECK(
534 LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout()));
535 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
536 AddThunkToThunkSequence(BuildFftThunk(fft));
537 return Status::OK();
538 }
539
HandleTriangularSolve(HloInstruction * hlo)540 Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
541 auto has_fortran_layout = [](const Layout& layout) {
542 int n = layout.minor_to_major_size();
543 return layout.minor_to_major(0) == n - 2 &&
544 layout.minor_to_major(1) == n - 1;
545 };
546 TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout()));
547 TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout()));
548 TF_RET_CHECK(has_fortran_layout(hlo->shape().layout()));
549
550 std::vector<std::unique_ptr<Thunk>> thunks;
551
552 // Triangular solve is in-place on 'b', so copy 'b' to the output if they
553 // aren't the same buffer.
554 auto operand_buffer = GetAllocationSlice(*hlo->operand(1));
555 auto destination_buffer = GetAllocationSlice(*hlo);
556 if (operand_buffer != destination_buffer) {
557 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
558 /*source_address=*/operand_buffer,
559 /*destination_buffer=*/destination_buffer,
560 /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo));
561 }
562
563 thunks.push_back(BuildTriangularSolveThunk(hlo));
564
565 // Elide the sequential thunk if there's no copy.
566 if (thunks.size() == 1) {
567 AddThunkToThunkSequence(std::move(thunks[0]));
568 } else {
569 AddThunkToThunkSequence(
570 absl::make_unique<SequentialThunk>(std::move(thunks), hlo));
571 }
572 return Status::OK();
573 }
574
HandleFusion(HloInstruction * fusion)575 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
576 HloInstruction* root = fusion->fused_expression_root();
577 if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
578 switch (root->opcode()) {
579 case HloOpcode::kScatter: {
580 std::vector<std::unique_ptr<Thunk>> thunks;
581 // The initialization from 'operand' is using different loop bounds, so
582 // emit it in a separate kernel. Treat it like a loop fusion, writing to
583 // the output buffer.
584 {
585 int unroll_factor = ComputeMaxUnrollFactor(fusion);
586 thunks.push_back(BuildKernelThunk(
587 fusion, /*implements_whole_instruction=*/false, unroll_factor));
588
589 GpuElementalIrEmitter operand_elemental_emitter(
590 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
591 GetNestedComputer());
592 FusedIrEmitter operand_fused_emitter(
593 GetGeneratorForOperandIrArrays(fusion),
594 &operand_elemental_emitter);
595 TF_RETURN_IF_ERROR(
596 root->mutable_operand(0)->Accept(&operand_fused_emitter));
597
598 TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
599 *fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
600 static_cast<KernelThunk*>(thunks.back().get())));
601 }
602
603 // Now build the actual scatter, reading and writing to the freshly
604 // filled output buffer.
605 {
606 thunks.push_back(
607 BuildKernelThunk(fusion,
608 /*implements_whole_instruction=*/false));
609 // Spin up a new fused emitter for the scatter kernel and emit it.
610 GpuElementalIrEmitter scatter_elemental_emitter(
611 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
612 GetNestedComputer());
613 FusedIrEmitter scatter_fused_emitter(
614 GetGeneratorForOperandIrArrays(fusion),
615 &scatter_elemental_emitter);
616 TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
617 TF_RETURN_IF_ERROR(EmitScatter(
618 thunks.back().get(), root,
619 /*scatter_indices_gen=*/
620 scatter_fused_emitter.GetGenerator(root->operand(1)),
621 /*updates_gen=*/
622 scatter_fused_emitter.GetGenerator(root->operand(2))));
623 }
624 AddThunkToThunkSequence(
625 absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
626 return Status::OK();
627 }
628 case HloOpcode::kTuple:
629 case HloOpcode::kReduce: {
630 // HandleFusion specializes reduction from a multi-dimensional array to
631 // a 1D array. The specialized version requires a initializer thunk that
632 // initializes the output array to the initial value of the reduce.
633 if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) {
634 // TODO(b/118332391): Support variadic reduce.
635 return Unimplemented("Variadic reduce is not supported on GPU");
636 }
637 return EmitReductionToVector(fusion);
638 }
639 default:
640 LOG(FATAL) << "Bad opcode for input fusion: "
641 << fusion->fused_expression_root()->opcode();
642 }
643 } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
644 fusion, ir_emitter_context_->buffer_assignment())) {
645 // Fusion node with dynamic-update-slice as the root where the op's input
646 // (i.e. array to update) shares the same slice as its output. In this case
647 // we have a special algorithm that modifies the output in place without
648 // touching the un-updated elements.
649
650 // Set up kernel thunk and fused ir emitter.
651 std::unique_ptr<KernelThunk> fusion_thunk =
652 BuildKernelThunk(fusion, /*implements_whole_instruction=*/true);
653 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
654 ir_emitter_context_->llvm_module(),
655 &b_, GetNestedComputer());
656
657 // Shape of the dynamic-update-slice's "update" operand.
658 Shape update_shape = root->operand(1)->shape();
659
660 // Array to write into. Because this is an in-place operation, this is the
661 // same as operand 0's array.
662 IrArray output_array = GetIrArray(*fusion, *fusion);
663
664 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
665 update_shape, ir_emitter_context_->device_description());
666 UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(),
667 ir_emitter_context_->llvm_module());
668 AddThunkToThunkSequence(std::move(fusion_thunk));
669
670 return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
671 fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
672 &elemental_emitter, launch_dimensions, &b_);
673 }
674
675 if (ImplementedAsGemm(*fusion)) {
676 AddThunkToThunkSequence(BuildGemmThunk(fusion));
677 return Status::OK();
678 }
679
680 CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop);
681
682 if (CheckAndEmitHloWithTile021(fusion)) {
683 return Status::OK();
684 }
685
686 return IrEmitter::HandleFusion(fusion);
687 }
688
HandleCopy(HloInstruction * copy)689 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
690 CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
691 const BufferAssignment& buffer_assignment =
692 ir_emitter_context_->buffer_assignment();
693 if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
694 copy->shape().layout()) &&
695 buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
696 AddThunkToThunkSequence(BuildDeviceToDeviceCopyThunk(copy));
697 return Status::OK();
698 }
699 if (CheckAndEmitHloWithTile021(copy)) {
700 return Status::OK();
701 }
702
703 return IrEmitter::HandleCopy(copy);
704 }
705
EmitExtraOutputsForReduce(const HloInstruction * unnested_hlo,const IrArray::Index & index,absl::Span<const std::pair<llvm_ir::ElementGenerator,ShapeIndex>> extra_output_gens)706 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
707 const HloInstruction* unnested_hlo, const IrArray::Index& index,
708 absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
709 extra_output_gens) {
710 for (int i = 0; i != extra_output_gens.size(); ++i) {
711 llvm::Value* extra_output_address =
712 GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second)
713 .EmitArrayElementAddress(index, &b_,
714 "extra_output_element_address");
715 TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
716 extra_output_gens[i].first(index));
717 Store(extra_output_ir_value, extra_output_address);
718 }
719 return Status::OK();
720 }
721
HandleReduce(HloInstruction * reduce)722 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
723 // TODO(b/118332391): Support multi-output reduce.
724 if (!reduce->shape().IsArray()) {
725 return Unimplemented("Multi-output reduce is not supported on GPU");
726 }
727 if (IsReductionToVector(*reduce)) {
728 return EmitReductionToVector(reduce);
729 }
730
731 return IrEmitter::HandleReduce(reduce);
732 }
733
HandleTuple(HloInstruction * tuple)734 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
735 // For the root node of the entry computation we can elide writing the tuple
736 // buffer. We can always figure out the contents of the tuples from buffer
737 // assignment because we insert copies to ensure non-ambiguous output buffers.
738 // GpuExecutable never reads the tuple buffer.
739 if (tuple ==
740 tuple->parent()->parent()->entry_computation()->root_instruction()) {
741 return Status::OK();
742 }
743 bool all_tuple_elements_have_buffer =
744 absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
745 return ir_emitter_context_->buffer_assignment()
746 .GetUniqueTopLevelSlice(tuple_element)
747 .ok();
748 });
749 // TODO(b/111689850): This logic isn't quite correct.
750 //
751 // Tuples (especially tuples that are the final result of a computation) can
752 // be so huge that if we were to emit a kernel that took each tuple element as
753 // a parameter, we would exceed the max allowable number of parameters to a
754 // GPU kernel, b/31336476. As an optimization, if all tuple elements have a
755 // buffer, we collect their buffer addresses in a host array, and then copy
756 // that array to the tuple's buffer.
757 //
758 // Some tuple elements might not have an unambiguous buffer (like the result
759 // of a select-tuple). In that case, we fall back to emitting kernels which
760 // have access to their buffer addresses in code.
761 if (all_tuple_elements_have_buffer) {
762 std::vector<BufferAllocation::Slice> tuple_element_buffers;
763 for (const HloInstruction* tuple_element : tuple->operands()) {
764 tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
765 }
766 AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
767 tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
768 return Status::OK();
769 }
770 AddThunkToThunkSequence(
771 BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
772 return IrEmitter::HandleTuple(tuple);
773 }
774
HandleGetTupleElement(HloInstruction *)775 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
776 // GetTupleElement IR is emitted in the IR context of the user instruction,
777 // and so we do not build a kernel for GetTupleElement instructions.
778 return Status::OK();
779 }
780
HandleSelectAndScatter(HloInstruction * select_and_scatter)781 Status IrEmitterUnnested::HandleSelectAndScatter(
782 HloInstruction* select_and_scatter) {
783 CHECK_EQ(select_and_scatter->operand_count(), 3);
784 const auto* operand = select_and_scatter->operand(0);
785 const auto* source = select_and_scatter->operand(1);
786 const Window& window = select_and_scatter->window();
787 PrimitiveType operand_element_type = operand->shape().element_type();
788 const int64 rank = operand->shape().rank();
789 CHECK_EQ(rank, source->shape().rank());
790 CHECK_EQ(rank, window.dimensions_size());
791
792 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
793 BuildInitializerThunk(select_and_scatter));
794 std::vector<std::unique_ptr<Thunk>> thunks;
795 thunks.push_back(std::move(initializer_thunk));
796 thunks.push_back(BuildKernelThunk(select_and_scatter,
797 /*implements_whole_instruction=*/false));
798 std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
799 absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
800
801 // TODO(b/31410564): Implement dilation rate for select-and-scatter.
802 if (window_util::HasDilation(window)) {
803 return Unimplemented(
804 "Dilation for SelectAndScatter not implemented on GPU.");
805 }
806
807 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
808 source->shape(), ir_emitter_context_->device_description());
809 llvm::Type* index_type = GetIndexTypeForKernel(
810 select_and_scatter, launch_dimensions.launch_bound(), &b_);
811 auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
812 return llvm::ConstantInt::get(index_type, c);
813 };
814
815 // kSelectAndScatter is implemented as two kernel launches: the first launch
816 // initializes the output array to the given initial value,
817 // and the second accumulates the "source" matrix to the
818 // selected elements in the output array. The first launch is already
819 // implemented by the initializer thunk generated earlier, so this function
820 // only needs to take care of the select-and-scatter part.
821 //
822 // Pseudo code for select-and-scatter:
823 //
824 // for (coordinates S in the source): # This loop is parallel.
825 // initialized_flag = false
826 // for (coordinates W in the window):
827 // I = S * stride + W - pad_low
828 // if I within bounds of operand:
829 // if !(initialized_flag and select(selected_value, operand(I))):
830 // selected_value = operand(I)
831 // selected_index = I
832 // initialized_flag = true
833 // output(selected_index) = scatter(output(selected_index), source(S))
834 auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
835 // Allocate space to keep the currently selected value, its index, and a
836 // boolean flag if the value is initialized. The initialized_flag is set
837 // false.
838 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
839 llvm_ir::PrimitiveTypeToIrType(operand_element_type,
840 ir_emitter_context_->llvm_module()),
841 "selected_value_address", &b_);
842 llvm::Value* selected_index_address =
843 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
844 index_type, index_typed_constant(rank), "selected_index_address",
845 &b_);
846 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
847 b_.getInt1Ty(), "initialized_flag_address", &b_);
848 Store(b_.getInt1(false), initialized_flag_address);
849
850 // Create the inner loop to iterate over the window.
851 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
852 index_type);
853 DimensionVector window_size;
854 for (const auto& dim : window.dimensions()) {
855 window_size.push_back(dim.size());
856 CHECK_GT(dim.size(), 0);
857 }
858 const IrArray::Index window_index = window_loops.AddLoopsForShape(
859 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
860 llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
861 &b_);
862
863 // Compute the operand index to visit and evaluate the condition whether the
864 // operand index is within the bounds. The unsigned comparison includes
865 // checking whether the operand index >= 0.
866 std::vector<llvm::Value*> operand_multi_index(source_index.size());
867 llvm::Value* in_bounds_condition = b_.getInt1(true);
868 for (int64 i = 0; i < rank; ++i) {
869 llvm::Value* strided_index = NSWMul(
870 source_index[i], index_typed_constant(window.dimensions(i).stride()));
871 operand_multi_index[i] =
872 NSWSub(NSWAdd(strided_index, window_index[i]),
873 index_typed_constant(window.dimensions(i).padding_low()));
874 llvm::Value* index_condition = ICmpULT(
875 operand_multi_index[i],
876 index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
877 in_bounds_condition = And(in_bounds_condition, index_condition);
878 }
879 CHECK(in_bounds_condition != nullptr);
880
881 // Only need to do something if the operand index is within the bounds.
882 // First check if the initialized_flag is set.
883 llvm_ir::LlvmIfData if_in_bounds =
884 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
885 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
886 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
887 Load(initialized_flag_address), "initialized", &b_);
888
889 // If the initialized_flag is false, initialize the selected value and index
890 // with the currently visiting operand.
891 llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
892 const auto save_operand_index = [&](const IrArray::Index& operand_index) {
893 for (int64 i = 0; i < rank; ++i) {
894 llvm::Value* selected_index_address_slot =
895 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
896 Store(operand_index[i], selected_index_address_slot);
897 }
898 };
899 IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
900 IrArray::Index operand_index(operand_multi_index, operand->shape(),
901 index_type);
902 llvm::Value* operand_data =
903 operand_array.EmitReadArrayElement(operand_index, &b_);
904 Store(operand_data, selected_value_address);
905 save_operand_index(operand_index);
906 Store(b_.getInt1(true), initialized_flag_address);
907
908 // If the initialized_flag is true, call the `select` function to
909 // potentially update the selected value and index with the currently
910 // visiting operand.
911 llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
912 llvm::Value* operand_address =
913 operand_array.EmitArrayElementAddress(operand_index, &b_);
914 llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
915 llvm_ir::PrimitiveTypeToIrType(PRED,
916 ir_emitter_context_->llvm_module()),
917 "select_return_buffer", &b_);
918 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
919 *select_and_scatter->select(),
920 {selected_value_address, operand_address}, select_return_buffer));
921 llvm::Value* result = Load(select_return_buffer);
922
923 // If the 'select' function returns false, update the selected value and the
924 // index to the currently visiting operand.
925 llvm::Value* cond = ICmpNE(
926 result,
927 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
928 PRED, ir_emitter_context_->llvm_module()),
929 0),
930 "boolean_predicate");
931 llvm_ir::LlvmIfData if_select_lhs =
932 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
933 llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
934 Store(Load(operand_address), selected_value_address);
935 save_operand_index(operand_index);
936
937 // After iterating over the window elements, scatter the source element to
938 // the selected index of the output. The value we store at the output
939 // location is computed by calling the `scatter` function with the source
940 // value and the current output value.
941 llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
942 &b_);
943 std::vector<llvm::Value*> selected_multi_index;
944 for (int64 i = 0; i < rank; ++i) {
945 llvm::Value* selected_index_address_slot =
946 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
947 selected_multi_index.push_back(Load(selected_index_address_slot));
948 }
949 llvm::Value* source_value_address =
950 GetIrArray(*source, *select_and_scatter)
951 .EmitArrayElementAddress(source_index, &b_);
952 IrArray::Index selected_index(selected_multi_index,
953 select_and_scatter->shape(),
954 operand_index.GetType());
955 llvm::Value* output_value_address =
956 GetIrArray(*select_and_scatter, *select_and_scatter)
957 .EmitArrayElementAddress(selected_index, &b_);
958 return EmitAtomicOperationForNestedComputation(
959 *select_and_scatter->scatter(), output_value_address,
960 source_value_address);
961 };
962
963 UpdateLaunchDimensions(
964 launch_dimensions,
965 // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
966 // consisting of two thunks, an initializer KernelThunk that initializes
967 // the output and another KernelThunk that accumulates the scattered
968 // elements.
969 select_and_scatter_thunk->thunks().back().get(),
970 ir_emitter_context_->llvm_module());
971 AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
972 return ParallelLoopEmitter(loop_body_emitter, source->shape(),
973 launch_dimensions, &b_)
974 .EmitLoop(IrName(select_and_scatter), index_type);
975 }
976
HandleWhile(HloInstruction * xla_while)977 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
978 HloComputation* condition = xla_while->while_condition();
979 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
980 condition->root_instruction()->shape().element_type() == PRED)
981 << "While condition computation must return bool";
982 // Build ForThunk for conformant while loops, otherwise build WhileThunk.
983 auto config = xla_while->backend_config<WhileLoopBackendConfig>();
984 if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
985 AddThunkToThunkSequence(
986 BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
987 } else {
988 AddThunkToThunkSequence(BuildWhileThunk(xla_while));
989 }
990 return Status::OK();
991 }
992
HandleRng(HloInstruction * rng)993 Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
994 // Build the kernel to generate the random numbers.
995 //
996 // Unroll the kernel so that the duplicated computation that calculates the
997 // 128 bit sample can be optimized away by LLVM.
998 std::unique_ptr<KernelThunk> rng_thunk = BuildKernelThunk(
999 rng, /*implements_whole_instruction=*/false, ComputeMaxUnrollFactor(rng));
1000 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
1001 for (const HloInstruction* operand : rng->operands()) {
1002 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
1003 return GetIrArray(*operand, *rng).EmitReadArrayElement(index, &b_);
1004 };
1005 }
1006 TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
1007 *rng,
1008 GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
1009 GetNestedComputer())
1010 .MakeElementGenerator(rng, operand_to_generator),
1011 rng_thunk.get()));
1012
1013 // Emit a kernel to increment the global state for Philox RNG algorithm.
1014 std::unique_ptr<Thunk> increment_seed_thunk =
1015 BuildKernelThunk(rng, /*implements_whole_instruction=*/false);
1016 llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_);
1017
1018 // Build the SequentialThunk for the RNG hlo.
1019 std::vector<std::unique_ptr<Thunk>> thunks;
1020 thunks.reserve(2);
1021 thunks.push_back(std::move(rng_thunk));
1022 thunks.push_back(std::move(increment_seed_thunk));
1023 AddThunkToThunkSequence(
1024 absl::make_unique<SequentialThunk>(std::move(thunks), rng));
1025
1026 return Status::OK();
1027 }
1028
HandleScatter(HloInstruction * scatter)1029 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
1030 const HloInstruction* operand = scatter->operand(0);
1031 const HloInstruction* scatter_indices = scatter->operand(1);
1032 const HloInstruction* updates = scatter->operand(2);
1033
1034 std::vector<std::unique_ptr<Thunk>> thunks;
1035
1036 // Copy the operand into the output if it's not the same buffer already.
1037 auto operand_buffer = GetAllocationSlice(*operand);
1038 auto destination_buffer = GetAllocationSlice(*scatter);
1039 if (operand_buffer != destination_buffer) {
1040 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1041 /*source_address=*/operand_buffer,
1042 /*destination_buffer=*/destination_buffer,
1043 /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter));
1044 }
1045
1046 thunks.push_back(
1047 BuildKernelThunk(scatter,
1048 /*implements_whole_instruction=*/thunks.empty()));
1049
1050 TF_RETURN_IF_ERROR(EmitScatter(
1051 thunks.back().get(), scatter,
1052 /*scatter_indices_gen=*/
1053 [=](const IrArray::Index& index) {
1054 return GetIrArray(*scatter_indices, *scatter)
1055 .EmitReadArrayElement(index, &b_, "scatter_index");
1056 },
1057 /*updates_gen=*/
1058 [=](const IrArray::Index& index) {
1059 return GetIrArray(*updates, *scatter)
1060 .EmitReadArrayElement(index, &b_, "update");
1061 }));
1062
1063 // Elide the sequential thunk if there's no copy.
1064 if (thunks.size() == 1) {
1065 AddThunkToThunkSequence(std::move(thunks[0]));
1066 } else {
1067 AddThunkToThunkSequence(
1068 absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
1069 }
1070
1071 return Status::OK();
1072 }
1073
EmitScatter(Thunk * thunk,HloInstruction * scatter,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen)1074 Status IrEmitterUnnested::EmitScatter(
1075 Thunk* thunk, HloInstruction* scatter,
1076 const llvm_ir::ElementGenerator& scatter_indices_gen,
1077 const llvm_ir::ElementGenerator& updates_gen) {
1078 const HloInstruction* operand = scatter->operand(0);
1079 const HloInstruction* scatter_indices = scatter->operand(1);
1080 const HloInstruction* updates = scatter->operand(2);
1081 const ScatterDimensionNumbers& dim_numbers =
1082 scatter->scatter_dimension_numbers();
1083 CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
1084
1085 auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
1086 std::vector<llvm::Value*> raw_window_multidim;
1087 std::vector<llvm::Value*> input_scatter_multidim;
1088 std::vector<int64> raw_window_bounds;
1089
1090 // Partition the index into window indices and scatter indices.
1091 for (int64 i = 0, e = index.size(); i != e; ++i) {
1092 // For window indices also remember the window size, this comes in handy
1093 // later.
1094 if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
1095 raw_window_multidim.push_back(index[i]);
1096 raw_window_bounds.push_back(updates->shape().dimensions(i));
1097 } else {
1098 input_scatter_multidim.push_back(index[i]);
1099 }
1100 }
1101 DCHECK_EQ(raw_window_multidim.size(),
1102 dim_numbers.update_window_dims_size());
1103
1104 // Apply inserted_window_dims to the window dimensions.
1105 int64 raw_window_multidim_idx = 0;
1106 std::vector<llvm::Value*> input_window_multidim;
1107 std::vector<int64> input_window_bounds;
1108 for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
1109 if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
1110 input_window_bounds.push_back(1); // Trivial dimension.
1111 input_window_multidim.push_back(index.GetConstantWithIndexType(0));
1112 } else {
1113 input_window_bounds.push_back(
1114 raw_window_bounds[raw_window_multidim_idx]);
1115 input_window_multidim.push_back(
1116 raw_window_multidim[raw_window_multidim_idx]);
1117 ++raw_window_multidim_idx;
1118 }
1119 }
1120 DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
1121
1122 // Insert a 1 dimension at the end if index_vector_dim requests one.
1123 Shape scatter_indices_shape = scatter_indices->shape();
1124 if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
1125 scatter_indices_shape.add_dimensions(1);
1126 scatter_indices_shape.mutable_layout()->add_minor_to_major(
1127 dim_numbers.index_vector_dim());
1128 }
1129
1130 // Now load the indices corresponding to the current window from
1131 // scatter_indices.
1132 std::vector<llvm::Value*> raw_scatter_index_multidim =
1133 input_scatter_multidim;
1134 raw_scatter_index_multidim.insert(
1135 raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(),
1136 nullptr);
1137 llvm::Value* is_in_bounds = b_.getTrue();
1138 for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
1139 i != e; ++i) {
1140 // Our index is stored along index_vector_dim, insert that into the lookup
1141 // index into scatter_indices.
1142 raw_scatter_index_multidim[dim_numbers.index_vector_dim()] =
1143 index.GetConstantWithIndexType(i);
1144 llvm_ir::IrArray::Index raw_scatter_index_index(
1145 raw_scatter_index_multidim, scatter_indices_shape, index.GetType());
1146
1147 int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
1148 TF_ASSIGN_OR_RETURN(
1149 llvm::Value* const loaded_scatter_index,
1150 scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
1151 scatter_indices_shape, scatter_indices->shape(), &b_)));
1152 // And add the index to our window index. This yields the output index.
1153 llvm::Value* casted_scatter_index =
1154 IntCast(loaded_scatter_index, index.GetType(),
1155 /*isSigned=*/true);
1156 llvm::Value* dim_offset =
1157 Add(input_window_multidim[operand_dim], casted_scatter_index);
1158 input_window_multidim[operand_dim] = dim_offset;
1159
1160 // Also do the bounds check now.
1161 int64 max_index = operand->shape().dimensions(operand_dim) -
1162 input_window_bounds[operand_dim] + 1;
1163 // is_in_bounds = index >= 0 && index < dim_size-window_size+1
1164 // --> index u< dim_size-window_size+1
1165 is_in_bounds =
1166 And(is_in_bounds, ICmpULT(casted_scatter_index,
1167 index.GetConstantWithIndexType(max_index)));
1168 }
1169
1170 llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
1171 is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
1172 llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
1173 // All done, now just read from the calculated input from the window, and do
1174 // an atomic store to the calculated location in the output.
1175 llvm_ir::IrArray::Index input_window_index(input_window_multidim,
1176 index.GetType());
1177 HloInstruction* output_hlo =
1178 scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
1179 llvm::Value* output_address =
1180 GetIrArray(*output_hlo, *output_hlo)
1181 .EmitArrayElementAddress(input_window_index, &b_);
1182 llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType(
1183 updates->shape().element_type(), module_));
1184 TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
1185 Store(input_ir_value, input_address);
1186 return EmitAtomicOperationForNestedComputation(
1187 *scatter->to_apply(), output_address, input_address);
1188 };
1189
1190 // Launch a kernel that reads every element in the updates tensor. We could
1191 // also do one kernel per window instead if bounds checks turn out to be a
1192 // bottleneck.
1193 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1194 updates->shape(), ir_emitter_context_->device_description());
1195 UpdateLaunchDimensions(launch_dimensions, thunk,
1196 ir_emitter_context_->llvm_module());
1197
1198 return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
1199 launch_dimensions, &b_)
1200 .EmitLoop(IrName(scatter),
1201 GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
1202 &b_));
1203 }
1204
HandleSelect(HloInstruction * select)1205 Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
1206 return IrEmitter::HandleSelect(select);
1207 }
1208
HandleSort(HloInstruction * sort)1209 Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
1210 std::vector<std::unique_ptr<Thunk>> thunks;
1211 Shape keys_shape = sort->operand(0)->shape();
1212 int64 dimension_to_sort = sort->dimensions(0);
1213 for (int64 i = 0; i < sort->operand_count(); ++i) {
1214 ShapeIndex shape_index =
1215 sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
1216 // We assume that the layout of all involved operands and outputs is the
1217 // same.
1218 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
1219 sort->operand(i)->shape()));
1220 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
1221 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
1222
1223 // If possible, we share buffers. If that is not possible, we need to copy
1224 // the values, because the emitter does the sorting in-place.
1225 auto destination_buffer = GetAllocationSlice(*sort, shape_index);
1226 auto source_address = GetAllocationSlice(*sort->operand(i));
1227 if (destination_buffer != source_address) {
1228 // TODO(b/26783907): Figure out why we never seem to share buffers for
1229 // key/value sort.
1230 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1231 /*source_address=*/source_address,
1232 /*destination_buffer=*/destination_buffer,
1233 /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
1234 nullptr));
1235 }
1236 }
1237
1238 uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
1239 int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
1240 CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
1241 CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
1242
1243 // Naive C++ code for the outer loops:
1244 //
1245 // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
1246 // ++stage) {
1247 // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
1248 // SortInPlace(first_xor_mask);
1249 // for (int64 mask = stage - 1; mask >= 0; --mask) {
1250 // int64 later_xor_mask = 1LL << mask;
1251 // SortInPlace(later_xor_mask);
1252 // }
1253 // }
1254 //
1255 // This follows the alternative representation of the algorithm described on
1256 // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
1257 //
1258 // Each mask specifies how to derive from one position in the array the
1259 // position with which it should be compared (we calculate the xor of the
1260 // position with the mask).
1261 // As an optimization, we can move the 'mask' loop to inside the
1262 // sorting/comparison loop if the comparisons happen within a small block of
1263 // the array. To make this work, we collect all consecutive masks that are
1264 // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
1265 // Each thread then processes one tile of data.
1266
1267 const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
1268
1269 // If we cannot combine several xor masks together, we don't use tiling, so we
1270 // calculate the standard launch dimensions for the shape. However we only
1271 // need to iterate through ~half of the dimension to sort (rounded up to the
1272 // next highest power of 2), because each iteration compares one pair of
1273 // elements.
1274 Shape standard_iteration_shape = keys_shape;
1275 uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
1276 standard_iteration_shape.set_dimensions(dimension_to_sort,
1277 standard_num_iterations_in_sort_dim);
1278 LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions(
1279 standard_iteration_shape, ir_emitter_context_->device_description());
1280
1281 // Calculate the launch dimensions for the case where we use tiling. We split
1282 // the dimension that should be sorted into tiles of size 'kTileSize'. This
1283 // means we first need to round 'dimension_to_sort_bound' up to be a multiple
1284 // of the tile size.
1285 int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
1286 Shape iteration_shape = keys_shape;
1287
1288 // We iterate through the element pairs that should be compared.
1289 uint64 num_iterations_in_sort_dim = rounded_bound / 2;
1290 iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
1291 uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
1292
1293 // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
1294 // block. Each thread is responsible for copying exactly two adjacent elements
1295 // into shared memory, and then does a comparison of two possibly different
1296 // elements taken from shared memory.
1297 const uint64 kThreadsPerBlock = kTileSize / 2;
1298
1299 // Check whether we should use any tiling. We might not be able to use it if
1300 // we have not enough threads, or not enough shared memory. Also it does not
1301 // give a speedup if the tile size is < 128.
1302 int64 total_shared_memory_needed = 0;
1303 for (int64 i = 0; i < sort->operand_count(); ++i) {
1304 total_shared_memory_needed +=
1305 kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
1306 sort->operand(i)->shape().element_type());
1307 }
1308 bool no_tiling =
1309 kTileSize < 128 ||
1310 kThreadsPerBlock >
1311 ir_emitter_context_->device_description().threads_per_block_limit() ||
1312 total_shared_memory_needed >
1313 ir_emitter_context_->device_description().shared_memory_per_block();
1314
1315 uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
1316 LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
1317
1318 auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
1319 thunks.push_back(
1320 BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
1321 LaunchDimensions launch_dimensions = xor_masks.size() > 1
1322 ? tiled_launch_dimensions
1323 : standard_launch_dimensions;
1324 UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
1325 ir_emitter_context_->llvm_module());
1326 std::vector<IrArray> values_arrays;
1327 values_arrays.reserve(sort->operand_count());
1328 for (int64 i = 0; i < sort->operand_count(); ++i) {
1329 ShapeIndex shape_index =
1330 sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
1331 values_arrays.push_back(GetIrArray(*sort, *sort, shape_index));
1332 }
1333 return llvm_ir::EmitSortInPlace(
1334 dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_,
1335 launch_dimensions,
1336 xor_masks.size() > 1 ? num_iterations_in_sort_dim
1337 : standard_num_iterations_in_sort_dim,
1338 kTileSize,
1339 [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
1340 return EmitCallToNestedComputation(*sort->to_apply(), operands,
1341 output);
1342 });
1343 };
1344 std::vector<int64> xor_masks;
1345 for (int64 stage = 0; stage < num_stages; ++stage) {
1346 for (int64 mask = stage; mask >= 0; --mask) {
1347 int64 xor_mask;
1348 if (mask == stage) {
1349 xor_mask = (1LL << (stage + 1)) - 1;
1350 } else {
1351 xor_mask = 1LL << mask;
1352 }
1353 if (xor_mask >= kTileSize || no_tiling) {
1354 if (!xor_masks.empty()) {
1355 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
1356 xor_masks.clear();
1357 }
1358 TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
1359 } else {
1360 xor_masks.push_back(xor_mask);
1361 }
1362 }
1363 }
1364 if (!xor_masks.empty()) {
1365 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
1366 }
1367
1368 AddThunkToThunkSequence(
1369 absl::make_unique<SequentialThunk>(std::move(thunks), sort));
1370 return Status::OK();
1371 }
1372
HandleTupleSelect(HloInstruction * tuple_select)1373 Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
1374 AddThunkToThunkSequence(
1375 BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
1376 return IrEmitter::HandleTupleSelect(tuple_select);
1377 }
1378
1379 namespace {
1380
IsScalarAddComputation(HloComputation * computation)1381 bool IsScalarAddComputation(HloComputation* computation) {
1382 return Match(computation->root_instruction(),
1383 m::AddAnyOrder(m::Parameter(0), m::Parameter(1))
1384 .WithShape(m::Shape().IsEffectiveScalar()));
1385 }
1386
1387 } // namespace
1388
HandleAllReduce(HloInstruction * crs)1389 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
1390 VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
1391 << "; operand count: " << crs->operand_count()
1392 << "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled();
1393
1394 // Note the replica_count == 1 case is handled via device-to-device copy
1395 // below.
1396 bool should_use_nccl_thunk =
1397 hlo_module_config_.replica_count() > 1 &&
1398 crs->IsCrossReplicaAllReduce() &&
1399 crs->operand_count() == 1 && // One array to reduce.
1400 crs->operand(0)->shape().element_type() == F32 &&
1401 // Check the computation is a summation.
1402 IsScalarAddComputation(crs->to_apply());
1403
1404 if (should_use_nccl_thunk) {
1405 CHECK(crs->operand(0)->shape().IsArray())
1406 << "Operands to all-reduce must be arrays: " << crs->ToString();
1407 AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
1408 /*replica_count=*/hlo_module_config_.replica_count(),
1409 /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
1410 /*source_address=*/GetAllocationSlice(*crs->operand(0)),
1411 /*destination_buffer=*/GetAllocationSlice(*crs), crs));
1412 return Status::OK();
1413 }
1414
1415 if (hlo_module_config_.replica_count() != 1) {
1416 // TODO(b/33011107): Support more AllReduce configurations on GPU.
1417 string message = absl::StrFormat(
1418 "Requested AllReduce not implemented on GPU; replica_count: %d; "
1419 "operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d",
1420 hlo_module_config_.replica_count(), crs->operand_count(),
1421 crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled());
1422 if (crs->operand_count() > 0) {
1423 absl::StrAppendFormat(
1424 &message, "; first operand array element-type: %s",
1425 PrimitiveType_Name(crs->operand(0)->shape().element_type()));
1426 }
1427 return Unimplemented("%s", message);
1428 }
1429
1430 // CRS with one operand and one replica is simply the identity function.
1431 // Buffer assignment expects a copy, so that's what we do.
1432 //
1433 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1434 // in algebraic-simplifier, but currently on some platforms
1435 // HloModuleConfig::num_replicas changes between when the module is compiled
1436 // and when it's run.
1437 if (crs->operand_count() == 1) {
1438 CHECK(crs->operand(0)->shape().IsArray())
1439 << "Operands to all-reduce must be arrays: " << crs->ToString();
1440 AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
1441 /*source_address=*/GetAllocationSlice(*crs->operand(0)),
1442 /*destination_buffer=*/GetAllocationSlice(*crs),
1443 /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
1444 return Status::OK();
1445 }
1446
1447 // One-replica CRS with multiple operands produces a tuple of the inputs.
1448 // Again, buffer assignment expects us to copy each.
1449 std::vector<std::unique_ptr<Thunk>> thunks;
1450 std::vector<BufferAllocation::Slice> tuple_element_buffers;
1451 for (int64 i = 0; i < crs->operand_count(); ++i) {
1452 tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
1453 .GetUniqueSlice(crs, {i})
1454 .ValueOrDie());
1455 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1456 /*source_address=*/GetAllocationSlice(*crs->operand(i)),
1457 /*destination_buffer=*/tuple_element_buffers.back(),
1458 /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
1459 }
1460
1461 // Output a tuple of the buffers above.
1462 thunks.push_back(absl::make_unique<TupleThunk>(
1463 tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
1464 AddThunkToThunkSequence(
1465 absl::make_unique<SequentialThunk>(std::move(thunks), crs));
1466 return Status::OK();
1467 }
1468
HandleAfterAll(HloInstruction * after_all)1469 Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
1470 return Status::OK();
1471 }
1472
HandleInfeed(HloInstruction * infeed)1473 Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
1474 AddThunkToThunkSequence(BuildInfeedThunk(infeed));
1475 return Status::OK();
1476 }
1477
HandleOutfeed(HloInstruction * outfeed)1478 Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
1479 AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
1480 return Status::OK();
1481 }
1482
1483 // Figures out how to access the buffers for all subshapes of hlo's operands and
1484 // for hlo itself (i.e. all the buffers produced by HLO).
1485 //
1486 // Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for
1487 // this key is a pair {Slice, ShapeIndex}, where the slice tells you the root
1488 // buffer to look in, and the ShapeIndex describes how to dereference starting
1489 // at that buffer to get to the buffer in question.
1490 //
1491 // For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for
1492 // hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo)
1493 // is found at slice[3][4]. That is, slice is a void***, which we dereference
1494 // twice -- first at index 3, and then at index 4 -- to get the address of our
1495 // buffer.
1496 //
1497 // This function conservatively assumes that we'll touch all sub-buffers of
1498 // every operand and of the output.
1499 static std::map<std::pair<const HloInstruction*, ShapeIndex>,
1500 std::pair<BufferAllocation::Slice, ShapeIndex>>
GetHloBufferSlices(const HloInstruction * hlo,const BufferAssignment & buffer_assn)1501 GetHloBufferSlices(const HloInstruction* hlo,
1502 const BufferAssignment& buffer_assn) {
1503 std::map<std::pair<const HloInstruction*, ShapeIndex>,
1504 std::pair<BufferAllocation::Slice, ShapeIndex>>
1505 slices;
1506
1507 // Tries to find a slice plus an array of indices i1, ..., iN such that the
1508 // sub-buffer for instr at index can be found at slice[i1]...[iN].
1509 auto find_slice_for = [&](const HloInstruction* instr,
1510 const ShapeIndex& index)
1511 -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
1512 // Simple, common case: Is the buffer for instr known at runtime? If so,
1513 // we're done.
1514 auto slice = buffer_assn.GetUniqueSlice(instr, index);
1515 if (slice.ok()) {
1516 return {{slice.ValueOrDie(), ShapeIndex()}};
1517 }
1518
1519 // If that didn't work, walk up any bitcasts that we might see. These must
1520 // appear before any GTE instructions, because it's illegal to bitcast to a
1521 // tuple type.
1522 const HloInstruction* parent = instr;
1523 while (parent->opcode() == HloOpcode::kBitcast) {
1524 parent = parent->operand(0);
1525
1526 auto slice = buffer_assn.GetUniqueSlice(parent, {});
1527 if (slice.ok()) {
1528 return {{slice.ValueOrDie(), ShapeIndex()}};
1529 }
1530 }
1531
1532 // Check whether instr is a GTE instruction. If it is, see if we can get a
1533 // buffer for its parent, and continue walking up parents until we find a
1534 // defined buffer or we hit something that's not a GTE.
1535 ShapeIndex gte_indices;
1536 while (parent->opcode() == HloOpcode::kGetTupleElement) {
1537 gte_indices.push_front(parent->tuple_index());
1538 parent = parent->operand(0);
1539
1540 auto slice = buffer_assn.GetUniqueSlice(parent, {});
1541 if (slice.ok()) {
1542 return {{slice.ValueOrDie(), gte_indices}};
1543 }
1544 }
1545
1546 // Finally, if we don't know the buffer for instr at index, see if we know
1547 // the buffer for instr at index without its last element. If so, we can
1548 // dynamically find the buffer for instr by dereferencing a pointer in that
1549 // buffer. Continue looking this way until we run out of elements in
1550 // 'index'.
1551 //
1552 // We can almost always get a buffer without resorting to this. The only
1553 // exception is for cases where the relevant sub-buffer is truly unknowable,
1554 // for example the sub-buffer of a tuple-shaped select.
1555 ShapeIndex new_index = index;
1556 while (!new_index.empty()) {
1557 gte_indices.push_front(new_index.back());
1558 new_index.pop_back();
1559 auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
1560 if (slice.ok()) {
1561 return {{slice.ValueOrDie(), gte_indices}};
1562 }
1563 }
1564
1565 return nullopt;
1566 };
1567
1568 // Adds entries for all subshapes of instr to `slices`.
1569 auto add_slices_for = [&](const HloInstruction* instr) {
1570 ShapeUtil::ForEachSubshape(
1571 instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
1572 if (slices.count({instr, index})) {
1573 // HLOs can have duplicate operands; don't bother redoing work.
1574 return;
1575 }
1576 auto maybe_slice = find_slice_for(instr, index);
1577 if (maybe_slice.has_value()) {
1578 slices[{instr, index}] = *maybe_slice;
1579 } else {
1580 VLOG(1) << "Couldn't find buffer for " << instr->ToString()
1581 << " at index " << index.ToString();
1582 }
1583 });
1584 };
1585
1586 add_slices_for(hlo);
1587 for (const HloInstruction* operand : hlo->operands()) {
1588 // Conservatively assume we'll need the buffers for all subshapes of the
1589 // operand.
1590 add_slices_for(operand);
1591 }
1592
1593 return slices;
1594 }
1595
BuildKernelThunk(const HloInstruction * inst,bool implements_whole_instruction,int unroll_factor)1596 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
1597 const HloInstruction* inst, bool implements_whole_instruction,
1598 int unroll_factor) {
1599 const BufferAssignment& buffer_assn =
1600 ir_emitter_context_->buffer_assignment();
1601
1602 std::map<std::pair<const HloInstruction*, ShapeIndex>,
1603 std::pair<BufferAllocation::Slice, ShapeIndex>>
1604 hlo_slices = GetHloBufferSlices(inst, buffer_assn);
1605
1606 // Figure out which buffer allocations need to be passed as arguments to our
1607 // kernel. This is simply all of the allocations referenced in hlo_slices,
1608 // plus the XLA temp buffer (if we have it). We always include the temp
1609 // buffer because even if the kernel itself doesn't use it, a nested
1610 // subcomputation within the kernel (e.g. a kMap's computation) might.
1611 std::unordered_set<const BufferAllocation*> buffers_needed;
1612 for (const auto& kv : hlo_slices) {
1613 buffers_needed.insert(kv.second.first.allocation());
1614 }
1615 absl::optional<const BufferAllocation*> temp_buffer;
1616 for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
1617 if (alloc.IsPreallocatedTempBuffer()) {
1618 if (!temp_buffer.has_value()) {
1619 temp_buffer = &alloc;
1620 } else {
1621 LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!";
1622 }
1623 }
1624 }
1625 if (temp_buffer.has_value()) {
1626 buffers_needed.insert(*temp_buffer);
1627 }
1628
1629 // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
1630 // this order.
1631 std::vector<const BufferAllocation*> non_constant_buffers;
1632 absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
1633 [](const BufferAllocation* allocation) {
1634 return !allocation->is_constant();
1635 });
1636
1637 absl::c_sort(non_constant_buffers,
1638 [](const BufferAllocation* a, const BufferAllocation* b) {
1639 return a->index() < b->index();
1640 });
1641
1642 llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
1643
1644 // Build a map from a BufferAllocation to the corresponding argument in our
1645 // kernel.
1646 std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
1647 {
1648 auto arg_it = kernel->arg_begin();
1649 auto buffers_it = non_constant_buffers.begin();
1650 for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
1651 kernel_args[*buffers_it] = arg_it;
1652 }
1653 }
1654
1655 // For each buffer our kernel might want to touch, bind it to a value derived
1656 // from our kernel args.
1657 for (const auto& kv : hlo_slices) {
1658 const HloInstruction* instr = kv.first.first;
1659 const ShapeIndex& index = kv.first.second;
1660 const BufferAllocation::Slice& slice = kv.second.first;
1661 const ShapeIndex& gte_index = kv.second.second;
1662
1663 VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
1664 << " is found in slice " << slice.ToString() << " at GTE index "
1665 << gte_index.ToString();
1666
1667 llvm::Value* loc;
1668 if (slice.allocation()->is_constant()) {
1669 loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
1670 llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation()));
1671 CHECK_NE(loc, nullptr);
1672 } else {
1673 loc = InBoundsGEP(kernel_args.at(slice.allocation()),
1674 {b_.getInt64(slice.offset())});
1675 }
1676
1677 // If gte_index is nonempty, we have to dereference `loc` to get to the
1678 // value we're ultimately interested in.
1679 llvm::Type* int8_double_pointer =
1680 llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
1681 for (int64 idx : gte_index) {
1682 loc = BitCast(loc, int8_double_pointer);
1683 loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
1684 }
1685
1686 bindings_.BindHloToIrValue(*instr, loc, index);
1687 }
1688
1689 // Bind the temp buffer so that nested subcomputations can find it if they
1690 // need.
1691 if (temp_buffer.has_value()) {
1692 bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
1693 } else {
1694 bindings_.SetTempBufferBase(
1695 llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
1696 }
1697
1698 return absl::make_unique<KernelThunk>(
1699 non_constant_buffers, kernel->getName(),
1700 implements_whole_instruction ? inst : nullptr, unroll_factor);
1701 }
1702
BuildHostToDeviceCopyThunk(const HloInstruction * inst)1703 std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
1704 const HloInstruction* inst) {
1705 const HloInstruction* operand = inst->operand(0);
1706 CHECK_EQ(HloOpcode::kConstant, operand->opcode());
1707 return absl::make_unique<HostToDeviceCopyThunk>(
1708 /*source_address=*/operand->literal().untyped_data(),
1709 /*destination_buffer=*/GetAllocationSlice(*inst),
1710 /*mem_size=*/
1711 llvm_ir::ByteSizeOf(operand->shape(),
1712 ir_emitter_context_->llvm_module()->getDataLayout()),
1713 inst);
1714 }
1715
BuildDeviceToDeviceCopyThunk(const HloInstruction * inst)1716 std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk(
1717 const HloInstruction* inst) {
1718 const HloInstruction* operand = inst->operand(0);
1719 return absl::make_unique<DeviceToDeviceCopyThunk>(
1720 /*source_address=*/GetAllocationSlice(*operand),
1721 /*destination_buffer=*/GetAllocationSlice(*inst),
1722 /*mem_size=*/
1723 llvm_ir::ByteSizeOf(operand->shape(),
1724 ir_emitter_context_->llvm_module()->getDataLayout()),
1725 inst);
1726 }
1727
BuildInfeedThunk(const HloInstruction * inst)1728 std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
1729 const HloInstruction* inst) {
1730 CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
1731
1732 ShapeTree<BufferAllocation::Slice> slices(inst->shape());
1733 slices.ForEachMutableElement(
1734 [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
1735 *slice = ir_emitter_context_->buffer_assignment()
1736 .GetUniqueSlice(inst, index)
1737 .ConsumeValueOrDie();
1738 });
1739 return absl::make_unique<InfeedThunk>(slices, inst);
1740 }
1741
BuildOutfeedThunk(const HloInstruction * inst)1742 std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
1743 const HloInstruction* inst) {
1744 CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
1745
1746 ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
1747 slices.ForEachMutableElement(
1748 [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
1749 auto status_or_slice =
1750 ir_emitter_context_->buffer_assignment().GetUniqueSlice(
1751 inst->operand(0), index);
1752 if (status_or_slice.ok()) {
1753 *slice = status_or_slice.ConsumeValueOrDie();
1754 }
1755 });
1756 return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
1757 }
1758
1759 namespace {
GetScalarConstantAsDouble(const Literal & literal)1760 double GetScalarConstantAsDouble(const Literal& literal) {
1761 switch (literal.shape().element_type()) {
1762 case F16:
1763 return static_cast<double>(literal.Get<Eigen::half>({}));
1764 case F32:
1765 return literal.Get<float>({});
1766 case F64:
1767 return literal.Get<double>({});
1768 default:
1769 LOG(FATAL) << "Unsupported type.";
1770 }
1771 }
1772 } // namespace
1773
BuildGemmThunk(const HloInstruction * inst)1774 std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
1775 const HloInstruction* inst) {
1776 if (inst->opcode() == HloOpcode::kDot) {
1777 const HloInstruction* lhs = inst->operand(0);
1778 const HloInstruction* rhs = inst->operand(1);
1779 return absl::make_unique<GemmThunk>(
1780 GetAllocationSlice(*lhs), // The buffer assigned to LHS.
1781 GetAllocationSlice(*rhs), // The buffer assigned to RHS.
1782 GetAllocationSlice(*inst), // The output buffer.
1783 lhs->shape(), // The shape of LHS.
1784 rhs->shape(), // The shape of RHS.
1785 inst->shape(), // The shape of the output.
1786 1.0, // alpha.
1787 0.0, // beta.
1788 inst, /*implements_whole_instruction=*/true);
1789 }
1790
1791 if (inst->opcode() == HloOpcode::kFusion) {
1792 CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput);
1793 const HloInstruction* output_fused_op = inst->fused_expression_root();
1794
1795 double alpha_value = 1.0;
1796 const HloInstruction* bias = nullptr;
1797 const HloInstruction* dot = output_fused_op->operand(0);
1798 if (output_fused_op->opcode() == HloOpcode::kMultiply) {
1799 const HloInstruction* alpha = output_fused_op->operand(1);
1800 if (dot->opcode() != HloOpcode::kDot) {
1801 std::swap(dot, alpha);
1802 }
1803 if (alpha->opcode() == HloOpcode::kBroadcast) {
1804 alpha = alpha->operand(0);
1805 }
1806 if (alpha->opcode() == HloOpcode::kParameter) {
1807 alpha = inst->operand(alpha->parameter_number());
1808 }
1809 // TODO(b/74185543): Remove the following if block once we support fusion
1810 // with a non-constant as well. Then we will just always use the constant
1811 // on the device.
1812 if (alpha->opcode() == HloOpcode::kCopy) {
1813 alpha = alpha->operand(0);
1814 }
1815 alpha_value = GetScalarConstantAsDouble(alpha->literal());
1816 } else {
1817 // Fused bias add.
1818 CHECK_EQ(output_fused_op->opcode(), HloOpcode::kAdd);
1819 bias = output_fused_op->operand(1);
1820 if (dot->opcode() != HloOpcode::kDot) {
1821 std::swap(dot, bias);
1822 }
1823 bias = inst->operand(bias->parameter_number());
1824 }
1825
1826 DCHECK(dot->opcode() == HloOpcode::kDot);
1827 const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
1828 const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
1829 DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
1830 rhs_parameter->opcode() == HloOpcode::kParameter);
1831 const HloInstruction* lhs =
1832 inst->operand(lhs_parameter->parameter_number());
1833 const HloInstruction* rhs =
1834 inst->operand(rhs_parameter->parameter_number());
1835
1836 // The bias is passed inside the output buffer. If those buffers are shared
1837 // we can just use it, otherwise copy the bias values into the output buffer
1838 // first.
1839 if (bias != nullptr &&
1840 GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
1841 std::vector<std::unique_ptr<Thunk>> thunks;
1842 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1843 /*source_buffer=*/GetAllocationSlice(*bias),
1844 /*destination_buffer=*/GetAllocationSlice(*inst),
1845 /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr));
1846 thunks.push_back(absl::make_unique<GemmThunk>(
1847 GetAllocationSlice(*lhs), // The buffer assigned to LHS.
1848 GetAllocationSlice(*rhs), // The buffer assigned to RHS.
1849 GetAllocationSlice(*inst), // The output buffer.
1850 lhs->shape(), // The shape of LHS.
1851 rhs->shape(), // The shape of RHS.
1852 inst->shape(), // The shape of the output.
1853 alpha_value, // alpha.
1854 1.0, // beta.
1855 inst, /*implements_whole_instruction=*/false));
1856 return absl::make_unique<SequentialThunk>(std::move(thunks), inst);
1857 }
1858 return absl::make_unique<GemmThunk>(
1859 GetAllocationSlice(*lhs), // The buffer assigned to LHS.
1860 GetAllocationSlice(*rhs), // The buffer assigned to RHS.
1861 GetAllocationSlice(*inst), // The output buffer.
1862 lhs->shape(), // The shape of LHS.
1863 rhs->shape(), // The shape of RHS.
1864 inst->shape(), // The shape of the output.
1865 alpha_value, // alpha.
1866 bias != nullptr ? 1.0 : 0.0, // beta.
1867 inst, /*implements_whole_instruction=*/true);
1868 }
1869
1870 LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
1871 }
1872
BuildFftThunk(const HloInstruction * inst)1873 std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
1874 const HloInstruction* inst) {
1875 const HloInstruction* operand = inst->operand(0);
1876 return absl::make_unique<FftThunk>(
1877 inst->fft_type(), inst->fft_length(),
1878 /*input_buffer=*/GetAllocationSlice(*operand),
1879 /*output_buffer=*/GetAllocationSlice(*inst),
1880 /*input_shape=*/operand->shape(),
1881 /*output_shape=*/inst->shape(), inst);
1882 }
1883
BuildTriangularSolveThunk(const HloInstruction * inst)1884 std::unique_ptr<Thunk> IrEmitterUnnested::BuildTriangularSolveThunk(
1885 const HloInstruction* inst) {
1886 const HloInstruction* a = inst->operand(0);
1887 const HloInstruction* b = inst->operand(1);
1888 int64 m = b->shape().dimensions(b->shape().rank() - 2);
1889 int64 n = b->shape().dimensions(b->shape().rank() - 1);
1890 int64 batch_size = std::accumulate(
1891 b->shape().dimensions().begin(), b->shape().dimensions().end() - 2,
1892 int64{1}, [](int64 a, int64 b) { return a * b; });
1893 int64 elem_size =
1894 ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type());
1895 int64 a_batch_stride = inst->triangular_solve_options().left_side()
1896 ? m * m * elem_size
1897 : n * n * elem_size;
1898 int64 b_batch_stride = m * n * elem_size;
1899 return absl::make_unique<TriangularSolveThunk>(
1900 inst->triangular_solve_options(),
1901 /*a_input_buffer=*/GetAllocationSlice(*a),
1902 /*b_input_buffer=*/GetAllocationSlice(*inst),
1903 inst->shape().element_type(), batch_size, m, n, a_batch_stride,
1904 b_batch_stride, inst);
1905 }
1906
BuildInitializerThunk(HloInstruction * hlo,const ShapeIndex & index)1907 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
1908 HloInstruction* hlo, const ShapeIndex& index) {
1909 bool fused = HloOpcode::kFusion == hlo->opcode();
1910 HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
1911 HloInstruction* init_value_operand = [&] {
1912 switch (inst->opcode()) {
1913 case HloOpcode::kSelectAndScatter:
1914 return inst->mutable_operand(2);
1915 case HloOpcode::kReduce:
1916 return inst->mutable_operand(1);
1917 case HloOpcode::kTuple:
1918 CHECK(hlo->IsMultiOutputFusion())
1919 << ": " << hlo->ToString() << " is not a multi-output fusion.";
1920 CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce)
1921 << ": Found '" << inst->operand(index.back())->opcode() << "' in "
1922 << inst->ToString() << " but expected 'reduce'.";
1923 // For multi-output fusion look through the tuple.
1924 return inst->mutable_operand(index.back())->mutable_operand(1);
1925 default:
1926 LOG(FATAL) << "Opcode " << inst->opcode()
1927 << " should not need an initializer.";
1928 }
1929 }();
1930
1931 const HloInstruction* init_value = init_value_operand;
1932 if (fused && init_value->opcode() == HloOpcode::kParameter) {
1933 init_value = hlo->operand(init_value->parameter_number());
1934 }
1935
1936 // Initializer thunks don't implement a whole instruction, and we want to
1937 // profile the whole instruction instead of the individual thunks it consists
1938 // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
1939 // generate below.
1940 //
1941 // In the common case, the initializer is a constant. In this case, emit a
1942 // device-memset call if we can. Currently StreamExecutor only supports
1943 // zeroing and 32-bit memsets.
1944 if (init_value->IsConstant()) {
1945 CHECK(ShapeUtil::IsScalar(init_value->shape()));
1946 int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape());
1947 const auto& literal = init_value->literal();
1948
1949 // Are all the bytes of this scalar equal to 0? If so, we can create a
1950 // MemzeroThunk.
1951 absl::Span<const uint8> literal_bytes(
1952 reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
1953 if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
1954 return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
1955 nullptr)};
1956 }
1957
1958 // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
1959 // repeating the literal 4 or 2 times, so long as the destination buffer is
1960 // an even multiple of 32 bits long.
1961 const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
1962 if ((num_bytes == 1 || num_bytes == 2) &&
1963 ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
1964 uint16 pattern16;
1965 if (num_bytes == 1) {
1966 uint8 b = literal_bytes.front();
1967 pattern16 = uint16{b} | (uint16{b} << 8);
1968 } else {
1969 memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
1970 }
1971 uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
1972 return {absl::make_unique<Memset32BitValueThunk>(
1973 pattern32, GetAllocationSlice(*hlo, index), nullptr)};
1974 }
1975
1976 // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
1977 // memset so long as all 32-bit words of the scalar are equal to each other.
1978 if (num_bytes >= 4 && num_bytes % 4 == 0 &&
1979 memcmp(literal_bytes.data(), literal_bytes.data() + 4,
1980 literal_bytes.size() - 4) == 0) {
1981 uint32 word;
1982 memcpy(&word, literal_bytes.data(), sizeof(word));
1983 return {absl::make_unique<Memset32BitValueThunk>(
1984 word, GetAllocationSlice(*hlo, index), nullptr)};
1985 }
1986 }
1987
1988 // Otherwise fall back to our slow initializer code.
1989 std::unique_ptr<KernelThunk> kernel_thunk =
1990 BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
1991 LaunchDimensions launch_dimensions =
1992 CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
1993 ir_emitter_context_->device_description());
1994 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
1995 ir_emitter_context_->llvm_module());
1996
1997 if (fused) {
1998 // If init_value was fused into this reduce we have to generate it first.
1999 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
2000 ir_emitter_context_->llvm_module(),
2001 &b_, GetNestedComputer());
2002
2003 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
2004 &elemental_emitter);
2005 TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
2006 TF_RETURN_IF_ERROR(
2007 ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
2008 GetIrArray(*hlo, *hlo, index), launch_dimensions,
2009 &b_)
2010 .EmitLoop(IrName(hlo)));
2011 } else {
2012 // In the unfused case the element is already there, just read from it.
2013 TF_RETURN_IF_ERROR(ParallelLoopEmitter(
2014 [=](const IrArray::Index& index) {
2015 return GetIrArray(*init_value, *hlo)
2016 .EmitReadArrayElement(index, &b_);
2017 },
2018 GetIrArray(*hlo, *hlo, index), launch_dimensions,
2019 &b_)
2020 .EmitLoop(IrName(hlo)));
2021 }
2022
2023 // Clean up state left behind by emitting the loop above. (This is normally
2024 // done in IrEmitterUnnested::Postprocess().)
2025 bindings_.UnbindAllLocalIrValues();
2026
2027 // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
2028 return {std::move(kernel_thunk)};
2029 }
2030
2031 namespace {
2032
2033 // Checks that the buffers corresponding to the given two HLOs share the same
2034 // allocation.
CheckHloBuffersShareAllocation(const HloInstruction * a,const HloInstruction * b,const ShapeIndex & index,const BufferAssignment & buffer_assignment)2035 Status CheckHloBuffersShareAllocation(
2036 const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
2037 const BufferAssignment& buffer_assignment) {
2038 const BufferAllocation::Slice slice_a =
2039 buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
2040 const BufferAllocation::Slice slice_b =
2041 buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
2042 if (slice_a != slice_b) {
2043 return InternalError(
2044 "instruction %s %s does not share allocation with instruction %s %s",
2045 a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
2046 }
2047 return Status::OK();
2048 }
2049
2050 // Checks that all buffers used during while loop iteration share the same
2051 // buffer allocation. This includes buffers for while result, while init
2052 // operand, condition parameter, body parameter and body result.
2053 // Returns OK on success, error status otherwise.
CheckWhileBuffersShareAllocation(const HloInstruction * xla_while,const BufferAssignment & buffer_assignment)2054 Status CheckWhileBuffersShareAllocation(
2055 const HloInstruction* xla_while,
2056 const BufferAssignment& buffer_assignment) {
2057 return ShapeUtil::ForEachSubshapeWithStatus(
2058 xla_while->shape(),
2059 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
2060 const HloInstruction* condition_parameter =
2061 xla_while->while_condition()->parameter_instruction(0);
2062 const HloComputation* body = xla_while->while_body();
2063 const HloInstruction* body_parameter = body->parameter_instruction(0);
2064 const HloInstruction* body_result = body->root_instruction();
2065 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
2066 xla_while, xla_while->operand(0), index, buffer_assignment));
2067 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
2068 xla_while, condition_parameter, index, buffer_assignment));
2069 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
2070 xla_while, body_parameter, index, buffer_assignment));
2071 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
2072 xla_while, body_result, index, buffer_assignment));
2073 return Status::OK();
2074 });
2075 }
2076
2077 // Checks that the buffers used in a conditional instruction are shared with the
2078 // operands and result as follows:
2079 // * The result buffer of the conditional should share the allocation with the
2080 // result buffers of each branch computation.
2081 // * The buffer of operand b+1 should share the allocation with the buffer of
2082 // the parameter 0 instruction of the b'th computation.
CheckConditionalBuffersShareAllocation(const HloInstruction * conditional,const BufferAssignment & buffer_assignment)2083 Status CheckConditionalBuffersShareAllocation(
2084 const HloInstruction* conditional,
2085 const BufferAssignment& buffer_assignment) {
2086 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2087 conditional->shape(),
2088 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
2089 for (auto branch_computation : conditional->branch_computations()) {
2090 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
2091 conditional, branch_computation->root_instruction(), index,
2092 buffer_assignment));
2093 }
2094 return Status::OK();
2095 }));
2096 for (int j = 0; j < conditional->branch_count(); ++j) {
2097 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2098 conditional->operand(j + 1)->shape(),
2099 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
2100 return CheckHloBuffersShareAllocation(
2101 conditional->operand(j + 1),
2102 conditional->branch_computation(j)->parameter_instruction(0),
2103 index, buffer_assignment);
2104 }));
2105 }
2106 return Status::OK();
2107 }
2108
2109 } // namespace
2110
BuildWhileThunk(const HloInstruction * hlo)2111 std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
2112 const HloInstruction* hlo) {
2113 // Check that all while-related buffers share an allocation.
2114 TF_CHECK_OK(CheckWhileBuffersShareAllocation(
2115 hlo, ir_emitter_context_->buffer_assignment()));
2116
2117 // Generate thunk sequence for while 'condition'.
2118 HloComputation* condition = hlo->while_condition();
2119 IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
2120 ir_emitter_context_);
2121 TF_CHECK_OK(condition->Accept(&ir_emitter_condition));
2122
2123 // Generate thunk sequence for while 'body'.
2124 HloComputation* body = hlo->while_body();
2125 IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
2126 ir_emitter_context_);
2127 TF_CHECK_OK(body->Accept(&ir_emitter_body));
2128
2129 return absl::make_unique<WhileThunk>(
2130 GetAllocationSlice(*condition->root_instruction()), // cond result
2131 ir_emitter_condition.ConsumeThunkSequence(),
2132 ir_emitter_body.ConsumeThunkSequence(), hlo);
2133 }
2134
BuildForThunk(const HloInstruction * hlo,const int64 loop_limit)2135 std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
2136 const HloInstruction* hlo, const int64 loop_limit) {
2137 // Check that all while-related buffers share an allocation.
2138 TF_CHECK_OK(CheckWhileBuffersShareAllocation(
2139 hlo, ir_emitter_context_->buffer_assignment()));
2140
2141 // Generate thunk sequence for while 'body' (will be used a For loop body).
2142 HloComputation* body = hlo->while_body();
2143 IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
2144 ir_emitter_context_);
2145 TF_CHECK_OK(body->Accept(&ir_emitter_body));
2146
2147 return absl::make_unique<ForThunk>(
2148 loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
2149 }
2150
BuildConditionalThunk(const HloInstruction * hlo)2151 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
2152 const HloInstruction* hlo) {
2153 // Check that the buffers used in conditional are shared with the operands and
2154 // result appropriately.
2155 TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
2156 hlo, ir_emitter_context_->buffer_assignment()));
2157
2158 std::vector<BufferAllocation::Slice> branch_operands;
2159 std::vector<ThunkSequence> branch_thunks;
2160 for (int j = 0; j < hlo->branch_count(); ++j) {
2161 branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
2162 HloComputation* branch_computation = hlo->branch_computation(j);
2163 IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
2164 ir_emitter_context_);
2165 TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
2166 branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
2167 }
2168
2169 return absl::make_unique<ConditionalThunk>(
2170 GetAllocationSlice(*hlo->operand(0)), branch_operands,
2171 std::move(branch_thunks), hlo);
2172 }
2173
EmitTargetElementLoopInThunk(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator,KernelThunk * thunk)2174 Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
2175 const HloInstruction& hlo,
2176 const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
2177 int unroll_factor = thunk->unroll_factor();
2178 VLOG(3) << bindings_.ToString();
2179
2180 const Shape& element_shape = hlo.IsMultiOutputFusion()
2181 ? ShapeUtil::GetSubshape(hlo.shape(), {0})
2182 : hlo.shape();
2183 VLOG(3) << "EmitTargetElementLoopInThunk "
2184 << ShapeUtil::HumanStringWithLayout(hlo.shape())
2185 << " for unroll_factor " << unroll_factor;
2186 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2187 element_shape, ir_emitter_context_->device_description(), unroll_factor);
2188 UpdateLaunchDimensions(launch_dimensions, thunk,
2189 ir_emitter_context_->llvm_module());
2190 if (!hlo.IsMultiOutputFusion()) {
2191 return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
2192 launch_dimensions, &b_, unroll_factor)
2193 .EmitLoop(
2194 IrName(&hlo),
2195 GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_));
2196 }
2197
2198 // Emit the tuple pointers in one thread. We could do this at any point in
2199 // the kernel, but we do it at the beginning in the hopes of reducing register
2200 // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the
2201 // kernel *anyway*.
2202 std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
2203 KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
2204 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_);
2205 });
2206
2207 // For multioutput fusion, we need to emit each operand and the root.
2208 TF_RETURN_IF_ERROR(
2209 ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
2210 &b_, unroll_factor)
2211 .EmitLoop(IrName(&hlo),
2212 GetIndexTypeForKernel(
2213 &hlo, launch_dimensions.launch_bound(), &b_)));
2214
2215 b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
2216 return Status::OK();
2217 }
2218
2219 namespace {
2220
2221 // Returns true if the fusion contains any instruction that is likely
2222 // translated to complex LLVM IR, such as loops, and prevent vectorization.
MayPreventVectorization(const HloInstruction & fusion_hlo)2223 bool MayPreventVectorization(const HloInstruction& fusion_hlo) {
2224 CHECK_EQ(fusion_hlo.opcode(), HloOpcode::kFusion);
2225 return absl::c_any_of(
2226 fusion_hlo.fused_instructions_computation()->instructions(),
2227 [&](const HloInstruction* instr) {
2228 switch (instr->opcode()) {
2229 case HloOpcode::kReduce:
2230 case HloOpcode::kReduceWindow:
2231 case HloOpcode::kSort:
2232 case HloOpcode::kDot:
2233 return true;
2234 default:
2235 return false;
2236 }
2237 });
2238 }
2239
2240 } // namespace
2241
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)2242 Status IrEmitterUnnested::EmitTargetElementLoop(
2243 const HloInstruction& hlo,
2244 const llvm_ir::ElementGenerator& element_generator) {
2245 int unroll_factor = 1;
2246 // Unfused elementwise operations are usually memory bound, unroll them.
2247 if (hlo.IsElementwise() ||
2248 (hlo.opcode() == HloOpcode::kFusion && !MayPreventVectorization(hlo))) {
2249 unroll_factor = ComputeMaxUnrollFactor(&hlo);
2250 }
2251
2252 std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
2253 &hlo, /*implements_whole_instruction=*/true, unroll_factor);
2254 Status emit_status =
2255 EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
2256 thunk_sequence_->emplace_back(std::move(kernel_thunk));
2257
2258 return emit_status;
2259 }
2260
ConstructIrArrayForInputs(const HloInstruction & hlo)2261 std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
2262 const HloInstruction& hlo) {
2263 std::vector<IrArray> param_arrays;
2264 param_arrays.reserve(hlo.operands().size());
2265 for (const HloInstruction* param : hlo.operands()) {
2266 param_arrays.push_back(GetIrArray(*param, hlo));
2267 }
2268 return param_arrays;
2269 }
2270
ConstructInputReducedShapeAndCastInputIrArrayToShape(const HloInstruction & hlo,const std::vector<IrArray> & param_arrays,const std::vector<llvm::Value * > & param_buffers,absl::Span<const int64> reduced_output_dims,std::vector<Shape> * param_reduced_shapes,std::vector<IrArray> * param_in_reduced_shape_arrays)2271 int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
2272 const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
2273 const std::vector<llvm::Value*>& param_buffers,
2274 absl::Span<const int64> reduced_output_dims,
2275 std::vector<Shape>* param_reduced_shapes,
2276 std::vector<IrArray>* param_in_reduced_shape_arrays) {
2277 int64 num_params = hlo.operands().size();
2278 param_in_reduced_shape_arrays->reserve(num_params);
2279 param_reduced_shapes->reserve(num_params);
2280 for (int64 id = 0; id < num_params; ++id) {
2281 if (param_buffers[id] == nullptr) {
2282 param_reduced_shapes->push_back(Shape());
2283 param_in_reduced_shape_arrays->push_back(IrArray());
2284 continue;
2285 }
2286 const HloInstruction* param = hlo.operand(id);
2287 param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
2288 param->shape().element_type(),
2289 Permute({0, 2, 1}, reduced_output_dims)));
2290 param_in_reduced_shape_arrays->push_back(
2291 param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_));
2292 }
2293 return num_params;
2294 }
2295
2296 namespace {
2297
GetStartOffsetAndStepForX(int64 tile_size_x,int64 num_threads_x,const KernelMappingScheme * mapping_scheme,llvm::IRBuilder<> * builder,llvm::Value * x,llvm::Type * index_ty)2298 std::tuple<llvm::Value*, int64> GetStartOffsetAndStepForX(
2299 int64 tile_size_x, int64 num_threads_x,
2300 const KernelMappingScheme* mapping_scheme, llvm::IRBuilder<>* builder,
2301 llvm::Value* x, llvm::Type* index_ty) {
2302 llvm::Value* start_offset_x;
2303 int64 step_x;
2304 if (mapping_scheme->DilatedX()) {
2305 start_offset_x = x;
2306 step_x = num_threads_x;
2307 } else {
2308 start_offset_x = builder->CreateMul(
2309 x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x));
2310 step_x = 1;
2311 }
2312 return std::make_tuple(start_offset_x, step_x);
2313 }
2314
EmitFullElementalTile(const KernelMappingScheme * mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,llvm::IRBuilder<> * builder,llvm::Value * y,llvm::Value * x,llvm::Type * index_ty,const EmitElementFunction & emit_elem_function)2315 void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme,
2316 const IrArray::Index& tile_origin_index,
2317 const string& loop_name, KernelSupportLibrary* ksl,
2318 llvm::IRBuilder<>* builder, llvm::Value* y,
2319 llvm::Value* x, llvm::Type* index_ty,
2320 const EmitElementFunction& emit_elem_function) {
2321 int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX();
2322 int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY();
2323 int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
2324 int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY();
2325
2326 llvm::Value* start_offset_x;
2327 int64 step_x;
2328 std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX(
2329 tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty);
2330 IrArray::Index source_idx =
2331 tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder)
2332 .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder);
2333 ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0),
2334 /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y),
2335 /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y),
2336 [&](llvm::Value* y_indvar) {
2337 IrArray::Index source_idx_y = source_idx.AddOffsetToDim(
2338 y_indvar, KernelMappingScheme::DimY, builder);
2339 llvm::Value* y_loc = builder->CreateAdd(y_indvar, y);
2340
2341 for (int64 j = 0; j < tile_size_x / num_threads_x; j++) {
2342 IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim(
2343 llvm::ConstantInt::get(index_ty, j * step_x),
2344 KernelMappingScheme::DimX, builder);
2345 llvm::Value* x_loc = builder->CreateAdd(
2346 llvm::ConstantInt::get(index_ty, j * step_x),
2347 start_offset_x);
2348 emit_elem_function(source_idx_y_x, y_loc, x_loc, j);
2349 }
2350 });
2351 }
2352
EmitPartialElementalTile(const KernelMappingScheme * mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,llvm::IRBuilder<> * builder,llvm::Value * y,llvm::Value * x,llvm::Value * tile_height,llvm::Value * tile_width,llvm::Type * index_ty,const EmitElementFunction & emit_elem_function)2353 void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme,
2354 const IrArray::Index& tile_origin_index,
2355 const string& loop_name,
2356 KernelSupportLibrary* ksl,
2357 llvm::IRBuilder<>* builder, llvm::Value* y,
2358 llvm::Value* x, llvm::Value* tile_height,
2359 llvm::Value* tile_width, llvm::Type* index_ty,
2360 const EmitElementFunction& emit_elem_function) {
2361 int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX();
2362 int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY();
2363 int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
2364
2365 llvm::Value* start_offset_x;
2366 int64 step_x;
2367 std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX(
2368 tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty);
2369 IrArray::Index source_idx =
2370 tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder)
2371 .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder);
2372 for (int64 j = 0; j < tile_size_x / num_threads_x; j++) {
2373 IrArray::Index source_idx_x =
2374 source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x),
2375 KernelMappingScheme::DimX, builder);
2376 llvm::Value* x_loc = builder->CreateAdd(
2377 llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x);
2378
2379 ksl->If(
2380 loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width),
2381 [&] {
2382 // tile_height_bound =
2383 // ceil(tile_height / num_threads_y) * num_threads_y
2384 llvm::Value* ceiling_of_ratio = builder->CreateUDiv(
2385 builder->CreateAdd(tile_height, llvm::ConstantInt::get(
2386 index_ty, num_threads_y - 1)),
2387 llvm::ConstantInt::get(index_ty, num_threads_y));
2388 llvm::Value* tile_height_bound = builder->CreateMul(
2389 ceiling_of_ratio,
2390 llvm::ConstantInt::get(index_ty, num_threads_y));
2391 ksl->For(
2392 loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0),
2393 /*end=*/tile_height_bound,
2394 /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y),
2395 [&](llvm::Value* y_indvar) {
2396 llvm::Value* y_loc = builder->CreateAdd(y_indvar, y);
2397 ksl->If(loop_name + "_y_in_tile",
2398 builder->CreateICmpULT(y_loc, tile_height), [&] {
2399 emit_elem_function(
2400 source_idx_x.AddOffsetToDim(
2401 y_indvar, KernelMappingScheme::DimY, builder),
2402 y_loc, x_loc, j);
2403 });
2404 });
2405 });
2406 }
2407 }
2408
2409 // Emits code to process up to
2410 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
2411 // given `emit_elem_function` is the function to emit code to process one
2412 // element, `y` and `x` are the intra-tile coordinates for the first element
2413 // to process, and `index` is the index for the origin of the tile. Information
2414 // about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
2415 // bounds check to ensure that each processed element is within the boundary
2416 // defined by `tile_width` and `tile_height`.
EmitTiledElementalCodeWithBoundsCheck(const KernelMappingScheme * mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,llvm::IRBuilder<> * builder,llvm::Value * y,llvm::Value * x,llvm::Value * tile_height,llvm::Value * tile_width,const EmitElementFunction & emit_elem_function)2417 void EmitTiledElementalCodeWithBoundsCheck(
2418 const KernelMappingScheme* mapping_scheme,
2419 const IrArray::Index& tile_origin_index, const string& loop_name,
2420 KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
2421 llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
2422 const EmitElementFunction& emit_elem_function) {
2423 int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
2424 int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY();
2425 llvm::Type* index_ty = tile_width->getType();
2426
2427 ksl->If(
2428 loop_name + "_full_tile",
2429 builder->CreateAnd(
2430 builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x),
2431 tile_width),
2432 builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y),
2433 tile_height)),
2434 [&] {
2435 EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl,
2436 builder, y, x, index_ty, emit_elem_function);
2437 },
2438 [&] {
2439 EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name,
2440 ksl, builder, y, x, tile_height, tile_width,
2441 index_ty, emit_elem_function);
2442 });
2443 }
2444 } // namespace
2445
2446 // Emits code to process a tensor element in a tile for the given kCopy HLO that
2447 // performs a 0-2-1 transpose.
2448 //
2449 // index: The index for the first output element in the normalized tensor. The
2450 // normalized tensor is the resulting tensor after collapsing contiguous
2451 // dimensions that play the same role in the transpose.
2452 // y_loc: The y coordinate within a tile.
2453 // x_loc: The x coordinate within a tile.
2454 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForCopy(HloInstruction * hlo,const llvm_ir::IrArray::Index & index,const KernelCodegenInfo * kernel_info,llvm::Value * y_loc,llvm::Value * x_loc,int64)2455 void IrEmitterUnnested::EmitTileElementForCopy(
2456 HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
2457 const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
2458 llvm::Value* x_loc, int64 /*x_iter_num*/) {
2459 llvm_ir::TiledParameterInfo* tiled_param_info =
2460 kernel_info->GetTiledParameterInfo();
2461 // TODO(jlebar): Add AA metadata to this load.
2462 llvm::Instruction* load_from_shmem_buffer =
2463 Load(GEP(tiled_param_info->GetBufferForParameter(0),
2464 {b_.getInt64(0), x_loc, y_loc}),
2465 "output_element");
2466 llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
2467 Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
2468 hlo->shape().element_type(),
2469 kernel_info->GetKernelMappingScheme()->GetDimensionsInElements());
2470 // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
2471 // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
2472 output_array.CastToShape(output_reduced_shape, &b_)
2473 .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_);
2474 }
2475
2476 // Emits code to process a tensor element in a tile for the given kLoop fusion
2477 // HLO containing parameters that are 0-2-1 transpose of its outputs.
2478 //
2479 // index: The index for the first output element in the normalized tensor, that
2480 // is the resulting tensor after collapsing contiguous dimensions that play
2481 // the same role in the transpose.
2482 // kernel_info: Other information to support the kernel code generation.
2483 // y_loc: The y coordinate within a tile.
2484 // x_loc: The x coordinate within a tile.
EmitTileElementForFusion(HloInstruction * hlo,const llvm_ir::IrArray::Index & index,const KernelCodegenInfo * kernel_info,llvm::Value * y_loc,llvm::Value * x_loc,int64)2485 void IrEmitterUnnested::EmitTileElementForFusion(
2486 HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
2487 const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
2488 llvm::Value* x_loc, int64 /*x_iter_num*/) {
2489 llvm_ir::TiledParameterInfo* tiled_param_info =
2490 kernel_info->GetTiledParameterInfo();
2491 std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
2492 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
2493 GetNestedComputer());
2494 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
2495 &elem_emitter);
2496 tiled_param_info->set_y(y_loc);
2497 tiled_param_info->set_x(x_loc);
2498 fused_emitter.SetTiledParameterInfo(tiled_param_info);
2499 TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
2500 IrArray::Index untiled_index =
2501 kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
2502 index, output_arrays[0].GetShape());
2503 const llvm_ir::ElementGenerator& output_generator =
2504 fused_emitter.GetRootGenerator();
2505 llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
2506 if (hlo->IsMultiOutputFusion()) {
2507 DCHECK(output_value->getType()->isStructTy());
2508 DCHECK_EQ(output_value->getType()->getStructNumElements(),
2509 output_arrays.size());
2510 for (int64 i = 0; i < output_arrays.size(); ++i) {
2511 output_arrays[i].EmitWriteArrayElement(
2512 untiled_index, ExtractValue(output_value, i), &b_);
2513 }
2514 } else {
2515 output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
2516 }
2517 }
2518
2519 // Information to support the code generation for a tiled reduction kernel.
2520 using AddressVector = InlinedVector<llvm::AllocaInst*, 1>;
2521 class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo {
2522 public:
ReductionCodegenInfo(llvm_ir::KernelMappingScheme * mapping_scheme,bool is_row_reduction)2523 explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme,
2524 bool is_row_reduction)
2525 : KernelCodegenInfo(mapping_scheme),
2526 current_output_linear_index_address_(nullptr),
2527 current_output_inbound_address_(nullptr),
2528 is_row_reduction_(is_row_reduction) {}
2529
SetCurrentOutputLinearIndexAddress(llvm::AllocaInst * a)2530 void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) {
2531 current_output_linear_index_address_ = a;
2532 }
2533 // Returns the address of the memory that stores the linear index of the
2534 // current output. Since we are processing reduction to contiguous physical
2535 // dimensions, this linear index is the linear index of the 1D output array.
GetCurrentOutputLinearIndexAddress() const2536 llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const {
2537 return current_output_linear_index_address_;
2538 }
2539
SetCurrentOutputInboundAddress(llvm::AllocaInst * a)2540 void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) {
2541 current_output_inbound_address_ = a;
2542 }
2543
GetCurrentOutputInboundAddress() const2544 llvm::AllocaInst* GetCurrentOutputInboundAddress() const {
2545 return current_output_inbound_address_;
2546 }
2547
GetMutablePartialResultAddresses()2548 AddressVector* GetMutablePartialResultAddresses() {
2549 return &partial_result_addresses_;
2550 }
GetPartialResultAddresses() const2551 absl::Span<llvm::AllocaInst* const> GetPartialResultAddresses() const {
2552 return partial_result_addresses_;
2553 }
2554
GetMutableReductionInputAddresses()2555 AddressVector* GetMutableReductionInputAddresses() {
2556 return &reduction_input_addresses_;
2557 }
GetReductionInputAddresses() const2558 absl::Span<llvm::AllocaInst* const> GetReductionInputAddresses() const {
2559 return reduction_input_addresses_;
2560 }
2561
GetMutableReducers()2562 InlinedVector<HloComputation*, 1>* GetMutableReducers() { return &reducers_; }
GetReducers() const2563 const InlinedVector<HloComputation*, 1>& GetReducers() const {
2564 return reducers_;
2565 }
GetNumberOfReduces() const2566 int GetNumberOfReduces() const { return reducers_.size(); }
2567
GetMutableReductionOutputShapeIndices()2568 InlinedVector<ShapeIndex, 1>* GetMutableReductionOutputShapeIndices() {
2569 return &reduction_output_shape_indices_;
2570 }
GetReductionOutputShapeIndices() const2571 absl::Span<const ShapeIndex> GetReductionOutputShapeIndices() const {
2572 return reduction_output_shape_indices_;
2573 }
2574
IsRowReduction() const2575 bool IsRowReduction() const { return is_row_reduction_; }
2576
2577 // Return the dimension that is being reduced between DimX and DimY.
GetReducedDimensionEnum() const2578 int GetReducedDimensionEnum() const {
2579 return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX
2580 : llvm_ir::KernelMappingScheme::DimY;
2581 }
2582
2583 // Return the dimension that is being ketp between DimX and DimY.
GetKeptDimensionEnum() const2584 int GetKeptDimensionEnum() const {
2585 return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY
2586 : llvm_ir::KernelMappingScheme::DimX;
2587 }
2588
GetNumberOfPartialResults() const2589 int GetNumberOfPartialResults() const {
2590 if (IsRowReduction()) {
2591 return 1;
2592 }
2593 int64 num_thread = mapping_scheme_->GetNumberOfThreadsForDimensionX();
2594 int64 tile_size = mapping_scheme_->GetTileSizeForDimensionX();
2595 CHECK_EQ(tile_size % num_thread, 0);
2596 return tile_size / num_thread;
2597 }
2598
GetPartialResultIndex(int64 x_iter_num) const2599 int GetPartialResultIndex(int64 x_iter_num) const {
2600 if (IsRowReduction()) {
2601 return 0;
2602 }
2603 return x_iter_num;
2604 }
2605
2606 private:
2607 AddressVector partial_result_addresses_;
2608 AddressVector reduction_input_addresses_;
2609 InlinedVector<HloComputation*, 1> reducers_;
2610 InlinedVector<ShapeIndex, 1> reduction_output_shape_indices_;
2611 llvm::AllocaInst* current_output_linear_index_address_;
2612 llvm::AllocaInst* current_output_inbound_address_;
2613 bool is_row_reduction_;
2614 };
2615
2616 namespace {
2617 // Returns a group of instructions that generate the output for the kernel
2618 // containing the given HLO instruction. The result may be an unnested kReduce
2619 // HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple
2620 // for a multiple output fusion.
GetOutputInstructions(HloInstruction * const * reduce_or_tuple_pointer)2621 absl::Span<HloInstruction* const> GetOutputInstructions(
2622 HloInstruction* const* reduce_or_tuple_pointer) {
2623 HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode();
2624 CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple);
2625 return opcode == HloOpcode::kTuple
2626 ? (*reduce_or_tuple_pointer)->operands()
2627 : absl::Span<HloInstruction* const>(reduce_or_tuple_pointer, 1);
2628 }
2629
GetFirstReduceInstruction(absl::Span<HloInstruction * const> instructions)2630 const HloInstruction* GetFirstReduceInstruction(
2631 absl::Span<HloInstruction* const> instructions) {
2632 auto first_reduce_iter =
2633 absl::c_find_if(instructions, [](const HloInstruction* inst) {
2634 return inst->opcode() == HloOpcode::kReduce;
2635 });
2636 CHECK_NE(first_reduce_iter, instructions.end());
2637 return *first_reduce_iter;
2638 }
2639
2640 }; // namespace
2641
EmitPrologueForOneReduction(HloInstruction * unnested_hlo,HloInstruction * reduce_inst,int reduce_idx,KernelCodegenInfo * kernel_info,GpuElementalIrEmitter * elemental_emitter,ShapeIndex output_shape_index)2642 void IrEmitterUnnested::EmitPrologueForOneReduction(
2643 HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx,
2644 KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter,
2645 ShapeIndex output_shape_index) {
2646 ReductionCodegenInfo* reduction_info =
2647 static_cast<ReductionCodegenInfo*>(kernel_info);
2648
2649 InlinedVector<HloComputation*, 1>* reducers =
2650 reduction_info->GetMutableReducers();
2651 CHECK(IsReductionToVector(*reduce_inst));
2652 reducers->push_back(reduce_inst->to_apply());
2653
2654 InlinedVector<ShapeIndex, 1>* reduction_output_shape_indices =
2655 reduction_info->GetMutableReductionOutputShapeIndices();
2656 reduction_output_shape_indices->push_back(std::move(output_shape_index));
2657
2658 AddressVector* reduction_input_addresses =
2659 reduction_info->GetMutableReductionInputAddresses();
2660 llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
2661 reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module());
2662 llvm::AllocaInst* reduction_input_address = Alloca(element_type);
2663 reduction_input_addresses->push_back(reduction_input_address);
2664
2665 int num_partial_results = reduction_info->GetNumberOfPartialResults();
2666 AddressVector* partial_result_addresses =
2667 reduction_info->GetMutablePartialResultAddresses();
2668 llvm::AllocaInst* partial_result_address =
2669 Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results),
2670 "partial_reduction_result." + llvm::Twine(reduce_idx));
2671 partial_result_addresses->push_back(partial_result_address);
2672
2673 // Initialize the partial result with the initial value of the reduction.
2674 llvm::Value* init_ir_value;
2675 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
2676 HloInstruction* init_value_operand = reduce_inst->mutable_operand(1);
2677 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
2678 elemental_emitter);
2679
2680 TF_CHECK_OK(init_value_operand->Accept(&fused_emitter));
2681 init_ir_value =
2682 fused_emitter
2683 .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty()))
2684 .ValueOrDie();
2685 } else {
2686 const HloInstruction* init_value = unnested_hlo->operand(1);
2687 init_ir_value =
2688 GetIrArray(*init_value, *unnested_hlo)
2689 .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_);
2690 }
2691
2692 for (int i = 0; i < num_partial_results; ++i) {
2693 Store(init_ir_value, InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
2694 }
2695 }
2696
EmitPrologueForReduction(HloInstruction * unnested_hlo,KernelCodegenInfo * kernel_info)2697 void IrEmitterUnnested::EmitPrologueForReduction(
2698 HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
2699 VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString();
2700 // Find the unnested kReduce or the tuple that contains a list of kReduce.
2701 HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
2702 ? unnested_hlo->fused_expression_root()
2703 : unnested_hlo;
2704 absl::Span<HloInstruction* const> output_instructions =
2705 GetOutputInstructions(&reduce_or_tuple);
2706 ReductionCodegenInfo* reduction_info =
2707 static_cast<ReductionCodegenInfo*>(kernel_info);
2708 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
2709 ir_emitter_context_->llvm_module(),
2710 &b_, GetNestedComputer());
2711 const HloInstruction* first_reduce = nullptr;
2712 for (int i = 0, e = output_instructions.size(); i != e; ++i) {
2713 if (output_instructions[i]->opcode() != HloOpcode::kReduce) {
2714 continue;
2715 }
2716 HloInstruction* reduce_inst = output_instructions[i];
2717 if (first_reduce == nullptr) {
2718 first_reduce = reduce_inst;
2719 } else {
2720 CHECK(first_reduce->dimensions() == reduce_inst->dimensions());
2721 }
2722 ShapeIndex output_shape_index;
2723 if (reduce_or_tuple->opcode() == HloOpcode::kTuple) {
2724 output_shape_index = {i};
2725 }
2726
2727 EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info,
2728 &elemental_emitter,
2729 std::move(output_shape_index));
2730 }
2731
2732 int num_partial_results = reduction_info->GetNumberOfPartialResults();
2733
2734 // Allocate stack storage to store the linear indices for the current output,
2735 // and record the address of the storage.
2736 reduction_info->SetCurrentOutputLinearIndexAddress(
2737 Alloca(reduction_info->GetIndexType(),
2738 /*ArraySize=*/b_.getInt32(num_partial_results),
2739 "current_output_linear_index_address"));
2740
2741 if (!reduction_info->IsRowReduction()) {
2742 llvm::Type* bool_ty = b_.getInt1Ty();
2743 llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty);
2744 Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr);
2745 reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr);
2746 }
2747 }
2748
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses)2749 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
2750 absl::Span<HloComputation* const> reducers,
2751 absl::Span<llvm::AllocaInst* const> partial_result_addresses) {
2752 for (int distance = 16; distance >= 1; distance /= 2) {
2753 for (int i = 0; i != reducers.size(); ++i) {
2754 llvm::Type* element_type =
2755 partial_result_addresses[i]->getType()->getElementType();
2756 int bit_width = llvm_ir::GetSizeInBits(element_type);
2757 llvm::Value* result_from_other_lane = Alloca(
2758 element_type, nullptr, "result_from_other_lane" + llvm::Twine(i));
2759 // Bitcast cannot be applied to aggregate types (even packed ones), so
2760 // we bitcast addresses of load/store to intN* of the same bit-width.
2761 llvm::Type* shuffled_value_type =
2762 element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
2763 auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
2764 return BitCast(ptr, shuffled_value_type->getPointerTo());
2765 };
2766 llvm::Value* partial_result =
2767 Load(convert_pointer_for_shuffle(partial_result_addresses[i]),
2768 "partial_reduction_result");
2769 Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
2770 convert_pointer_for_shuffle(result_from_other_lane));
2771 TF_CHECK_OK(EmitCallToNestedComputation(
2772 *reducers[i], {partial_result_addresses[i], result_from_other_lane},
2773 partial_result_addresses[i]));
2774 }
2775 }
2776 }
2777
EmitEpilogueForReduction(HloInstruction * unnested_hlo,KernelCodegenInfo * kernel_info)2778 void IrEmitterUnnested::EmitEpilogueForReduction(
2779 HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
2780 ReductionCodegenInfo* reduction_info =
2781 static_cast<ReductionCodegenInfo*>(kernel_info);
2782 int num_reduces = reduction_info->GetNumberOfReduces();
2783 absl::Span<llvm::AllocaInst* const> partial_result_addresses =
2784 reduction_info->GetPartialResultAddresses();
2785 const InlinedVector<HloComputation*, 1>& reducers =
2786 reduction_info->GetReducers();
2787 absl::Span<const ShapeIndex> reduction_output_shape_indices =
2788 reduction_info->GetReductionOutputShapeIndices();
2789
2790 if (reduction_info->IsRowReduction()) {
2791 EmitFullWarpShuffleDownLoopForAllReduces(reducers,
2792 partial_result_addresses);
2793 llvm::Value* lane_id = reduction_info->GetLaneId();
2794 llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
2795 ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)),
2796 "lane_id_is_zero", &b_);
2797 llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
2798 } else {
2799 llvm::Value* output_inbound_addr =
2800 reduction_info->GetCurrentOutputInboundAddress();
2801 llvm::Value* output_inbound = Load(output_inbound_addr);
2802 llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse(
2803 ICmpEQ(output_inbound,
2804 llvm::ConstantInt::get(output_inbound->getType(), 1)),
2805 "output_inbound", &b_);
2806 llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_);
2807 }
2808
2809 int num_partial_results = reduction_info->GetNumberOfPartialResults();
2810
2811 // Emit an atomic operation that accumulates the partial reduction to the
2812 // output element. For row reduction, this is only for lane 0 due to the
2813 // if-statement emitted above.
2814 for (int i = 0; i != num_reduces; ++i) {
2815 for (int j = 0; j < num_partial_results; ++j) {
2816 IrArray::Index element_index(
2817 /*linear=*/Load(
2818 InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(),
2819 {b_.getInt32(j)}),
2820 "output_linear_addr"),
2821 ShapeUtil::GetSubshape(unnested_hlo->shape(),
2822 reduction_output_shape_indices[i]),
2823 &b_);
2824 llvm::Value* output_address =
2825 GetIrArray(*unnested_hlo, *unnested_hlo,
2826 reduction_output_shape_indices[i])
2827 .EmitArrayElementAddress(element_index, &b_,
2828 "output_element_address");
2829 // Do not emit atomic operations if each element in the reduction result
2830 // is computed by one block, that is the dimension being reduced has only
2831 // one block.
2832 const llvm_ir::KernelMappingScheme* mapping_scheme =
2833 reduction_info->GetKernelMappingScheme();
2834 if (mapping_scheme->GetTileBlockSizeForDimension(
2835 llvm_ir::KernelMappingScheme::DimZ) == 1 &&
2836 mapping_scheme->GetTileBlockSizeForDimension(
2837 reduction_info->GetReducedDimensionEnum()) == 1) {
2838 TF_CHECK_OK(EmitCallToNestedComputation(
2839 *reducers[i],
2840 {output_address,
2841 InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})},
2842 output_address));
2843 } else {
2844 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
2845 *reducers[i], output_address,
2846 InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})));
2847 }
2848 }
2849 }
2850 }
2851
EmitTileElementForReduction(HloInstruction * unnested_hlo,const llvm_ir::IrArray::Index & index,const KernelCodegenInfo * kernel_info,llvm::Value * y_loc,llvm::Value * x_loc,int64 x_iter_num)2852 void IrEmitterUnnested::EmitTileElementForReduction(
2853 HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
2854 const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
2855 llvm::Value* x_loc, int64 x_iter_num) {
2856 VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
2857 HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
2858 ? unnested_hlo->fused_expression_root()
2859 : unnested_hlo;
2860 llvm_ir::TiledParameterInfo* tiled_param_info =
2861 kernel_info->GetTiledParameterInfo();
2862 tiled_param_info->set_y(y_loc);
2863 tiled_param_info->set_x(x_loc);
2864
2865 // Record the linear address for the current reduction.
2866 const ReductionCodegenInfo* reduction_info =
2867 dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
2868 int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num;
2869
2870 Store(index[reduction_info->GetKeptDimensionEnum()],
2871 InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(),
2872 {b_.getInt32(partial_result_index)}));
2873 if (!reduction_info->IsRowReduction()) {
2874 llvm::Type* bool_ty = b_.getInt1Ty();
2875 llvm::AllocaInst* output_inbound_addr =
2876 reduction_info->GetCurrentOutputInboundAddress();
2877 Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr);
2878 }
2879
2880 InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
2881 std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
2882 extra_output_gens;
2883 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
2884 GetNestedComputer());
2885 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
2886 &elem_emitter);
2887 absl::Span<HloInstruction* const> output_instructions =
2888 GetOutputInstructions(&reduce_or_tuple);
2889 // Construct the ElementGenerator for each reduction and extra output in the
2890 // the group of output instructions.
2891 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
2892 fused_emitter.SetTiledParameterInfo(tiled_param_info);
2893 TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
2894
2895 for (int i = 0, e = output_instructions.size(); i != e; ++i) {
2896 const HloInstruction* inst = output_instructions[i];
2897 ShapeIndex output_shape_index;
2898 if (reduce_or_tuple->opcode() == HloOpcode::kTuple) {
2899 output_shape_index = {i};
2900 }
2901 if (inst->opcode() == HloOpcode::kReduce) {
2902 input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
2903 } else {
2904 extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
2905 std::move(output_shape_index));
2906 }
2907 }
2908 } else {
2909 input_gens.push_back([&](const IrArray::Index& index) {
2910 return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo)
2911 .EmitReadArrayElement(index, &b_);
2912 });
2913 }
2914
2915 IrArray::Index input_index =
2916 reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
2917 index,
2918 GetFirstReduceInstruction(output_instructions)->operand(0)->shape());
2919 absl::Span<llvm::AllocaInst* const> partial_reduction_result_addresses =
2920 reduction_info->GetPartialResultAddresses();
2921 absl::Span<llvm::AllocaInst* const> reduction_input_addresses =
2922 reduction_info->GetReductionInputAddresses();
2923 const InlinedVector<HloComputation*, 1>& reducers =
2924 reduction_info->GetReducers();
2925
2926 // Emit code to generate the input and perform the reduction computation for
2927 // each reduction instruction.
2928 for (int i = 0; i != reducers.size(); ++i) {
2929 llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie();
2930 Store(input_ir_value, reduction_input_addresses[i]);
2931 llvm::Value* partial_result_address =
2932 InBoundsGEP(partial_reduction_result_addresses[i],
2933 {b_.getInt32(partial_result_index)});
2934 TF_CHECK_OK(EmitCallToNestedComputation(
2935 *reducers[i], {partial_result_address, reduction_input_addresses[i]},
2936 partial_result_address));
2937 }
2938
2939 // Emit code to generate the output for the non-reduction instructions in the
2940 // fusion, if any.
2941 TF_CHECK_OK(
2942 EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens));
2943 }
2944
2945 // Emits a kernel for the hlo instruction using the given tiling scheme.
EmitBlock(const TileGenerator & emit_one_tile,KernelCodegenInfo * kernel_info,KernelSupportLibrary * ksl,llvm::Type * index_ty)2946 void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile,
2947 KernelCodegenInfo* kernel_info,
2948 KernelSupportLibrary* ksl,
2949 llvm::Type* index_ty) {
2950 KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
2951 absl::Span<const int64> dims_in_tile = mapping_scheme->GetDimensionsInTiles();
2952 absl::Span<const int64> dims_in_block =
2953 mapping_scheme->GetDimensionsInBlocks();
2954 absl::Span<const int64> block_sizes = mapping_scheme->GetBlockSizes();
2955 auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
2956 return llvm::ConstantInt::get(index_ty, c);
2957 };
2958
2959 // Emit all the tiles for a given dimension in a tile block.
2960 auto emit_tiles_for_block_dim =
2961 [&](const string& loop_name, const IrArray::Index& starting_tile,
2962 int dim_id,
2963 const std::function<void(const IrArray::Index& tile_index)>
2964 emit_next_block_dim) {
2965 if (block_sizes[dim_id] == 1) {
2966 emit_next_block_dim(starting_tile);
2967 } else {
2968 llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id];
2969 llvm::Value* block_size_for_dim =
2970 index_typed_constant(block_sizes[dim_id]);
2971 llvm::Value* block_id_for_dim =
2972 b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
2973 llvm::Value* last_block_for_dim =
2974 index_typed_constant(dims_in_block[dim_id] - 1);
2975 llvm::Value* last_block_size_for_dim = index_typed_constant(
2976 dims_in_tile[dim_id] -
2977 (dims_in_block[dim_id] - 1) * block_sizes[dim_id]);
2978 llvm::Value* num_tiles_in_block =
2979 Select(ICmpEQ(last_block_for_dim, block_id_for_dim),
2980 last_block_size_for_dim, block_size_for_dim);
2981 ksl->For(loop_name,
2982 /*start=*/index_typed_constant(0),
2983 /*end=*/num_tiles_in_block,
2984 /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
2985 IrArray::Index tile_index = starting_tile.AddOffsetToDim(
2986 block_dim_induction_var, dim_id, &b_);
2987 emit_next_block_dim(tile_index);
2988 });
2989 }
2990 };
2991
2992 absl::Span<const int64> reduced_dims =
2993 mapping_scheme->GetDimensionsInElements();
2994 const bool block_contains_multi_tiles =
2995 mapping_scheme->GetNumberOfTilesInOneBlock() > 1;
2996
2997 // Emit the tile with a given tile_index, by calculating the tight bounds for
2998 // each dimension of the tile and then calling emit_one_tile.
2999 auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) {
3000 std::vector<llvm::Value*> output_tile_bounds(3);
3001 for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot;
3002 ++i) {
3003 int64 tile_size_for_dim = mapping_scheme->GetTileSizeForDimension(i);
3004 // Only last row or column may not have full size.
3005 llvm::Value* is_last_row =
3006 ICmpEQ(tile_index[i], index_typed_constant(dims_in_tile[i] - 1));
3007 int64 partial_row_size =
3008 reduced_dims[i] - (dims_in_tile[i] - 1) * tile_size_for_dim;
3009 output_tile_bounds[i] =
3010 Select(is_last_row, index_typed_constant(partial_row_size),
3011 index_typed_constant(tile_size_for_dim), "tile_bound");
3012 }
3013
3014 IrArray::Index tile_origin =
3015 mapping_scheme->GetElementIndexForTileOrigin(tile_index);
3016 emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles);
3017 };
3018
3019 const IrArray::Index starting_block =
3020 mapping_scheme->EmitBlockIndex(index_ty);
3021 const IrArray::Index starting_tile_for_dim_z =
3022 mapping_scheme->GetTileIndexForBlockOrigin(starting_block);
3023
3024 // Emit the three dimensional block of tiles.
3025 emit_tiles_for_block_dim(
3026 "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ,
3027 [&](const IrArray::Index& starting_tile_for_dim_y) {
3028 emit_tiles_for_block_dim(
3029 "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY,
3030 [&](const IrArray::Index& starting_tile_for_dim_x) {
3031 emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x,
3032 KernelMappingScheme::DimX,
3033 emit_one_tile_for_tile_index);
3034 });
3035 });
3036 }
3037
3038 // Emits a kernel for the hlo instruction using the given kernel mapping scheme.
3039 //
3040 // unnested_hlo: The unnested hlo instruction for which the kernel is generated.
3041 // Currently, these hlo instructions are supported: kLoop fusion, kCopy.
3042 // tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of
3043 // other tensors with the same dimensions and are safe to be tranposed via
3044 // the shared memory tranpose implementation.
3045 // mapping_scheme: The tiling scheme to use.
3046 // kernel_generator: Contains function objects for code generation, such as
3047 // element generator, block prologue and epilogue generators.
3048 // kernel_info: Represent other information to support the code generation
3049 // of the tiled kernel for the hlo.
EmitKernel(HloInstruction * unnested_hlo,absl::Span<const int64> tiled_param_ids,const KernelCodeGenerator & kernel_generator,KernelCodegenInfo * kernel_info)3050 LaunchDimensions IrEmitterUnnested::EmitKernel(
3051 HloInstruction* unnested_hlo, absl::Span<const int64> tiled_param_ids,
3052 const KernelCodeGenerator& kernel_generator,
3053 KernelCodegenInfo* kernel_info) {
3054 KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
3055
3056 std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*unnested_hlo);
3057 int64 num_params = param_arrays.size();
3058 // Allocate shared memory buffers to store the tiled inputs.
3059 std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
3060 for (int64 id : tiled_param_ids) {
3061 const HloInstruction* param = unnested_hlo->operand(id);
3062 param_shmem_buffers[id] =
3063 mapping_scheme->GetSharedMemoryBufferForElementType(
3064 llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(),
3065 module_),
3066 IrName(unnested_hlo, StrCat("tile", id)));
3067 VLOG(3) << "Added shmem buffer for parameter " << id << ": "
3068 << llvm_ir::DumpToString(*param_shmem_buffers[id]);
3069 }
3070
3071 const ReductionCodegenInfo* reduction_info =
3072 dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
3073 bool is_column_reduction =
3074 (reduction_info && !reduction_info->IsRowReduction());
3075
3076 LaunchDimensions launch_dimensions =
3077 LaunchDimensions(mapping_scheme->GetNumberOfBlocks(),
3078 mapping_scheme->GetThreadsPerBlock());
3079
3080 // TODO(b/110211620): Enable int32 index type for column reduction.
3081 llvm::Type* index_ty =
3082 is_column_reduction
3083 ? b_.getInt64Ty()
3084 : GetIndexTypeForKernel(unnested_hlo,
3085 launch_dimensions.launch_bound(), &b_);
3086
3087 auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
3088 return llvm::ConstantInt::get(index_ty, c);
3089 };
3090
3091 // For multioutput fusion, one thread needs to output a tuple with pointers to
3092 // all the individual outputs. We could do this at any point in the kernel,
3093 // but we do it at the beginning in the hopes of reducing register pressure,
3094 // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel
3095 // *anyway*.
3096 if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) {
3097 KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
3098 llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo),
3099 ConstructIrArrayForOutputs(*unnested_hlo), &b_);
3100 });
3101 }
3102
3103 // For each tiled parameter, cast its input IrArray to the corresponding
3104 // reduced shape and keep the reduced shape live during IR emission.
3105 std::vector<IrArray> param_in_reduced_shape_arrays;
3106 std::vector<Shape> param_reduced_shapes;
3107 absl::Span<const int64> reduced_dims =
3108 mapping_scheme->GetDimensionsInElements();
3109 int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape(
3110 *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims,
3111 ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays);
3112 DCHECK_EQ(num_shapes, num_params);
3113
3114 // Calculate the starting element coordinate within a tile for the current
3115 // thread, (y, x) from thread_id.
3116 llvm::Value* x;
3117 llvm::Value* y;
3118 std::tie(y, x) = mapping_scheme->EmitThreadYXCoordinate(index_ty);
3119
3120 kernel_info->SetLaneId(
3121 mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x
3122 : nullptr);
3123 kernel_info->SetIndexType(index_ty);
3124
3125 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3126 // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
3127 auto emit_tiled_elemental_code_with_bounds_check =
3128 [&](const IrArray::Index& index, const string& loop_name,
3129 llvm::Value* tile_height, llvm::Value* tile_width,
3130 const EmitElementFunction& emit_elem_function) {
3131 EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name,
3132 &ksl, &b_, y, x, tile_height,
3133 tile_width, emit_elem_function);
3134 };
3135
3136 auto emit_one_tile = [&](const IrArray::Index& output_tile_origin,
3137 absl::Span<llvm::Value* const> output_tile_bounds,
3138 bool block_contains_multi_tiles) {
3139 // Calculate the input tile origin from the output tile origin.
3140 const IrArray::Index input_tile_origin(
3141 Permute({0, 2, 1}, output_tile_origin.multidim()));
3142
3143 // If shared memory transpose is needed, wait for all threads to reach this
3144 // point, lest we copy a value from tile to output before the other thread
3145 // copies it from input to tile. This is `__syncthreads` in CUDA.
3146 if (!tiled_param_ids.empty()) {
3147 // Copy input parameter values to shared memory buffers:
3148 // tile[y, x] = input[index]
3149 // Note that tile_width and tile_height are flipped here because we are
3150 // reading a transposed tile.
3151 emit_tiled_elemental_code_with_bounds_check(
3152 input_tile_origin, "input", output_tile_bounds[2],
3153 output_tile_bounds[1],
3154 [&](const IrArray::Index& index, llvm::Value* y_loc,
3155 llvm::Value* x_loc, int64 /*x_iter_num*/) {
3156 for (int64 id : tiled_param_ids) {
3157 IrArray& input_in_logical_shape =
3158 param_in_reduced_shape_arrays[id];
3159 llvm::Value* shmem_buffer = param_shmem_buffers[id];
3160 // TODO(jlebar): Add AA metadata to this store. Tile buffers are
3161 // global variables, so LLVM can't infer much about it.
3162 Store(input_in_logical_shape.EmitReadArrayElement(
3163 index, &b_, "input_element"),
3164 GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc}));
3165 }
3166 });
3167
3168 // Wait for all threads to reach this point using `__syncthreads` in CUDA.
3169 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
3170 }
3171
3172 llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
3173 kernel_info->SetTiledParamInfo(&tiled_param_info);
3174
3175 // Write to output[index] by emitting code like normal, except that values
3176 // for the tiled parameters are read from the shmem buffers.
3177 emit_tiled_elemental_code_with_bounds_check(
3178 output_tile_origin, "output", output_tile_bounds[1],
3179 output_tile_bounds[2],
3180 [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc,
3181 int64 x_iter_num) {
3182 kernel_generator.GetTileElementGenerator()(
3183 unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num);
3184 });
3185
3186 // If a tile block contains multiple tiles and shared memory buffers are
3187 // used, we need to wait for all threads to finish using the shared memory
3188 // buffer for the current tile before we move on to process the next tile
3189 // and overwrite the shared memory buffers.
3190 if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
3191 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
3192 }
3193 };
3194
3195 const BlockPrologueGenerator& block_prologue_generator =
3196 kernel_generator.GetBlockPrologueGenerator();
3197 if (block_prologue_generator) {
3198 block_prologue_generator(unnested_hlo, kernel_info);
3199 }
3200
3201 EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty);
3202
3203 const BlockEpilogueGenerator& block_epilogue_generator =
3204 kernel_generator.GetBlockEpilogueGenerator();
3205 if (block_epilogue_generator) {
3206 block_epilogue_generator(unnested_hlo, kernel_info);
3207 }
3208
3209 return launch_dimensions;
3210 }
3211
3212 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
3213 // algorithm to improve the memory access patterns for the input parameters
3214 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
3215 // is responsible for making sure that it is safe to apply the shared memory
3216 // tranpose on the input parameters.
3217 //
3218 //
3219 // For the purpose of tiling, the output tensors have a logical shape of three
3220 // components 0-2-1 while the relevant input parameters have a logical shape
3221 // of three components 0-1-2 in the order major to minor. The x- and y-
3222 // dimensions of the tensors are tiled in square tiles with an edge length
3223 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
3224 // transposes one tile: each thread copies kTileSize/kNumRows elements from
3225 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
3226 // reads from the shared memory instead of the original input.
3227 //
3228 // This is similar to the following CUDA algorithm in TensorFlow:
3229 // https://goo.gl/MStRV6.
3230 //
3231 // `kTileSize` should usually be same as warp size. We currently choose 32 for
3232 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
3233 //
3234 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
3235 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(HloInstruction * hlo,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids)3236 LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
3237 HloInstruction* hlo, absl::Span<const int64> reduced_output_dims,
3238 absl::Span<const int64> tiled_param_ids) {
3239 constexpr int kNumRows = 4;
3240 KernelMappingScheme mapping_scheme(
3241 reduced_output_dims, /*tile_size_y=*/kWarpSize,
3242 /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1},
3243 /*num_threads_y=*/kNumRows,
3244 /*num_threads_x=*/kWarpSize, &b_);
3245 TileElementGenerator element_generator;
3246 if (hlo->opcode() == HloOpcode::kCopy) {
3247 element_generator = [&](HloInstruction* hlo,
3248 const llvm_ir::IrArray::Index& index,
3249 const KernelCodegenInfo* kernel_info,
3250 llvm::Value* y_loc, llvm::Value* x_loc,
3251 int64 x_iter_num) {
3252 EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num);
3253 };
3254 } else {
3255 DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
3256 element_generator =
3257 [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
3258 const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
3259 llvm::Value* x_loc, int64 x_iter_num) {
3260 EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc,
3261 x_iter_num);
3262 };
3263 }
3264 KernelCodegenInfo kernel_info(&mapping_scheme);
3265 KernelCodeGenerator kernel_generator(std::move(element_generator));
3266 return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info);
3267 }
3268
3269 namespace {
3270 // A recursive function to inspect the users of a parameter to determine
3271 // whether it's safe for a parameter to participate in a shared-memory
3272 // transpose.
3273 //
3274 // Consider a fusion parameter P for which we might want to use a shmem
3275 // transpose. If we do, we use a GPU thread block to preload a tile of P with
3276 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
3277 // cooperatively, where z, y, x are the indices for the normalized input/output
3278 // tensor (see the document for FindTranspose021 for the definition of
3279 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
3280 // requires that the computation of the output tile only read elements within
3281 // the preload tile. If this is not true, we can't use a shmem transpose for P.
3282 //
3283 // If the computation of output element [z, y, x] only requires the element of
3284 // P with the same indices, the shmem tranpose implementation can be applied
3285 // to P safely. This is a sufficient but not necessary condition. We check all
3286 // the transitive users of P to see if we can find a user that may cause an
3287 // exception to the situation. If such a user is not found, we conclude that P
3288 // is safe for shmem transpose.
3289 //
3290 // This is trivially true for elementwise operations and some "data-movement"
3291 // ops like kTuple. However, it's not true for operations that can change the
3292 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
3293 // For example:
3294 //
3295 // fused_computation {
3296 // param_0 = f32[64,64]{1,0} parameter(0)
3297 // ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
3298 // }
3299 // The output element at logical address [0, 63] depends on the input element
3300 // at logical address [63, 0], which would not be within the shared-memory
3301 // block.
3302 //
3303 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
3304 // with tranpose, we only need to end the use-chain checking with the input of
3305 // a reduce operations. In this case, the above description on "output" apply
3306 // to the result of such a use-chain, which provides the input to the reduce
3307 // operation.
IsInstructionSafeForShmemTranspose(const HloInstruction * hlo)3308 bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
3309 if (hlo->IsElementwise()) {
3310 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
3311 return IsInstructionSafeForShmemTranspose(user);
3312 });
3313 }
3314
3315 switch (hlo->opcode()) {
3316 // Non-elementwise instructions that don't cause the shmem transpose
3317 // to be unsafe, including the instructions that don't currently fuse.
3318 case HloOpcode::kGetDimensionSize:
3319 // The result of the operation doesn't rely on the content of the
3320 // tensor. As such, there is no need to further inspect its users.
3321 return true;
3322 case HloOpcode::kGetTupleElement:
3323 case HloOpcode::kMap:
3324 case HloOpcode::kParameter:
3325 case HloOpcode::kTuple:
3326 case HloOpcode::kTupleSelect:
3327 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
3328 return IsInstructionSafeForShmemTranspose(user);
3329 });
3330
3331 default:
3332 return false;
3333 }
3334 }
3335
3336 // Given a group of input parameters that are 0-2-1 tranpose of the outputs of
3337 // a fusion kernel, returns the input parameters that are safe for the shared
3338 // memory tranpose implementation.
3339 //
3340 // When a tile based shared memory transpose is used to implement an input with
3341 // 0-2-1 transpose, we preload a tile of the input elements
3342 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
3343 // indices. Preloading the input tile this way is only safe when the computation
3344 // of the output tile elements do not need any input element outside the
3345 // preloaded tile. We inspect all the transitive users of the input parameter
3346 // up to the fusion root instruction to see if we can find any instruction
3347 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(const HloInstruction * fusion,std::vector<int64> input_ids)3348 std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
3349 std::vector<int64> input_ids) {
3350 std::vector<int64> filtered_input_ids;
3351 for (int64 i = 0; i < input_ids.size(); ++i) {
3352 const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
3353 if (IsInstructionSafeForShmemTranspose(input)) {
3354 filtered_input_ids.push_back(input_ids[i]);
3355 } else {
3356 VLOG(10) << "Input not safe for shmem transpose " << input->ToString()
3357 << "\n";
3358 }
3359 }
3360 return filtered_input_ids;
3361 }
3362
3363 } // namespace
3364
CheckAndEmitHloWithTile021(HloInstruction * hlo)3365 bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
3366 HloOpcode opcode = hlo->opcode();
3367 CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
3368 CHECK(opcode != HloOpcode::kFusion ||
3369 hlo->fusion_kind() == HloInstruction::FusionKind::kLoop)
3370 << "Only loop fusions are supported.";
3371
3372 const Shape& output_shape = hlo->IsMultiOutputFusion()
3373 ? ShapeUtil::GetSubshape(hlo->shape(), {0})
3374 : hlo->shape();
3375
3376 // If the output_shape is reduced to 021 shape, find all the parameters of
3377 // the HLO that are in the corresponding 012 shape.
3378 std::vector<int64> params_012;
3379 optional<std::vector<int64>> reduced_dims_021;
3380 for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
3381 ++operand_idx) {
3382 HloInstruction* operand = hlo->mutable_operand(operand_idx);
3383 auto find_transpose_result =
3384 llvm_ir::FindTranspose021(operand->shape(), output_shape);
3385 if (!find_transpose_result.has_value()) {
3386 continue;
3387 }
3388 const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
3389 if (!reduced_dims_021.has_value()) {
3390 reduced_dims_021 = curr_reduced_dims_021;
3391 }
3392 if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
3393 // There is more than one possible transpose. Instead of picking one
3394 // transpose, we simply give up here.
3395 return false;
3396 }
3397 params_012.push_back(operand_idx);
3398 }
3399
3400 if (!reduced_dims_021.has_value()) {
3401 return false;
3402 }
3403
3404 if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
3405 (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
3406 return false;
3407 }
3408
3409 if (opcode == HloOpcode::kFusion) {
3410 params_012 = FilterInputsForShmemTranspose(hlo, params_012);
3411 if (params_012.empty()) {
3412 return false;
3413 }
3414 }
3415
3416 // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
3417 // elements are of size 4 bytes), and CUDA has an architectural limit of
3418 // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we
3419 // don't use this, in part because it eats into our L1 cache space.)
3420 //
3421 // For correctness we need to ensure that we don't make more than 48kb worth
3422 // of shmem tiles per block. And for performance, we'd probably like to use
3423 // significantly less, so that we can fit more than one block at a time on a
3424 // gpu core.
3425 //
3426 // We say without benchmarks that we want at least 3 threads/block,
3427 // corresponding to 3 shmem tiles if the elements are 32 bits wide. We
3428 // choose which params get the shmem transpose treatment arbitrarily; it's
3429 // not clear if there's a Right Choice.
3430 //
3431 // This is only sound if tiled transposes are the only place where we use
3432 // shared memory in fusions. If in the future other fusible ops use shared
3433 // memory, we'll have to adjust this heuristic.
3434 constexpr int kMinBlocksPerCore = 3;
3435 constexpr int64 kShmemPerCore = 48 * 1024;
3436 int64 shmem_used = 0;
3437 for (int64 i = 0; i < params_012.size(); ++i) {
3438 const HloInstruction* operand = hlo->operand(params_012[i]);
3439 shmem_used +=
3440 32 * 33 *
3441 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
3442
3443 if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
3444 // Erase this element and everything after it from params_012.
3445 params_012.resize(i);
3446 break;
3447 }
3448 }
3449
3450 VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
3451 std::unique_ptr<KernelThunk> kernel_thunk =
3452 BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
3453 const LaunchDimensions launch_dimensions =
3454 EmitHlo021Tile(hlo, *reduced_dims_021, params_012);
3455 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3456 ir_emitter_context_->llvm_module());
3457 AddThunkToThunkSequence(std::move(kernel_thunk));
3458
3459 return true;
3460 }
3461
3462 namespace {
3463 // Checks that the outputs of a fusion with reduction are consistent.
AreFusedReductionOutputsConsistent(absl::Span<HloInstruction * const> output_instructions,const HloInstruction * first_reduce)3464 Status AreFusedReductionOutputsConsistent(
3465 absl::Span<HloInstruction* const> output_instructions,
3466 const HloInstruction* first_reduce) {
3467 for (const HloInstruction* inst : output_instructions) {
3468 if (inst->opcode() == HloOpcode::kReduce) {
3469 // Shapes, layouts and dimensions must be the same for all reduces
3470 // inside of this fusion.
3471 TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
3472 TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
3473 inst->operand(0)->shape()));
3474 TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
3475 inst->operand(1)->shape()));
3476 TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions());
3477 } else {
3478 // For extra outputs we can relax shape equality to allow different
3479 // types (with the same number of elements). Layouts still have to
3480 // match.
3481 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType(
3482 first_reduce->operand(0)->shape(), inst->shape()));
3483 TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
3484 inst->shape().layout()));
3485 }
3486 }
3487 return Status::OK();
3488 }
3489
3490 // Finds the dimensions to keep for the reduction, sorts and returns the
3491 // dimensions from minor to major.
GetDimensionsToKeepMinorToMajor(const Shape & input_shape,absl::Span<const int64> dims_to_reduce)3492 DimensionVector GetDimensionsToKeepMinorToMajor(
3493 const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
3494 DimensionVector input_dims(input_shape.rank(), 0);
3495 absl::c_iota(input_dims, 0);
3496 DimensionVector input_dims_to_keep;
3497 for (int input_dim : input_dims) {
3498 auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) {
3499 return dim_to_reduce == input_dim;
3500 });
3501 if (it == dims_to_reduce.end()) {
3502 input_dims_to_keep.push_back(input_dim);
3503 }
3504 }
3505
3506 // Sort the dimensions to keep from minor to major.
3507 absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) {
3508 return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) <
3509 PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b);
3510 });
3511
3512 VLOG(10) << "dims to keep minor to major"
3513 << absl::StrJoin(input_dims_to_keep, ",");
3514 return input_dims_to_keep;
3515 }
3516
3517 // Given the input shape and dimensions to reduce for the reduction to vector,
3518 // returns <num_reduced_major, num_kept, num_reduced_minor>:
3519 // num_kept: the number of elements in the contiguous dimensions to keep.
3520 // num_reduced_major: the number of elements in the dimensions to reduce that
3521 // are more major than the dimensions to keep.
3522 // num_reduced_minor: the number of elements in the dimensions to reduce that
3523 // are more minor than the dimensions to kept.
GetReductionToVectorDimensions(const Shape & input_shape,absl::Span<const int64> dims_to_reduce)3524 std::tuple<int64, int64, int64> GetReductionToVectorDimensions(
3525 const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
3526 DimensionVector input_dims_to_keep_minor_to_major =
3527 GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce);
3528 CHECK(LayoutUtil::AreDimensionsConsecutive(
3529 input_shape.layout(), input_dims_to_keep_minor_to_major));
3530 int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
3531 if (input_dims_to_keep_minor_to_major.empty()) {
3532 return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
3533 }
3534 DimensionVector input_dims(input_shape.rank(), 0);
3535 absl::c_iota(input_dims, 0);
3536 absl::Span<const int64> minor_to_major =
3537 LayoutUtil::MinorToMajor(input_shape);
3538 for (int input_dim : input_dims) {
3539 int64 curr_dim_size = input_shape.dimensions(input_dim);
3540 if (PositionInContainer(minor_to_major, input_dim) >
3541 PositionInContainer(minor_to_major,
3542 input_dims_to_keep_minor_to_major.back())) {
3543 num_reduced_major *= curr_dim_size;
3544 } else if (PositionInContainer(minor_to_major, input_dim) <
3545 PositionInContainer(minor_to_major,
3546 input_dims_to_keep_minor_to_major.front())) {
3547 num_reduced_minor *= curr_dim_size;
3548 } else {
3549 num_kept *= curr_dim_size;
3550 }
3551 }
3552
3553 return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
3554 }
3555
3556 // Returns true if all the transitive users of hlo before hitting users in
3557 // use_chain_endings are elementwise operations.
AreUsersElementwise(const HloInstruction * hlo,const ConstHloInstructionSet & use_chain_endings)3558 bool AreUsersElementwise(const HloInstruction* hlo,
3559 const ConstHloInstructionSet& use_chain_endings) {
3560 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
3561 return use_chain_endings.count(user) ||
3562 (user->IsElementwise() &&
3563 AreUsersElementwise(user, use_chain_endings));
3564 });
3565 }
3566
3567 // Returns the number of fusion inputs that have the same dimension as the
3568 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(const HloInstruction * unnested_hlo,const Shape & op_shape,const ConstHloInstructionSet & use_chain_endings)3569 int64 NumInputsInvolveInOnlyElementwiseOps(
3570 const HloInstruction* unnested_hlo, const Shape& op_shape,
3571 const ConstHloInstructionSet& use_chain_endings) {
3572 return absl::c_count_if(
3573 unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
3574 const Shape& parameter_shape = parameter->shape();
3575 return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
3576 AreUsersElementwise(parameter, use_chain_endings);
3577 });
3578 }
3579
3580 // Returns the number of fusion inputs that have more elements than the given
3581 // shape.
NumInputsWithMoreElementsThan(const HloInstruction * unnested_hlo,const Shape & shape)3582 int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo,
3583 const Shape& shape) {
3584 int64 num_elements = ShapeUtil::ElementsIn(shape);
3585 return absl::c_count_if(
3586 unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
3587 return ShapeUtil::ElementsIn(parameter->shape()) > num_elements;
3588 });
3589 }
3590
3591 // The benefit of unrolling a kInput fusion that is a column reduction comes
3592 // from the vectorization of non-reduction fusion outputs and fusion inputs.
3593 // On the other hand, unrolling can also introduce factors that can cause
3594 // the kernel to run slower. This routine uses a simple heuristic to estimate
3595 // the benefit as well as the overhead of unrolling in order to decide whether
3596 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(const HloInstruction * unnested_hlo,const Shape & input_shape,int64 num_kept)3597 bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
3598 const Shape& input_shape,
3599 int64 num_kept) {
3600 // TODO(b/122468062): Need further investigate to see whether we can
3601 // remove the constraint on IsPowerOfTwo.
3602 if (!IsPowerOfTwo(static_cast<uint64>(num_kept))) {
3603 return false;
3604 }
3605
3606 if (unnested_hlo->opcode() == HloOpcode::kReduce) {
3607 return true;
3608 }
3609
3610 CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion);
3611 int64 can_be_vectorized = 0;
3612 int64 cannot_be_vectorized = 0;
3613 const HloInstruction* fused_root = unnested_hlo->fused_expression_root();
3614 ConstHloInstructionSet use_chain_endings;
3615 if (fused_root->opcode() == HloOpcode::kReduce) {
3616 use_chain_endings.insert(fused_root);
3617 // Atomic.add of the reduction result can't be vectorized.
3618 cannot_be_vectorized++;
3619 } else {
3620 CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple);
3621 for (const HloInstruction* instr : fused_root->operands()) {
3622 if (instr->opcode() == HloOpcode::kReduce) {
3623 // Atomic.add of the reduction result can't be vectorized.
3624 cannot_be_vectorized++;
3625 } else {
3626 // Write of the non-reduction result can be vectorized.
3627 can_be_vectorized++;
3628 }
3629 use_chain_endings.insert(instr);
3630 }
3631 }
3632 // Fusion inputs that have the same dimension as the reduce input and
3633 // only involve in elementwise operations can be vectorized.
3634 can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(
3635 unnested_hlo, input_shape, use_chain_endings);
3636 // Fusion inputs with more elements than the reduce op input must participate
3637 // in non-elementwise operations and we assume that they are not vectorizable
3638 // for the purpose of estimating the benefit of unrolling. If the kernel is
3639 // unrolled even with such an assumption, and the accesses to those inputs
3640 // turn out to be vectorizable, the compiler will still vectorize them.
3641 cannot_be_vectorized +=
3642 NumInputsWithMoreElementsThan(unnested_hlo, input_shape);
3643 return can_be_vectorized >= cannot_be_vectorized;
3644 }
3645
3646 } // namespace
3647
3648 std::tuple<KernelMappingScheme, bool>
ComputeMappingSchemeAndReductionKind(const HloInstruction * unnested_hlo,const HloInstruction * first_reduce)3649 IrEmitterUnnested::ComputeMappingSchemeAndReductionKind(
3650 const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) {
3651 int64 depth = 1;
3652 int64 height = 1;
3653 int64 width = 1;
3654 bool is_row_reduction = true;
3655 int64 tile_size_x = 1;
3656 int64 tile_size_y = 1;
3657 int64 block_size_z = 1;
3658 int64 num_threads_x = 1;
3659 int64 num_threads_y = 1;
3660 const Shape& input_shape = first_reduce->operand(0)->shape();
3661 int64 num_input_elems = ShapeUtil::ElementsIn(input_shape);
3662 int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape());
3663 int64 num_reduced_major, num_kept, num_reduced_minor;
3664 std::tie(num_reduced_major, num_kept, num_reduced_minor) =
3665 GetReductionToVectorDimensions(input_shape, first_reduce->dimensions());
3666 CHECK_EQ(num_output_elems, num_kept);
3667 bool dilated_x = true;
3668
3669 if (num_kept == 1) {
3670 // Scalar reduction is a special row reduction with depth = height = 1.
3671 width = num_input_elems;
3672 tile_size_x = kWarpSize * 16;
3673 num_threads_x = kWarpSize;
3674 } else if (num_reduced_minor == 1) {
3675 // Column reduction reduces inputs with dimension [height, width], where
3676 // width is the minor dimension, to dimension [width].
3677 height = num_reduced_major;
3678 width = num_kept;
3679 is_row_reduction = false;
3680 // Column reduction without transpose doesn't require communication among
3681 // threads processing elements in the same tile. The current implementation
3682 // only support the use of one hardware thread block to process one block of
3683 // tiles in the KernelMappingScheme. We try to use one thread to compute
3684 // the partial results for two tensor elements and to maximize the values of
3685 // num_threads_x and tile_size_x to allow a bigger hardware thread block.
3686 int64 hw_threads_per_block_limit =
3687 ThreadsPerBlockLimit(ir_emitter_context_->device_description());
3688 if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
3689 num_kept)) {
3690 tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept);
3691 num_threads_x = tile_size_x / 2;
3692 dilated_x = false;
3693 } else {
3694 tile_size_x = std::min(hw_threads_per_block_limit, num_kept);
3695 num_threads_x = tile_size_x;
3696 }
3697 int64 kNumElementsPerPartialSum = 128;
3698 tile_size_y = kNumElementsPerPartialSum;
3699 } else {
3700 // Row reduction reduces inputs with dimension [depth, height, width],
3701 // where width is the most minor dimension, to dimension [height] .
3702 depth = num_reduced_major;
3703 height = num_kept;
3704 width = num_reduced_minor;
3705 num_threads_x = kWarpSize;
3706 if (width % (kWarpSize * 64) == 0) {
3707 tile_size_x = kWarpSize * 64;
3708 } else {
3709 tile_size_x = kWarpSize * 8;
3710 block_size_z = 8;
3711 while (depth % block_size_z != 0) {
3712 block_size_z -= 1;
3713 }
3714 }
3715 }
3716 DCHECK_EQ(depth * height * width, num_input_elems);
3717 VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height
3718 << " " << width;
3719
3720 DimensionVector dims_in_elem{depth, height, width};
3721 DimensionVector req_block_sizes{block_size_z, 1, 1};
3722 llvm_ir::KernelMappingScheme mapping_scheme(
3723 dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y,
3724 num_threads_x, &b_);
3725 mapping_scheme.SetDilatedX(dilated_x);
3726 return std::make_tuple(mapping_scheme, is_row_reduction);
3727 }
3728
EmitReductionToVector(HloInstruction * unnested_hlo)3729 Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) {
3730 VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
3731
3732 HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
3733 ? unnested_hlo->fused_expression_root()
3734 : unnested_hlo;
3735 absl::Span<HloInstruction* const> output_instructions =
3736 GetOutputInstructions(&reduce_or_tuple);
3737 const HloInstruction* first_reduce =
3738 GetFirstReduceInstruction(output_instructions);
3739
3740 if (output_instructions.size() > 1) {
3741 TF_RETURN_IF_ERROR(
3742 AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
3743 }
3744
3745 // Build an initializer thunk to initialize each reduction output.
3746 std::vector<std::unique_ptr<Thunk>> thunks;
3747 for (int i = 0, e = output_instructions.size(); i != e; ++i) {
3748 if (output_instructions[i]->opcode() != HloOpcode::kReduce) {
3749 continue;
3750 }
3751 TF_ASSIGN_OR_RETURN(
3752 std::unique_ptr<Thunk> initializer_thunk,
3753 BuildInitializerThunk(unnested_hlo,
3754 (output_instructions[i] == reduce_or_tuple)
3755 ? ShapeIndex()
3756 : ShapeIndex({i})));
3757 thunks.push_back(std::move(initializer_thunk));
3758 }
3759
3760 // Build a kernel thunk to compute all the outputs.
3761 std::unique_ptr<KernelThunk> kernel_thunk =
3762 BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
3763
3764 const Shape& input_shape = first_reduce->operand(0)->shape();
3765 // The layout of a reduction input is either set by LayoutAssignment for
3766 // unnested kReduce or by InstructionFusion for fused kReduce.
3767 CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
3768 "doesn't set the input layout of "
3769 << first_reduce->ToString();
3770
3771 bool is_row_reduction;
3772 llvm_ir::KernelMappingScheme mapping_scheme;
3773 std::tie(mapping_scheme, is_row_reduction) =
3774 ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce);
3775 ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction);
3776 KernelCodeGenerator kernel_generator(
3777 /*tile_element_generator=*/
3778 [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
3779 const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
3780 llvm::Value* x_loc, int64 x_iter_num) {
3781 EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc,
3782 x_iter_num);
3783 },
3784 /*block_prologue_generator=*/
3785 [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
3786 EmitPrologueForReduction(hlo, kernel_info);
3787 },
3788 /*block_epilogue_generator*/
3789 [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
3790 EmitEpilogueForReduction(hlo, kernel_info);
3791 });
3792
3793 LaunchDimensions launch_dimensions =
3794 EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info);
3795 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3796 ir_emitter_context_->llvm_module());
3797
3798 thunks.push_back(std::move(kernel_thunk));
3799 std::unique_ptr<SequentialThunk> sequential_thunk =
3800 absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
3801 AddThunkToThunkSequence(std::move(sequential_thunk));
3802
3803 return Status::OK();
3804 }
3805
EmitConstantGlobals()3806 Status IrEmitterUnnested::EmitConstantGlobals() {
3807 for (const BufferAllocation& allocation :
3808 ir_emitter_context_->buffer_assignment().Allocations()) {
3809 if (!allocation.is_constant()) {
3810 continue;
3811 }
3812
3813 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
3814 const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
3815 llvm::ArrayType* global_type =
3816 llvm::ArrayType::get(b_.getInt8Ty(), allocation.size());
3817 llvm::Constant* initializer =
3818 should_emit_initializer
3819 ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
3820 : llvm::ConstantAggregateZero::get(global_type);
3821 if (should_emit_initializer) {
3822 VLOG(3) << "Emitted initializer for constant with shape "
3823 << ShapeUtil::HumanString(literal.shape());
3824 }
3825
3826 // These globals will be looked up by name by GpuExecutable so we need to
3827 // give them an external linkage. Not all of their uses are visible in
3828 // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
3829 // merely preserves their names (like available_externally), we also need
3830 // to ensure that they stick around even if they're "unused".
3831 //
3832 // We may have to be more more clever here in the future if we notice that
3833 // we're keeping around too many globals because of their linkage.
3834 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
3835 global_type, /*isConstant=*/should_emit_initializer,
3836 llvm::GlobalValue::ExternalLinkage,
3837 /*Initializer=*/initializer,
3838 llvm_ir::ConstantBufferAllocationToGlobalName(allocation));
3839 global_for_const->setAlignment(kConstantBufferAlignBytes);
3840 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
3841 global_for_const);
3842 }
3843
3844 return Status::OK();
3845 }
3846
3847 } // namespace gpu
3848 } // namespace xla
3849