1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <type_traits>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/IR/BasicBlock.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "llvm/IR/Module.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
43 #include "mlir/IR/Attributes.h"  // from @llvm-project
44 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
45 #include "mlir/IR/Builders.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
49 #include "mlir/IR/Verifier.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
53 #include "tensorflow/compiler/mlir/utils/name_utils.h"
54 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
55 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
56 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
57 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
58 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
59 #include "tensorflow/compiler/xla/layout_util.h"
60 #include "tensorflow/compiler/xla/literal.h"
61 #include "tensorflow/compiler/xla/permutation_util.h"
62 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
63 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
64 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
65 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
66 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
67 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
68 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
69 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
70 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
71 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
72 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
73 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
74 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
75 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
76 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
77 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
78 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
79 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
80 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
81 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
82 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
83 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
84 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
85 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
86 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
87 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
88 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
89 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
90 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
91 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
92 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
93 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
94 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
95 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
97 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
98 #include "tensorflow/compiler/xla/service/hlo_computation.h"
99 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
100 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
101 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
102 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
103 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
104 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
105 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
106 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
107 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
108 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
109 #include "tensorflow/compiler/xla/service/name_uniquer.h"
110 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
111 #include "tensorflow/compiler/xla/service/shape_inference.h"
112 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
113 #include "tensorflow/compiler/xla/shape_util.h"
114 #include "tensorflow/compiler/xla/status_macros.h"
115 #include "tensorflow/compiler/xla/types.h"
116 #include "tensorflow/compiler/xla/union_find.h"
117 #include "tensorflow/compiler/xla/util.h"
118 #include "tensorflow/compiler/xla/window_util.h"
119 #include "tensorflow/compiler/xla/xla_data.pb.h"
120 #include "tensorflow/core/lib/core/bits.h"
121 #include "tensorflow/core/lib/core/status.h"
122 #include "tensorflow/core/platform/errors.h"
123 #include "tensorflow/core/platform/logging.h"
124 
125 #if GOOGLE_CUDA
126 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
127 #endif  // GOOGLE_CUDA
128 
129 namespace xla {
130 namespace gpu {
131 
132 namespace {
133 
134 using absl::InlinedVector;
135 using absl::nullopt;
136 using absl::optional;
137 using absl::StrCat;
138 using llvm_ir::IrArray;
139 using llvm_ir::IrName;
140 
141 const auto kDimX = KernelMappingScheme::DimX;
142 const auto kDimY = KernelMappingScheme::DimY;
143 const auto kDimZ = KernelMappingScheme::DimZ;
144 const auto kDimTot = KernelMappingScheme::DimTot;
145 
146 const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
147 const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
148 const auto kStridedLinearIndexingX =
149     KernelMappingScheme::StridedLinearIndexingX;
150 
151 // If a dimensions is smaller than this, untiled transposition may be more
152 // efficient.
153 const int64 kMinDimensionToTransposeTiled = 16;
154 
155 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
156 // of the corresponding IR kernel in "llvm_module".
157 // Precondition: "thunk" must be a KernelThunk.
UpdateLaunchDimensions(const LaunchDimensions & launch_dims,Thunk * thunk,llvm::Module * llvm_module)158 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
159                             llvm::Module* llvm_module) {
160   CHECK(Thunk::Kind::kKernel == thunk->kind());
161   KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
162   kernel_thunk->SetLaunchDimensions(launch_dims);
163 
164   // Add __launch_bounds__ to metadata. This limits registers per thread to
165   // avoid out-of-resources launching errors.
166   llvm::NamedMDNode* nvvm_annotations_node =
167       llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
168   llvm::Function* ir_kernel =
169       llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
170   llvm::LLVMContext& llvm_context = llvm_module->getContext();
171   llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
172       llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
173       launch_dims.thread_counts_per_block().x);
174   // Our launch bounds are exact, so we can specify them as reqntidx rather than
175   // maxntidx.
176   nvvm_annotations_node->addOperand(llvm::MDNode::get(
177       llvm_context,
178       {llvm::ConstantAsMetadata::get(ir_kernel),
179        llvm::MDString::get(llvm_context, "reqntidx"),
180        llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
181 }
182 
BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,int64 v)183 bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,
184                                    int64 v) {
185   mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
186   return std::binary_search(
187       elements.begin(), elements.end(), value,
188       [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); });
189 }
190 
191 // Returns true if the fusion contains any instruction that is likely
192 // translated to complex LLVM IR, such as loops, and prevent vectorization.
MayPreventVectorization(const HloInstruction & hlo)193 bool MayPreventVectorization(const HloInstruction& hlo) {
194   if (hlo.opcode() == HloOpcode::kFusion) {
195     return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
196                           [](const HloInstruction* instr) {
197                             switch (instr->opcode()) {
198                               case HloOpcode::kReduceWindow:
199                               case HloOpcode::kSort:
200                               case HloOpcode::kDot:
201                               case HloOpcode::kSin:
202                               case HloOpcode::kCos:
203                               case HloOpcode::kPower:
204                               case HloOpcode::kAtan2:
205                                 return true;
206                               case HloOpcode::kReduce:
207                                 return !instr->shape().IsArray();
208                               default:
209                                 return false;
210                             }
211                           });
212   } else if (hlo.IsElementwise()) {
213     // Unfused elementwise operations are usually memory bound, unroll them.
214     switch (hlo.opcode()) {
215         // The following elementwise operation implementations contain branches.
216         // LLVM vectorizer doesn't work in that case.
217         // The unrolled code is faster when it isn't vectorized.
218       case HloOpcode::kSin:
219       case HloOpcode::kCos:
220       case HloOpcode::kPower:
221       case HloOpcode::kAtan2:
222         return true;
223       default:
224         return false;
225     }
226   } else if (hlo.opcode() == HloOpcode::kReduce && hlo.shape().IsArray()) {
227     // TODO(timshen): check if the to_apply() attribute contains instructions
228     // that break LLVM vectorization.
229     return false;
230   }
231   return true;
232 }
233 
LmhloOpIsElementwise(mlir::Operation * op)234 bool LmhloOpIsElementwise(mlir::Operation* op) {
235   CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
236   auto opcode = *MhloToHloOpcode(op);
237   if (HloInstruction::IsOpElementwise(opcode)) {
238     return true;
239   }
240   if (opcode == HloOpcode::kMap) {
241     int iota = 0;
242     for (const llvm::APInt& i :
243          mlir::cast<mlir::lmhlo::MapOp>(op).dimensions()) {
244       if (i.getZExtValue() != iota) {
245         return false;
246       }
247       iota++;
248     }
249     return true;
250   }
251   // TODO(timshen): not sure about whether porting
252   // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion()
253   // doesn't use such information.
254   return false;
255 }
256 
MayPreventVectorization(mlir::Operation * op)257 bool MayPreventVectorization(mlir::Operation* op) {
258   CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
259   auto opcode = *MhloToHloOpcode(op);
260 
261   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
262     for (mlir::Operation& instr : fusion.region().front()) {
263       if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
264                     mlir::TensorLoadOp, mlir::TensorStoreOp>(&instr)) {
265         continue;
266       }
267       CHECK(instr.getDialect() == instr.getContext()->getLoadedDialect("mhlo"))
268           << MlirToString(op);
269       switch (*MhloToHloOpcode(&instr)) {
270         case HloOpcode::kReduceWindow:
271         case HloOpcode::kSort:
272         case HloOpcode::kDot:
273         case HloOpcode::kSin:
274         case HloOpcode::kCos:
275         case HloOpcode::kPower:
276         case HloOpcode::kAtan2:
277           return true;
278         case HloOpcode::kReduce:
279           if (instr.getNumResults() > 1) {
280             return true;
281           }
282           break;
283         default:
284           break;
285       }
286     }
287     return false;
288   } else if (LmhloOpIsElementwise(op)) {
289     // Unfused elementwise operations are usually memory bound, unroll them.
290     switch (opcode) {
291         // The following elementwise operation implementations contain branches.
292         // LLVM vectorizer doesn't work in that case.
293         // The unrolled code is faster when it isn't vectorized.
294       case HloOpcode::kSin:
295       case HloOpcode::kCos:
296       case HloOpcode::kPower:
297       case HloOpcode::kAtan2:
298         return true;
299       default:
300         return false;
301     }
302   } else if (opcode == HloOpcode::kReduce && GetHloOutputs(op).size() == 1) {
303     // TODO(timshen): check if the to_apply() attribute contains instructions
304     // that break LLVM vectorization.
305     return false;
306   }
307   return true;
308 }
309 
GetOutputOps(mlir::lmhlo::FusionOp fusion)310 std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
311   llvm::SetVector<mlir::Operation*> ops;
312   for (mlir::Value output_value : fusion.getFusionResults()) {
313     ops.insert(output_value.getDefiningOp());
314   }
315   return std::vector<mlir::Operation*>(ops.begin(), ops.end());
316 }
317 
318 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const Shape & shape,const HloModuleConfig & hlo_module_config)319 int ComputeMaxUnrollFactor(const Shape& shape,
320                            const HloModuleConfig& hlo_module_config) {
321   int max_unroll_factor =
322       hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
323 
324   // Find the largest possible power of two to unroll by.
325   // TODO(kramerb): Make this smarter.
326   int64 num_elements = ShapeUtil::ElementsIn(shape);
327   for (int i = max_unroll_factor; i > 1; i /= 2) {
328     if (num_elements % i == 0) {
329       return i;
330     }
331   }
332 
333   // Cannot unroll.
334   return 1;
335 }
336 
337 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const HloInstruction * hlo)338 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
339   const Shape& element_shape = hlo->IsMultiOutputFusion()
340                                    ? ShapeUtil::GetSubshape(hlo->shape(), {0})
341                                    : hlo->shape();
342   return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
343 }
344 
345 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Operation * op,const HloModuleConfig & hlo_module_config)346 int ComputeMaxUnrollFactor(mlir::Operation* op,
347                            const HloModuleConfig& hlo_module_config) {
348   Shape element_shape = [&] {
349     std::vector<Shape> shapes;
350     // Detect multi-output fusion. Notice that for a reduce in the fusion that
351     // returns a tuple, we don't want to treat it as multi-output fusion. We
352     // want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
353     // MOF, just pass the first element of the root tuple.
354     if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
355       std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
356       for (mlir::Value result : fusion_outputs[0]->getResults()) {
357         shapes.push_back(TypeToShape(result.getType()));
358       }
359     } else {
360       for (mlir::Value result : GetHloOutputs(op)) {
361         shapes.push_back(TypeToShape(result.getType()));
362       }
363     }
364     if (shapes.size() > 1) {
365       return ShapeUtil::MakeTupleShape(shapes);
366     }
367     return shapes[0];
368   }();
369   return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
370 }
371 
372 // Returns the llvm type for the indices used in the kernel that contains the
373 // hlo instruction. Such indices include the index for the parallel loop and
374 // the indices for the tensors accessed by the kernel. The return type is i32
375 // iff the following conditions are met:
376 //  . The launch_size of the kernel is within the range of i32.
377 //  . The sizes of all the tensors accessed within the kernel are within the
378 //    range of i32.
379 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64 launch_size,llvm::IRBuilder<> * b)380 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
381                                   llvm::IRBuilder<>* b) {
382   // Find the unnested hlo instruction for which the kernel is generated for.
383   const HloInstruction* unnested_hlo = hlo;
384   const HloComputation* computation = hlo->parent();
385   if (computation->IsFusionComputation()) {
386     unnested_hlo = computation->FusionInstruction();
387   }
388 
389   auto shape_in_range = [&](const Shape& s) {
390     bool in_range = true;
391     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
392                                       const ShapeIndex& /*index*/) {
393       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
394         in_range = false;
395       }
396     });
397 
398     return in_range;
399   };
400 
401   llvm::Type* i64_ty = b->getInt64Ty();
402   // Check launch dimension
403   if (!IsInt32(launch_size)) {
404     return i64_ty;
405   }
406 
407   // Check the size of result tensors
408   if (!shape_in_range(unnested_hlo->shape())) {
409     return i64_ty;
410   }
411 
412   auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
413     return shape_in_range(operand->shape());
414   };
415 
416   // Check the size of input tensors
417   if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
418     return i64_ty;
419   }
420 
421   // Check the size of the internal result tensors
422   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
423     if (!absl::c_all_of(
424             unnested_hlo->fused_instructions_computation()->instructions(),
425             hlo_shape_in_range)) {
426       return i64_ty;
427     }
428   }
429 
430   return b->getInt32Ty();
431 }
432 
433 // The same as GetIndexTypeForKernel, but works with MLIR ops.
GetIndexTypeForKernelFromMlir(mlir::Operation * op,int64 launch_size,llvm::IRBuilder<> * b)434 llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op,
435                                           int64 launch_size,
436                                           llvm::IRBuilder<>* b) {
437   auto shape_in_range = [&](const Shape& s) {
438     bool in_range = true;
439     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
440                                       const ShapeIndex& /*index*/) {
441       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
442         in_range = false;
443       }
444     });
445 
446     return in_range;
447   };
448 
449   llvm::Type* i64_ty = b->getInt64Ty();
450   // Check launch dimension
451   if (!IsInt32(launch_size)) {
452     return i64_ty;
453   }
454 
455   // Check the size of result tensors
456   for (auto result : GetHloOutputs(op)) {
457     if (!shape_in_range(TypeToShape(result.getType()))) {
458       return i64_ty;
459     }
460   }
461 
462   auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
463     return shape_in_range(TypeToShape(operand.getType()));
464   };
465 
466   // Check the size of input tensors
467   if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) {
468     return i64_ty;
469   }
470 
471   // Check the size of the internal result tensors
472   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
473     auto result = fusion.region().walk([&](mlir::Operation* op) {
474       for (mlir::Value result : op->getResults()) {
475         if (!hlo_shape_in_range(result)) {
476           return mlir::WalkResult::interrupt();
477         }
478       }
479       return mlir::WalkResult::advance();
480     });
481     if (result.wasInterrupted()) {
482       return i64_ty;
483     }
484   }
485 
486   return b->getInt32Ty();
487 }
488 
489 // Gets the input shape of the ROOT slices, which will be used as the kernel
490 // launch dims. The slice input fusion requires the input shapes of the ROOT
491 // slices to be the same although the (slice) output shapes can be different.
492 //
493 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
494 // slices are the same and the slices are non-strided. Otherwise, returns
495 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(mlir::lmhlo::FusionOp fusion)496 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
497     mlir::lmhlo::FusionOp fusion) {
498   if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
499     return FailedPrecondition(
500         "Unsupported root for slice input fusion. "
501         "Only non-strided slices are supported.");
502   }
503 
504   absl::optional<Shape> first_slice_operand_shape;
505   for (mlir::Value result : fusion.getFusionResults()) {
506     auto slice =
507         mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(result.getDefiningOp());
508     if (!slice) {
509       return FailedPrecondition("Expected a slice op");
510     }
511     if (first_slice_operand_shape.has_value()) {
512       Shape operand_shape = TypeToShape(slice.operand().getType());
513       if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape,
514                                                operand_shape)) {
515         return FailedPrecondition(
516             "Fused slices do not have the same input shape, instruction is %s",
517             MlirToString(fusion));
518       }
519     } else {
520       first_slice_operand_shape = TypeToShape(slice.operand().getType());
521     }
522   }
523   if (!first_slice_operand_shape.has_value()) {
524     return InvalidArgument("Fusion has no roots");
525   }
526   return *first_slice_operand_shape;
527 }
528 
529 }  // namespace
530 
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)531 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
532                                      const HloComputation* hlo_computation,
533                                      IrEmitterContext* ir_emitter_context)
534     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
535 
Create(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)536 StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
537     const HloModuleConfig& hlo_module_config,
538     const HloComputation* hlo_computation,
539     IrEmitterContext* ir_emitter_context) {
540   auto emitter = std::unique_ptr<IrEmitterUnnested>(new IrEmitterUnnested(
541       hlo_module_config, hlo_computation, ir_emitter_context));
542   if (hlo_computation) {
543     emitter->mlir_scratch_module_.emplace(mlir::ModuleOp::create(
544         mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc()));
545     emitter->lhlo_scratch_emitter_.emplace(
546         emitter->ir_emitter_context_->buffer_assignment(), *hlo_computation,
547         emitter->mlir_scratch_module_->get());
548     TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_->Initialize());
549   }
550   return std::move(emitter);
551 }
552 
Postprocess(HloInstruction * hlo)553 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
554   bindings_.UnbindAllLocalIrValues();
555   return DfsHloVisitor::Postprocess(hlo);
556 }
557 
BuildKernelPrototype(absl::string_view name,absl::Span<const BufferAllocation * const> args)558 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
559     absl::string_view name, absl::Span<const BufferAllocation* const> args) {
560   // Compute the kernel name. The opcode string may contain "-" which cannot be
561   // in a PTX function name, so sanitize the name before uniquifying it.
562   string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
563       llvm_ir::SanitizeFunctionName(std::string(name)));
564 
565   // Create the kernel and add it to the module.
566   llvm::Module* module = ir_emitter_context_->llvm_module();
567   llvm::LLVMContext& context = module->getContext();
568   llvm::FunctionType* kernel_type = llvm::FunctionType::get(
569       /*Result=*/llvm::Type::getVoidTy(context),
570       std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
571       /*isVarArg=*/false);
572   llvm::Function* kernel =
573       llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
574                              kernel_name.c_str(), module);
575 
576   // Add dereferenceable and alignment information to each of the kernel's
577   // parameters.
578   auto arg_it = kernel->arg_begin();
579   for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
580     const BufferAllocation* alloc = args[arg_no];
581     llvm::Argument* fn_arg = &*arg_it;
582     ++arg_it;
583 
584     kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
585 
586     const int64 alignment = [&] {
587       if (alloc->is_entry_computation_parameter()) {
588         return kEntryParameterAlignBytes;
589       } else if (alloc->is_constant()) {
590         return kConstantBufferAlignBytes;
591       } else {
592         return kXlaAllocatedBufferAlignBytes;
593       }
594     }();
595 
596     kernel->addParamAttr(
597         arg_no,
598         llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
599 
600     if (alloc->IsPreallocatedTempBuffer()) {
601       fn_arg->setName("temp_buf");
602     } else {
603       fn_arg->setName(StrCat("alloc", alloc->index()));
604     }
605   }
606 
607   AnnotateFunctionAsGpuKernel(module, kernel, &b_);
608 
609   // TODO(b/65380986): Investigate if adding fast math flags for generated
610   // kernels makes sense.
611 
612   // Update the insert point to the entry basic block.
613   llvm::BasicBlock* entry_bb =
614       llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
615 
616   // Emit a "return void" at entry_bb's end, and set the insert point before
617   // that return instruction.
618   b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
619 
620   return kernel;
621 }
622 
GetAllocationSliceForMlir(mlir::Value v)623 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
624     mlir::Value v) {
625   return xla::gpu::GetAllocationSliceForMlir(
626       v, ir_emitter_context_->allocations());
627 }
628 
DefaultAction(HloInstruction * hlo)629 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
630   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
631   return EmitUsingElementalIrEmitter(input);
632 }
633 
EmitUsingElementalIrEmitter(MlirEmitterInput input)634 Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
635   // Replace unnested op with a fused nested op.
636   //
637   // TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
638   // because we don't have a fully functioning LMHLO graph yet.
639 
640   mlir::Location loc = input.op->getLoc();
641   mlir::lmhlo::FusionOp fusion =
642       mlir::OpBuilder(input.op).create<mlir::lmhlo::FusionOp>(loc);
643   Shape output_shape;
644   mlir::OpBuilder b(&fusion.region());
645 
646   const auto load_memrefs = [loc, &b](mlir::ValueRange range) {
647     std::vector<mlir::Value> operands;
648     for (mlir::Value memref : range) {
649       auto load = b.create<mlir::TensorLoadOp>(loc, memref);
650       HloFunctionImporter::SetLayoutForMlir(load,
651                                             TypeToShape(memref.getType()));
652       operands.push_back(load);
653     }
654     return operands;
655   };
656 
657   if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
658     auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
659     HloFunctionImporter::SetLayoutForMlir(
660         operand, TypeToShape(copy.operand().getType()));
661     auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
662     output_shape = TypeToShape(copy.output().getType());
663     HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
664     b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
665   } else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(input.op)) {
666     std::vector<mlir::Value> operands = load_memrefs(reduce.operands());
667     std::vector<mlir::Value> init_values = load_memrefs(reduce.init_values());
668     auto fused_reduce = b.create<mlir::mhlo::ReduceOp>(
669         loc, operands, init_values, reduce.dimensions());
670     fused_reduce.body().takeBody(reduce.body());
671     CHECK_EQ(fused_reduce.getNumResults(), reduce.out().size());
672     std::vector<Shape> output_shapes;
673     for (int i = 0; i < reduce.out().size(); i++) {
674       b.create<mlir::TensorStoreOp>(loc, fused_reduce.getResult(i),
675                                     reduce.out()[i]);
676       auto shape = TypeToShape(reduce.out()[i].getType());
677       if (i == 0) {
678         HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape);
679       }
680       output_shapes.push_back(shape);
681     }
682     if (output_shapes.size() == 1) {
683       output_shape = output_shapes[0];
684     } else {
685       output_shape = ShapeUtil::MakeTupleShape(output_shapes);
686     }
687   } else {
688     // Try to generically convert any LMHLO ops to LMHLO fusion + the
689     // corresponding MHLO op. Currently we've only looked at elementwise ops and
690     // they seem to be well covered.
691     //
692     // TODO(timshen): Moving forward, we should make it cover all ops if
693     // possible, and only special-case the ones it can't.
694     std::vector<mlir::Value> outputs;
695     mlir::Operation* new_op;
696     {
697       auto operands = GetHloOperands(input.op);
698       outputs = GetHloOutputs(input.op);
699       TF_RET_CHECK(outputs.size() == 1) << MlirToString(input.op);
700 
701       std::vector<mlir::Value> loads = load_memrefs(operands);
702       std::string mhlo_op_name = mlir::hlo::LmhloToMhloOpName(
703           input.op->getName().getStringRef(), input.op->getContext());
704       TF_RET_CHECK(!mhlo_op_name.empty())
705           << "No corresponding MHLO op for given LMHLO op: "
706           << MlirToString(input.op);
707       mlir::OperationState op_state(loc, mhlo_op_name);
708 
709       mlir::BlockAndValueMapping mapper;
710       for (mlir::Region& region : input.op->getRegions()) {
711         mlir::Region* new_region = op_state.addRegion();
712         region.cloneInto(new_region, mapper);
713       }
714 
715       op_state.addOperands(loads);
716       op_state.addAttributes(input.op->getAttrs());
717       op_state.addTypes({mlir::RankedTensorType::get(
718           outputs[0].getType().cast<mlir::MemRefType>().getShape(),
719           outputs[0].getType().cast<mlir::MemRefType>().getElementType())});
720       new_op = b.createOperation(op_state);
721     }
722     TF_RET_CHECK(mlir::succeeded(mlir::verify(new_op)));
723     output_shape = TypeToShape(outputs[0].getType());
724     HloFunctionImporter::SetLayoutForMlir(new_op, output_shape);
725     b.create<mlir::TensorStoreOp>(loc, new_op->getResult(0), outputs[0]);
726   }
727   int unroll_factor = 1;
728   if (!MayPreventVectorization(input.op)) {
729     unroll_factor = ComputeMaxUnrollFactor(input.op, hlo_module_config_);
730   }
731   input.op->erase();
732   input.op = fusion;
733   return EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
734 }
735 
HandleConstant(HloInstruction * constant)736 Status IrEmitterUnnested::HandleConstant(HloInstruction* constant) {
737   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(constant));
738   return EmitConstant(input);
739 }
740 
EmitConstant(MlirEmitterInput mlir_input)741 Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
742   auto get_global = mlir::cast<mlir::GetGlobalMemrefOp>(mlir_input.op);
743   auto module = get_global->getParentOfType<mlir::ModuleOp>();
744   auto global =
745       mlir::cast<mlir::GlobalMemrefOp>(module.lookupSymbol(get_global.name()));
746 
747   auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
748   TF_RET_CHECK(literal);
749 
750   const bool should_emit_initializer = literal.getType().getNumElements() <= 1;
751 
752   TF_ASSIGN_OR_RETURN(int element_bytes,
753                       GetElementTypeBytes(literal.getType().getElementType()));
754   llvm::ArrayType* global_type = llvm::ArrayType::get(
755       b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
756 
757   GpuExecutable::ConstantInfo info;
758   llvm::Constant* initializer;
759   if (should_emit_initializer) {
760     std::vector<uint8> content;
761     TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
762     initializer = llvm::ConstantDataArray::get<uint8>(
763         ir_emitter_context_->llvm_module()->getContext(), content);
764   } else {
765     TF_RETURN_IF_ERROR(
766         CopyDenseElementsDataToXlaFormat(literal, &info.content));
767     initializer = llvm::ConstantAggregateZero::get(global_type);
768   }
769 
770   // These globals will be looked up by name by GpuExecutable so we need to
771   // give them an external linkage.  Not all of their uses are visible in
772   // the LLVM IR so we can't give then a linkage that merely preserves their
773   // names (like available_externally), we also need to ensure that they stick
774   // around even if they're "unused".
775   //
776   // We may have to be more clever here in the future if we notice that we're
777   // keeping around too many globals because of their linkage.
778   unsigned global_address_space =
779       llvm_ir::GetGlobalMemoryAddressSpace(*ir_emitter_context_->llvm_module());
780 
781   llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
782       global_type, /*isConstant=*/should_emit_initializer,
783       llvm::GlobalValue::ExternalLinkage,
784       /*Initializer=*/initializer, global.sym_name(),
785       /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
786       /*AddressSpace=*/global_address_space,
787       /*isExternallyInitialized=*/false);
788   global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
789   ir_emitter_context_->llvm_module()->getGlobalList().push_back(
790       global_for_const);
791 
792   info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
793 
794   info.allocation_index =
795       global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
796   ir_emitter_context_->constants().push_back(std::move(info));
797   return Status::OK();
798 }
799 
HandleConditional(HloInstruction * conditional)800 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
801   TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
802   AddThunkToThunkSequence(std::move(thunk));
803   return Status::OK();
804 }
805 
HandleConvolution(HloInstruction * convolution)806 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
807   AddThunkToThunkSequence(
808       BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
809   return IrEmitter::HandleConvolution(convolution);
810 }
811 
812 // Input = {dynamic array(with dynamic dimension meta data at the end)}
813 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitPadToStaticFromMlir(MlirEmitterInput mlir_input)814 Status IrEmitterUnnested::EmitPadToStaticFromMlir(MlirEmitterInput mlir_input) {
815   // TODO(jurahul): Create an op to represent PadToStatic.
816   auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(mlir_input.op);
817   int unroll_factor = 1;
818   std::string ir_name = mlir::GetNameFromLoc(pad_to_static.getLoc());
819 
820   std::vector<llvm_ir::IrArray> ir_arrays;
821   TF_ASSIGN_OR_RETURN(
822       auto kernel_thunk,
823       BuildKernelThunkForMlir(pad_to_static, mlir_input.thunk_info,
824                               mlir_input.extra_slice, &ir_arrays));
825 
826   const llvm_ir::IrArray source_array = ir_arrays[0];
827   const llvm_ir::IrArray output_array = ir_arrays[1];
828   auto output_dim_arrays =
829       absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2);
830 
831   // pseudo code for PadToStatic on a 2d array
832   //   int* source_array = input[0];
833   //   int* dest_array = output[0];
834   const Shape& data_shape =
835       TypeToShape(pad_to_static.output().front().getType());
836   const Shape& input_shape =
837       TypeToShape(pad_to_static.args().front().getType());
838   llvm::Value* source_buffer = source_array.GetBasePointer();
839   llvm::Value* raw_buffer =
840       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
841 
842   // TODO(jurahul): input_shape here is the static shape of the input (which has
843   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
844   // memref. When we change that to a more appropriate representation in MLIR,
845   // fix this code to correctly deduce the static shape backing the dynamically
846   // shaped memref.
847   int64 raw_data_size = ShapeUtil::ByteSizeOf(input_shape);
848 
849   //   int* dyn_dim0_size = source_array + meta_data_offset;
850   //   int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
851   std::vector<llvm::Value*> dynamic_dims;
852   for (int64 i = 1; i < pad_to_static.output().size(); ++i) {
853     // Dynamic size of each dimension is attached at the end of the source
854     // array(operand(0)). We need to extract these value.
855     const Shape& dim_shape = TypeToShape(pad_to_static.output()[i].getType());
856     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
857 
858     const int64 dim_index = i - 1;
859     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
860         b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
861     llvm::Value* dyn_dim_size = b_.CreateLoad(
862         b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
863         "dyn_dim_size");
864     dynamic_dims.push_back(dyn_dim_size);
865   }
866 
867   // only one thread need to store the dynamic index
868   //   int thread_id = GetThreadId();
869   //   int block_id = GetBlockId();
870   //   if (thread_id == 0 && block_id == 0) {
871   //     *output[1] = *dyn_dim0_size;
872   //     *output[2] = *dyn_dim1_size;
873   //   }
874   KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
875     for (int64 i = 1; i < pad_to_static.output().size(); ++i) {
876       const int64 dim_index = i - 1;
877       llvm::Value* dest_dim_size_address =
878           output_dim_arrays[dim_index].GetBasePointer();
879       // output[i] stores dynamic_dim_(i-1)
880       b_.CreateStore(dynamic_dims[i - 1],
881                      b_.CreateBitCast(dest_dim_size_address,
882                                       b_.getInt32Ty()->getPointerTo()));
883     }
884   });
885 
886   //     int dyn_element_total = 1;
887   //     dyn_element_total *= *dyn_dim0_size;
888   //     dyn_element_total *= *dyn_dim1_size;
889   llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
890   for (llvm::Value* dynamic_dim : dynamic_dims) {
891     dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
892                                      /*Name=*/"dyn_element_total");
893   }
894 
895   //   linear_index = block_id * threads_per_block + thread_id;
896   //   if (linear_index < max_num_element) {
897   //     Index static_index =
898   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
899   //     if (linerized_index < dyn_element_total) {
900   //       Index dyn_index =
901   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
902   //       dest_array[dyn_index.dim0][dyn_index.dim1] =
903   //           source_array[static_index.dim0][static_index.dim1];
904   //     }
905   //   }
906   llvm_ir::LoopEmitter::BodyEmitter body_generator =
907       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
908     llvm::Value* linearIndex =
909         array_index.Linearize(input_shape.dimensions(), &b_);
910     auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
911         b_.CreateICmpULT(linearIndex, dyn_element_total),
912         llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
913     // Set IR builder insertion point to the body of the if structure.
914     llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
915     llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
916                                       absl::MakeSpan(dynamic_dims), &b_);
917     output_array.EmitWriteArrayElement(
918         dyn_index,
919         source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_,
920         /*use_linear_index=*/false);
921     return Status::OK();
922   };
923 
924   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
925       input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
926   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
927                          ir_emitter_context_->llvm_module());
928   TF_RETURN_IF_ERROR(
929       ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
930                           unroll_factor)
931           .EmitLoop(ir_name,
932                     GetIndexTypeForKernelFromMlir(
933                         pad_to_static, launch_dimensions.launch_bound(), &b_)));
934   thunk_sequence_.emplace_back(std::move(kernel_thunk));
935   return Status::OK();
936 }
937 
938 // Input = {dynamic array(with dynamic dimension meta data at the end)}
939 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitSliceToDynamicFromMlir(MlirEmitterInput mlir_input)940 Status IrEmitterUnnested::EmitSliceToDynamicFromMlir(
941     MlirEmitterInput mlir_input) {
942   // TODO(jurahul): Create an op to represent SliceToDynamic.
943   auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(mlir_input.op);
944   int unroll_factor = 1;
945   std::string ir_name = mlir::GetNameFromLoc(slice_to_dynamic.getLoc());
946 
947   std::vector<llvm_ir::IrArray> ir_arrays;
948   TF_ASSIGN_OR_RETURN(
949       auto kernel_thunk,
950       BuildKernelThunkForMlir(slice_to_dynamic, mlir_input.thunk_info,
951                               mlir_input.extra_slice, &ir_arrays));
952 
953   const Shape& input_shape =
954       TypeToShape(slice_to_dynamic.args().front().getType());
955   TF_RET_CHECK(slice_to_dynamic.output().size() == 1);
956   const Shape& data_shape =
957       TypeToShape(slice_to_dynamic.output().front().getType());
958 
959   // TODO(jurahul): data_shape here is the static shape of the output (which has
960   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
961   // memref. When we change that to a more appropriate representation in MLIR,
962   // fix this code to correctly deduce the static shape backing the dynamically
963   // shaped memref.
964 
965   // calculate the location where metadata needs to be inserted
966   //   int* dyn_dim0_size = dest_array + meta_data_offset;
967   //   int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
968   int32 raw_data_size = ShapeUtil::ByteSizeOf(data_shape);
969 
970   // pseudo code for sliceToDynamic on a 2d array
971   //   int* source_array = input[0];
972   //   int* dest_array = output[0];
973   const llvm_ir::IrArray data_array = ir_arrays.back();
974   llvm::Value* dest_buffer = data_array.GetBasePointer();
975   llvm::Value* raw_buffer =
976       b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
977 
978   // Load dynamic dimensions from memory.
979   std::vector<llvm::Value*> dynamic_dims;
980   for (int64 i = 1; i < slice_to_dynamic.args().size(); ++i) {
981     // const int64 dim_index = i - 1;
982     llvm::Value* source_buffer = ir_arrays[i].GetBasePointer();
983     llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
984     dynamic_dims.push_back(dyn_dim_size);
985   }
986 
987   // only one thread need to store the dynamic index
988   //   int thread_id = GetThreadId();
989   //   int block_id = GetBlockId();
990   //   if (thread_id == 0 && block_id == 0) {
991   //     *dyn_dim0_size = *output[1];
992   //     *dyn_dim1_size = *output[2];
993   //   }
994   KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
995     for (int64 i = 1; i < slice_to_dynamic.args().size(); ++i) {
996       const int64 dim_index = i - 1;
997       llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
998           b_.getInt8Ty(), raw_buffer,
999           raw_data_size + dim_index * sizeof(int32));
1000       // output[i] stores dynamic_dim_(i-1)
1001       b_.CreateStore(
1002           dynamic_dims[dim_index],
1003           b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
1004     }
1005   });
1006 
1007   //     int dyn_element_total = 1;
1008   //     dyn_element_total *= dyn_dim0_size;
1009   //     dyn_element_total *= dyn_dim1_size;
1010   llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
1011   for (llvm::Value* dynamic_dim : dynamic_dims) {
1012     dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
1013                                      /*Name=*/"dyn_element_total");
1014   }
1015 
1016   //   linear_index = block_id * threads_per_block + thread_id;
1017   //   if (linear_index < max_num_element) {
1018   //     Index static_index =
1019   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
1020   //     if (linerized_index < dyn_element_total) {
1021   //       Index dyn_index =
1022   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
1023   //       dest_array[static_index.dim0][static_index.di] =
1024   //           source_array[dyn_index.dim0][dyn_index.dim1];
1025   //     }
1026   //   }
1027   llvm_ir::LoopEmitter::BodyEmitter body_generator =
1028       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
1029     llvm::Value* linearIndex =
1030         array_index.Linearize(input_shape.dimensions(), &b_);
1031     auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
1032         b_.CreateICmpULT(linearIndex, dyn_element_total),
1033         llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
1034     // Set IR builder insertion point to the body of the if structure.
1035     llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
1036     llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
1037                                       absl::MakeSpan(dynamic_dims), &b_);
1038 
1039     data_array.EmitWriteArrayElement(
1040         array_index,
1041         ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
1042                                           /*use_linear_index=*/false),
1043         &b_);
1044     return Status::OK();
1045   };
1046 
1047   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1048       input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
1049   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
1050                          ir_emitter_context_->llvm_module());
1051 
1052   TF_RETURN_IF_ERROR(
1053       ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
1054                           unroll_factor)
1055           .EmitLoop(ir_name, GetIndexTypeForKernelFromMlir(
1056                                  slice_to_dynamic,
1057                                  launch_dimensions.launch_bound(), &b_)));
1058   thunk_sequence_.emplace_back(std::move(kernel_thunk));
1059   return Status::OK();
1060 }
1061 
HandleCustomCall(HloInstruction * custom_call)1062 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
1063   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
1064   return EmitCustomCallFromMlir(input);
1065 }
1066 
EmitCustomCallFromMlir(MlirEmitterInput input)1067 Status IrEmitterUnnested::EmitCustomCallFromMlir(MlirEmitterInput input) {
1068   using mlir::dyn_cast;
1069   using mlir::isa;
1070 
1071   if (auto call = dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
1072     if (call.call_target_name() == "PadToStatic") {
1073       return EmitPadToStaticFromMlir(input);
1074     }
1075     if (call.call_target_name() == "SliceToDynamic") {
1076       return EmitSliceToDynamicFromMlir(input);
1077     }
1078     return EmitCustomCallThunkFromMlir(input);
1079   }
1080 
1081   if (isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
1082     return EmitGemmThunkFromMlir(input);
1083   }
1084 
1085   if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
1086                 mlir::lmhlo_gpu::ConvForwardFusedOp,
1087                 mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
1088                 mlir::lmhlo_gpu::ConvBackwardFilterOp,
1089                 mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
1090     return EmitConvolutionThunkFromMlir(input);
1091   }
1092 
1093   if (isa<mlir::lmhlo_gpu::BatchNormTrainingOp,
1094           mlir::lmhlo_gpu::BatchNormInferenceOp,
1095           mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) {
1096     return EmitBatchNormThunkFromMlir(input);
1097   }
1098 
1099 #if GOOGLE_CUDA
1100   if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) {
1101     return EmitCholeskyThunkFromMlir(input);
1102   }
1103 #endif  // GOOGLE_CUDA
1104 
1105   return Unimplemented("No registered implementation for custom call to \"%s\"",
1106                        MlirToString(input.op));
1107 }
1108 
EmitConvolutionThunkFromMlir(MlirEmitterInput input)1109 Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
1110   using mlir::dyn_cast;
1111   using mlir::lmhlo_gpu::Activation;
1112   using mlir::lmhlo_gpu::ConvBackwardFilterOp;
1113   using mlir::lmhlo_gpu::ConvBackwardInputOp;
1114   using mlir::lmhlo_gpu::ConvForwardFusedOp;
1115   using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
1116   using mlir::lmhlo_gpu::ConvForwardOp;
1117 
1118   // Last 2 operands of the convolution operation are the result and scratch.
1119   std::vector<BufferAllocation::Slice> operand_slices;
1120   int64 num_operands = input.op->getNumOperands();
1121   operand_slices.reserve(num_operands - 2);
1122   for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
1123     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
1124     operand_slices.push_back(slice);
1125   }
1126 
1127   mlir::Value conv_result = input.op->getOperand(num_operands - 2);
1128   mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
1129   TF_ASSIGN_OR_RETURN(auto conv_result_slice,
1130                       GetAllocationSliceForMlir(conv_result));
1131   TF_ASSIGN_OR_RETURN(auto scratch_slice,
1132                       GetAllocationSliceForMlir(scratch_result));
1133 
1134   auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
1135     mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
1136         llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
1137           return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
1138         }));
1139     return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
1140                                           shape.dimensions(), minor_to_major);
1141   };
1142 
1143   GpuConvDescriptor descriptor;
1144 
1145   auto fill_conv_descriptor = [&](auto op) {
1146     descriptor.operand0_shape =
1147         apply_layout(TypeToShape(input.op->getOperand(0).getType()),
1148                      op.backend_config().operand_0_layout());
1149     descriptor.operand1_shape =
1150         apply_layout(TypeToShape(input.op->getOperand(1).getType()),
1151                      op.backend_config().operand_1_layout());
1152     descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
1153                                            op.backend_config().result_layout());
1154     descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
1155     descriptor.scratch_size =
1156         input.extra_slice->shape.tuple_shapes(1).dimensions(0);
1157     mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
1158     mlir::DenseIntElementsAttr padding = op.padding().getValue();
1159     mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
1160     mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
1161     mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
1162     for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
1163       WindowDimension* dim = descriptor.window.add_dimensions();
1164       // Window size for a convolution is the same as the kernel size.
1165       // Kernel size of the convolution is operand1_shape. We need to look at
1166       // the convolution dimension numbers kernel spatial dimensions to get
1167       // the window size.
1168       int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
1169       dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
1170       dim->set_stride(window_strides.getValue<int64>(index));
1171       dim->set_padding_low(padding.getValue<int64>(index));
1172       dim->set_padding_high(padding.getValue<int64>(index));
1173       dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
1174       dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
1175       dim->set_window_reversal(window_reversal.getValue<bool>(index));
1176     }
1177     descriptor.feature_group_count = op.feature_group_count();
1178     descriptor.backend_config.set_algorithm(
1179         op.backend_config().algorithm().getInt());
1180     descriptor.backend_config.set_tensor_ops_enabled(
1181         op.backend_config().tensor_ops_enabled().getValue());
1182     descriptor.backend_config.set_conv_result_scale(
1183         op.result_scale().convertToDouble());
1184   };
1185 
1186   auto set_activation_mode = [&](auto op) -> Status {
1187     TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
1188                         ConvertConvActivationMode(op.activation_mode()));
1189     descriptor.backend_config.set_activation_mode(
1190         static_cast<int64>(activation_mode));
1191     return Status::OK();
1192   };
1193 
1194   if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
1195     descriptor.kind = CudnnConvKind::kForward;
1196     fill_conv_descriptor(op);
1197   } else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
1198     descriptor.kind = CudnnConvKind::kBackwardInput;
1199     fill_conv_descriptor(op);
1200   } else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
1201     descriptor.kind = CudnnConvKind::kBackwardFilter;
1202     fill_conv_descriptor(op);
1203   } else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
1204     descriptor.kind = CudnnConvKind::kForwardActivation;
1205     fill_conv_descriptor(op);
1206     TF_RETURN_IF_ERROR(set_activation_mode(op));
1207   } else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
1208     descriptor.kind = CudnnConvKind::kForwardActivation;
1209     fill_conv_descriptor(op);
1210     TF_RETURN_IF_ERROR(set_activation_mode(op));
1211     descriptor.backend_config.set_side_input_scale(
1212         op.side_input_scale().convertToDouble());
1213   } else {
1214     return InternalError("Unexpected operation");
1215   }
1216   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
1217   AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
1218       input.thunk_info, std::move(config), std::move(operand_slices),
1219       conv_result_slice, scratch_slice));
1220   return Status::OK();
1221 }
1222 
EmitGemmThunkFromMlir(MlirEmitterInput input)1223 Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
1224   auto build_gemm_config = [](auto op) {
1225     GpuGemmConfig config;
1226     GemmBackendConfig& backend = config.backend_config;
1227     config.output_shape = TypeToShape(op.output().getType());
1228     config.lhs_shape = TypeToShape(op.lhs().getType());
1229     config.rhs_shape = TypeToShape(op.rhs().getType());
1230     backend.Clear();
1231     if (op.algorithm()) {
1232       backend.set_selected_algorithm(*op.algorithm());
1233     }
1234     backend.set_alpha_real(op.alpha_real().convertToDouble());
1235     backend.set_alpha_imag(op.alpha_imag().convertToDouble());
1236     backend.set_batch_size(op.batch_size());
1237 
1238     auto& dims = *backend.mutable_dot_dimension_numbers();
1239     auto mlir_dims = op.dot_dimension_numbers();
1240 
1241     auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) {
1242       for (llvm::APInt e : mlir_dim.getIntValues())
1243         config_attrs->Add(e.getSExtValue());
1244     };
1245     fill_dims(mlir_dims.lhs_batching_dimensions(),
1246               dims.mutable_lhs_batch_dimensions());
1247     fill_dims(mlir_dims.rhs_batching_dimensions(),
1248               dims.mutable_rhs_batch_dimensions());
1249     fill_dims(mlir_dims.lhs_contracting_dimensions(),
1250               dims.mutable_lhs_contracting_dimensions());
1251     fill_dims(mlir_dims.rhs_contracting_dimensions(),
1252               dims.mutable_rhs_contracting_dimensions());
1253     return config;
1254   };
1255 
1256   GpuGemmConfig config;
1257   BufferAllocation::Slice lhs, rhs, bias, output;
1258 
1259   if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(input.op)) {
1260     config = build_gemm_config(gemm);
1261     TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm.lhs()));
1262     TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm.rhs()));
1263     TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm.output()));
1264   } else if (auto gemm_bias =
1265                  mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
1266     config = build_gemm_config(gemm_bias);
1267     config.backend_config.set_beta(gemm_bias.beta().convertToDouble());
1268     TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm_bias.lhs()));
1269     TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm_bias.rhs()));
1270     TF_ASSIGN_OR_RETURN(bias, GetAllocationSliceForMlir(gemm_bias.bias()));
1271     TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm_bias.output()));
1272 
1273     // The bias is passed inside the output buffer. If those buffers are shared
1274     // we can just use it, otherwise copy the bias values into the output buffer
1275     // first.
1276     if (bias != output) {
1277       std::vector<std::unique_ptr<Thunk>> thunks;
1278 
1279       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1280           Thunk::ThunkInfo(),
1281           /*source_buffer=*/bias,
1282           /*destination_buffer=*/output,
1283           /*mem_size=*/ShapeUtil::ByteSizeOf(config.output_shape)));
1284       thunks.push_back(absl::make_unique<GemmThunk>(
1285           input.thunk_info, std::move(config), lhs, rhs, output,
1286           /*implements_whole_instruction=*/false));
1287       AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1288           input.thunk_info, std::move(thunks)));
1289       return Status::OK();
1290     }
1291   }
1292 
1293   AddThunkToThunkSequence(absl::make_unique<GemmThunk>(
1294       input.thunk_info, std::move(config), lhs, rhs, output,
1295       /*implements_whole_instruction=*/true));
1296   return Status::OK();
1297 }
1298 
1299 namespace {
1300 // An MLIR value and its name as defined in the ODS spec.
1301 struct NamedValue {
1302   mlir::Value value;
1303   absl::string_view name;
1304 };
1305 
1306 // Verifies that the given batch norm is well formed for thunk emission. This
1307 // requires that all statistics operands (mean, stddev etc) are F32 types and
1308 // all the non-statistics operands need to match in shape, element type, and
1309 // layout (which maps to them having the same memref type).
VerifyBatchNormForThunkEmission(mlir::ArrayRef<NamedValue> statistics_operands,mlir::ArrayRef<NamedValue> other_operands)1310 Status VerifyBatchNormForThunkEmission(
1311     mlir::ArrayRef<NamedValue> statistics_operands,
1312     mlir::ArrayRef<NamedValue> other_operands) {
1313   for (const NamedValue& v : statistics_operands) {
1314     // Note: MLIR verification will ensure that the operands of the batchnorm
1315     // LHLO are valid memref types.
1316     if (!v.value.getType().cast<mlir::MemRefType>().getElementType().isF32()) {
1317       return Unimplemented("Operand %s of batch norm should have F32 type",
1318                            v.name);
1319     }
1320   }
1321   if (other_operands.empty()) {
1322     return Status::OK();
1323   }
1324 
1325   mlir::Type first_type = other_operands.front().value.getType();
1326   absl::string_view first_name = other_operands.front().name;
1327 
1328   for (const NamedValue& v : other_operands.drop_front(1)) {
1329     if (v.value.getType() != first_type) {
1330       return Unimplemented("%s and %s for batch norm should have same types",
1331                            v.name, first_name);
1332     }
1333   }
1334 
1335   return Status::OK();
1336 }
1337 }  // namespace
1338 
EmitBatchNormThunkFromMlir(MlirEmitterInput input)1339 Status IrEmitterUnnested::EmitBatchNormThunkFromMlir(MlirEmitterInput input) {
1340   auto get_batch_norm_config = [](auto op, mlir::Value output) {
1341     CudnnBatchNormConfig config;
1342     config.output_shape = TypeToShape(output.getType());
1343     config.output_type = config.output_shape.element_type();
1344     config.epsilon = op.epsilon().convertToFloat();
1345     config.feature_index = op.feature_index();
1346     return config;
1347   };
1348 
1349   // The statistics operands for batch norm operations need to be FP32 type.
1350   // And the rest of the operands need match in shape, layout, and element type
1351   // to match.
1352   if (auto bn_train =
1353           mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormTrainingOp>(input.op)) {
1354     TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1355         /*statistics_operands=*/
1356         {{bn_train.scale(), "scale"},
1357          {bn_train.offset(), "offset"},
1358          {bn_train.batch_mean(), "batch_mean"},
1359          {bn_train.batch_stddev(), "batch_stddev"}},
1360         /*other_operands=*/
1361         {{bn_train.operand(), "operand"}, {bn_train.output(), "output"}}));
1362     TF_ASSIGN_OR_RETURN(auto operand,
1363                         GetAllocationSliceForMlir(bn_train.operand()));
1364     TF_ASSIGN_OR_RETURN(auto scale,
1365                         GetAllocationSliceForMlir(bn_train.scale()));
1366     TF_ASSIGN_OR_RETURN(auto offset,
1367                         GetAllocationSliceForMlir(bn_train.offset()));
1368 
1369     // BatchNormTraining returns a tuple of three elements: data, calculated
1370     // mean, and calculated 1/sqrt(variance + epsilon).
1371     TF_ASSIGN_OR_RETURN(auto output_data,
1372                         GetAllocationSliceForMlir(bn_train.output()));
1373     TF_ASSIGN_OR_RETURN(auto output_mean,
1374                         GetAllocationSliceForMlir(bn_train.batch_mean()));
1375     TF_ASSIGN_OR_RETURN(auto output_inv_stddev,
1376                         GetAllocationSliceForMlir(bn_train.batch_stddev()));
1377 
1378     AddThunkToThunkSequence(
1379         absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
1380             input.thunk_info,
1381             /*config=*/get_batch_norm_config(bn_train, bn_train.output()),
1382             /*operand=*/operand,
1383             /*scale=*/scale,
1384             /*offset=*/offset,
1385             /*output_data=*/output_data,
1386             /*output_mean=*/output_mean,
1387             /*output_inv_stddev=*/output_inv_stddev));
1388     return Status::OK();
1389   }
1390 
1391   if (auto bn_grad =
1392           mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) {
1393     TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1394         /*statistics_operands=*/
1395         {{bn_grad.scale(), "scale"},
1396          {bn_grad.mean(), "mean"},
1397          {bn_grad.stddev(), "stddev"},
1398          {bn_grad.grad_scale(), "grad_scale"},
1399          {bn_grad.grad_offset(), "grad_offset"}},
1400         /*other_operands=*/
1401         {{bn_grad.operand(), "operand"},
1402          {bn_grad.grad_output(), "grad_output"},
1403          {bn_grad.grad_operand(), "grad_operand"}}));
1404 
1405     TF_ASSIGN_OR_RETURN(auto operand,
1406                         GetAllocationSliceForMlir(bn_grad.operand()));
1407     TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSliceForMlir(bn_grad.scale()));
1408     TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSliceForMlir(bn_grad.mean()));
1409     TF_ASSIGN_OR_RETURN(auto inv_stddev,
1410                         GetAllocationSliceForMlir(bn_grad.stddev()));
1411     TF_ASSIGN_OR_RETURN(auto grad_output,
1412                         GetAllocationSliceForMlir(bn_grad.grad_output()));
1413 
1414     // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
1415     // grad_offset.
1416     TF_ASSIGN_OR_RETURN(auto output_grad_data,
1417                         GetAllocationSliceForMlir(bn_grad.grad_operand()));
1418     TF_ASSIGN_OR_RETURN(auto output_grad_scale,
1419                         GetAllocationSliceForMlir(bn_grad.grad_scale()));
1420     TF_ASSIGN_OR_RETURN(auto output_grad_offset,
1421                         GetAllocationSliceForMlir(bn_grad.grad_offset()));
1422 
1423     CudnnBatchNormConfig config;
1424     config.output_shape = TypeToShape(bn_grad.grad_output().getType());
1425     config.output_type = config.output_shape.element_type();
1426     config.epsilon = bn_grad.epsilon().convertToFloat();
1427     config.feature_index = bn_grad.feature_index();
1428 
1429     AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
1430         input.thunk_info,
1431         /*config=*/get_batch_norm_config(bn_grad, bn_grad.grad_output()),
1432         /*operand=*/operand,
1433         /*scale=*/scale,
1434         /*mean=*/mean,
1435         /*inv_stddev=*/inv_stddev,
1436         /*grad_output=*/grad_output,
1437         /*output_grad_data=*/output_grad_data,
1438         /*output_grad_scale=*/output_grad_scale,
1439         /*output_grad_offset=*/output_grad_offset));
1440     return Status::OK();
1441   }
1442 
1443   if (auto bn_inference =
1444           mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormInferenceOp>(input.op)) {
1445     TF_RETURN_IF_ERROR(
1446         VerifyBatchNormForThunkEmission(/*statistics_operands=*/
1447                                         {{bn_inference.scale(), "scale"},
1448                                          {bn_inference.offset(), "offset"},
1449                                          {bn_inference.mean(), "mean"},
1450                                          {bn_inference.stddev(), "stddev"}},
1451                                         /*other_operands=*/
1452                                         {{bn_inference.operand(), "operand"},
1453                                          {bn_inference.output(), "output"}}));
1454 
1455     TF_ASSIGN_OR_RETURN(auto operand,
1456                         GetAllocationSliceForMlir(bn_inference.operand()));
1457     TF_ASSIGN_OR_RETURN(auto scale,
1458                         GetAllocationSliceForMlir(bn_inference.scale()));
1459     TF_ASSIGN_OR_RETURN(auto offset,
1460                         GetAllocationSliceForMlir(bn_inference.offset()));
1461     TF_ASSIGN_OR_RETURN(auto mean,
1462                         GetAllocationSliceForMlir(bn_inference.mean()));
1463     TF_ASSIGN_OR_RETURN(auto variance,
1464                         GetAllocationSliceForMlir(bn_inference.stddev()));
1465     TF_ASSIGN_OR_RETURN(auto output,
1466                         GetAllocationSliceForMlir(bn_inference.output()));
1467 
1468     AddThunkToThunkSequence(absl::make_unique<
1469                             CudnnBatchNormForwardInferenceThunk>(
1470         input.thunk_info,
1471         /*config=*/get_batch_norm_config(bn_inference, bn_inference.output()),
1472         /*operand=*/operand,
1473         /*scale=*/scale,
1474         /*offset=*/offset,
1475         /*mean=*/mean,
1476         /*variance=*/variance,
1477         /*output=*/output));
1478     return Status::OK();
1479   }
1480 
1481   return Unimplemented("Unsupported batch norm operation");
1482 }
1483 
1484 #if GOOGLE_CUDA
EmitCholeskyThunkFromMlir(MlirEmitterInput input)1485 Status IrEmitterUnnested::EmitCholeskyThunkFromMlir(MlirEmitterInput input) {
1486   auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(input.op);
1487 
1488   const Shape shape = TypeToShape(cholesky_op.input().getType());
1489   int ndim = shape.dimensions_size();
1490   CHECK_GE(ndim, 2);
1491   int64 n = shape.dimensions(ndim - 1);
1492 
1493   const auto& dims = shape.dimensions();
1494   int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
1495                                      [](int64 a, int64 b) { return a * b; });
1496 
1497   TF_ASSIGN_OR_RETURN(auto operand_buffer,
1498                       GetAllocationSliceForMlir(cholesky_op.input()));
1499   TF_ASSIGN_OR_RETURN(auto a_buffer,
1500                       GetAllocationSliceForMlir(cholesky_op.output()));
1501   TF_ASSIGN_OR_RETURN(auto workspace_buffer,
1502                       GetAllocationSliceForMlir(cholesky_op.scratch()));
1503   TF_ASSIGN_OR_RETURN(auto info_buffer,
1504                       GetAllocationSliceForMlir(cholesky_op.info()));
1505 
1506   std::vector<std::unique_ptr<Thunk>> thunks;
1507 
1508   if (operand_buffer != a_buffer) {
1509     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1510         input.thunk_info,
1511         /*source_address=*/operand_buffer,
1512         /*destination_buffer=*/a_buffer,
1513         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
1514   }
1515 
1516   CholeskyOptions options;
1517   options.set_lower(cholesky_op.is_lower());
1518   thunks.push_back(absl::make_unique<CholeskyThunk>(
1519       input.thunk_info, options, a_buffer, workspace_buffer, info_buffer,
1520       shape.element_type(), batch_size, n));
1521 
1522   // Elide the sequential thunk if there's no copy.
1523   if (thunks.size() == 1) {
1524     AddThunkToThunkSequence(std::move(thunks[0]));
1525   } else {
1526     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1527         input.thunk_info, std::move(thunks)));
1528   }
1529 
1530   return Status::OK();
1531 }
1532 #endif  // GOOGLE_CUDA
1533 
EmitCustomCallThunkFromMlir(MlirEmitterInput input)1534 Status IrEmitterUnnested::EmitCustomCallThunkFromMlir(MlirEmitterInput input) {
1535   auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(input.op);
1536   const std::string call_target_name = custom_call.call_target_name().str();
1537 
1538   void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1539       call_target_name, std::string(platform_name()));
1540   if (call_target) {
1541     std::vector<BufferAllocation::Slice> operands;
1542     for (mlir::Value arg : custom_call.args()) {
1543       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1544                           GetAllocationSliceForMlir(arg));
1545       operands.push_back(arg_slice);
1546     }
1547 
1548     std::vector<BufferAllocation::Slice> results;
1549     for (mlir::Value output : custom_call.output()) {
1550       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1551                           GetAllocationSliceForMlir(output));
1552       results.push_back(output_slice);
1553     }
1554 
1555     AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
1556         input.thunk_info, call_target, std::move(operands), std::move(results),
1557         custom_call.backend_config().str()));
1558     return Status::OK();
1559   }
1560   return Unimplemented("No registered implementation for custom call to \"%s\"",
1561                        call_target_name);
1562 }
1563 
HandleFft(HloInstruction * fft)1564 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
1565   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fft));
1566   return EmitFftThunkFromMlir(input);
1567 }
1568 
EmitFftThunkFromMlir(MlirEmitterInput input)1569 Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) {
1570   auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(input.op);
1571   const Shape operand_shape = TypeToShape(fft_op.operand().getType());
1572   const Shape output_shape = TypeToShape(fft_op.output().getType());
1573   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
1574   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
1575 
1576   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1577                       GetAllocationSliceForMlir(fft_op.operand()));
1578   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
1579                       GetAllocationSliceForMlir(fft_op.output()));
1580   TF_ASSIGN_OR_RETURN(xla::FftType fft_type, ConvertFftType(fft_op.fft_type()));
1581   auto fft_length_values = fft_op.fft_length().getValues<int64>();
1582   std::vector<int64> fft_length(fft_length_values.begin(),
1583                                 fft_length_values.end());
1584   AddThunkToThunkSequence(
1585       absl::make_unique<FftThunk>(input.thunk_info, fft_type, fft_length,
1586                                   /*input_buffer=*/arg_slice,
1587                                   /*output_buffer=*/dest_slice,
1588                                   /*input_shape=*/operand_shape,
1589                                   /*output_shape=*/output_shape));
1590   return Status::OK();
1591 }
1592 
HandleTriangularSolve(HloInstruction * hlo)1593 Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
1594   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
1595   return EmitTriangularSolveFromMlir(input);
1596 }
1597 
EmitTriangularSolveFromMlir(MlirEmitterInput input)1598 Status IrEmitterUnnested::EmitTriangularSolveFromMlir(MlirEmitterInput input) {
1599   auto triangular_solve_op =
1600       mlir::cast<mlir::lmhlo::TriangularSolveOp>(input.op);
1601   auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) {
1602     int64_t n = layout_attr.getNumElements();
1603     return layout_attr.getValue<int64_t>({0}) == n - 2 &&
1604            layout_attr.getValue<int64_t>({1}) == n - 1;
1605   };
1606   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a()));
1607   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
1608   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
1609 
1610   const Shape b_shape = TypeToShape(triangular_solve_op.b().getType());
1611 
1612   const Shape output_shape =
1613       TypeToShape(triangular_solve_op.output().getType());
1614 
1615   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
1616                       GetAllocationSliceForMlir(triangular_solve_op.a()));
1617   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
1618                       GetAllocationSliceForMlir(triangular_solve_op.b()));
1619   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1620                       GetAllocationSliceForMlir(triangular_solve_op.output()));
1621   TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a,
1622                       ConvertTranspose(triangular_solve_op.transpose_a()));
1623 
1624   std::vector<std::unique_ptr<Thunk>> thunks;
1625 
1626   // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1627   // aren't the same buffer.
1628   if (b_slice != output_slice) {
1629     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1630         Thunk::ThunkInfo(),
1631         /*source_address=*/b_slice,
1632         /*destination_buffer=*/output_slice,
1633         /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
1634   }
1635 
1636   int64 m = b_shape.dimensions(b_shape.rank() - 2);
1637   int64 n = b_shape.dimensions(b_shape.rank() - 1);
1638   int64 batch_size = std::accumulate(b_shape.dimensions().begin(),
1639                                      b_shape.dimensions().end() - 2, int64{1},
1640                                      [](int64 a, int64 b) { return a * b; });
1641   int64 elem_size =
1642       ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type());
1643   int64 a_batch_stride =
1644       triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size;
1645   int64 b_batch_stride = m * n * elem_size;
1646   TriangularSolveOptions options;
1647   options.set_left_side(triangular_solve_op.left_side());
1648   options.set_lower(triangular_solve_op.lower());
1649   options.set_unit_diagonal(triangular_solve_op.unit_diagonal());
1650   options.set_transpose_a(transpose_a);
1651   thunks.push_back(absl::make_unique<TriangularSolveThunk>(
1652       input.thunk_info, options,
1653       /*a_input_buffer=*/a_slice,
1654       /*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size,
1655       m, n, a_batch_stride, b_batch_stride));
1656 
1657   // Elide the sequential thunk if there's no copy.
1658   if (thunks.size() == 1) {
1659     AddThunkToThunkSequence(std::move(thunks[0]));
1660   } else {
1661     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1662         input.thunk_info, std::move(thunks)));
1663   }
1664   return Status::OK();
1665 }
1666 
1667 // Convert the following form of fusion region:
1668 //   fusion() {
1669 //     %0 = tensor_load %external_memref0
1670 //     %1 = tensor_load %external_memref1
1671 //     ...
1672 //     tensor_store %ret, %external_memref2
1673 //   }
1674 // to
1675 //   fusion(%external_memref0, %external_memref1) (^bb(%0, %1) {
1676 //     ...
1677 //     mhlo.return %ret
1678 //   })
1679 //
1680 // So that it's suitable for MHLO -> XLA HLO conversion.
1681 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
1682 // instead.
ProcessFusionForConversion(mlir::Region * region,std::vector<Shape> * operand_shapes,std::vector<Shape> * output_shapes)1683 static Status ProcessFusionForConversion(mlir::Region* region,
1684                                          std::vector<Shape>* operand_shapes,
1685                                          std::vector<Shape>* output_shapes) {
1686   std::vector<mlir::TensorLoadOp> loads;
1687   std::vector<mlir::TensorStoreOp> stores;
1688 
1689   region->walk([&](mlir::TensorLoadOp load) {
1690     if (load.memref().getParentRegion() != region) {
1691       loads.push_back(load);
1692     }
1693   });
1694 
1695   region->walk([&](mlir::TensorStoreOp store) {
1696     if (store.memref().getParentRegion() != region) {
1697       stores.push_back(store);
1698     }
1699   });
1700 
1701   for (auto load : loads) {
1702     auto arg = region->addArgument(load.getType());
1703     load.replaceAllUsesWith(arg);
1704     Shape shape = TypeToShape(load.getType());
1705     if (auto attr = mlir::GetLayoutFromMlirHlo(load)) {
1706       std::vector<int64> minor_to_major;
1707       absl::c_transform(
1708           attr, std::back_inserter(minor_to_major),
1709           std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
1710       *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
1711     } else {
1712       *shape.mutable_layout() =
1713           LayoutUtil::MakeDescendingLayout(load.getType().getShape().size());
1714     }
1715     operand_shapes->push_back(std::move(shape));
1716     load.erase();
1717   }
1718 
1719   std::vector<mlir::Value> returned_values;
1720   for (auto store : stores) {
1721     Shape shape = TypeToShape(store.memref().getType());
1722     if (auto attr = mlir::GetLayoutFromMlirHlo(store)) {
1723       std::vector<int64> minor_to_major;
1724       absl::c_transform(
1725           attr, std::back_inserter(minor_to_major),
1726           std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
1727       *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
1728     }
1729     output_shapes->push_back(shape);
1730 
1731     returned_values.push_back(store.tensor());
1732     store.erase();
1733   }
1734 
1735   region->back().back().erase();
1736   auto b = mlir::OpBuilder::atBlockEnd(&region->back());
1737   auto loc = returned_values[0].getLoc();
1738   b.create<mlir::mhlo::ReturnOp>(loc, returned_values);
1739   return Status::OK();
1740 }
1741 
GetMlirEmitterInput(HloInstruction * hlo)1742 StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
1743     HloInstruction* hlo) {
1744   MlirEmitterInput input;
1745   TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_->EmitOp(hlo));
1746   input.thunk_info = GetThunkInfo(hlo);
1747   if (hlo->shape().IsTuple()) {
1748     const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
1749     auto& slice = input.extra_slice.emplace();
1750     TF_ASSIGN_OR_RETURN(slice.buffer_slice,
1751                         buffer_assignment.GetUniqueSlice(hlo, {}));
1752     slice.written = true;
1753     slice.shape = hlo->shape();
1754   }
1755   return input;
1756 }
1757 
1758 // TODO(timshen): update the comment once the HandleFusion code path deleted.
1759 //
1760 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
1761 // subclass. The logic is de-virtualized and less scattered.
EmitLoopFusionFromMlir(MlirEmitterInput input,const Shape & output_shape,absl::optional<int> unroll_factor_override)1762 Status IrEmitterUnnested::EmitLoopFusionFromMlir(
1763     MlirEmitterInput input, const Shape& output_shape,
1764     absl::optional<int> unroll_factor_override) {
1765   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(input.op);
1766   MlirEmitterContext context;
1767   context.SetOperation(fusion);
1768 
1769   std::vector<llvm_ir::IrArray> ir_arrays;
1770   Thunk* kernel_thunk;
1771   {
1772     TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk_ptr,
1773                         BuildKernelThunkForMlir(fusion, input.thunk_info,
1774                                                 input.extra_slice, &ir_arrays));
1775     kernel_thunk = kernel_thunk_ptr.get();
1776     thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr));
1777   }
1778 
1779   auto operand_arrays =
1780       absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size());
1781   auto output_element_arrays = absl::MakeSpan(ir_arrays).subspan(
1782       context.operand_shapes.size(), context.output_shapes.size());
1783   const llvm_ir::IrArray* tuple_output_array = nullptr;
1784   if (ir_arrays.size() ==
1785       context.operand_shapes.size() + context.output_shapes.size() + 1) {
1786     tuple_output_array = &ir_arrays[context.operand_shapes.size() +
1787                                     context.output_shapes.size()];
1788   }
1789 
1790   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
1791                       GetOrCreateSubComputationFromRegion(&fusion.region(),
1792                                                           /*is_fusion=*/true));
1793 
1794   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
1795                                           GetNestedComputer());
1796   FusedIrEmitter fused_emitter(&elemental_emitter);
1797 
1798   for (int i = 0; i < context.operand_shapes.size(); i++) {
1799     auto* builder = &b_;
1800     auto ir_array = operand_arrays[i];
1801     fused_emitter.BindGenerator(
1802         fused_computation->parameter_instruction(i),
1803         [builder, ir_array](llvm_ir::IrArray::Index index) {
1804           return ir_array.EmitReadArrayElement(index, builder);
1805         });
1806   }
1807   TF_ASSIGN_OR_RETURN(
1808       auto element_generator,
1809       fused_emitter.GetGenerator(fused_computation->root_instruction()));
1810 
1811   int unroll_factor;
1812   if (unroll_factor_override.has_value()) {
1813     unroll_factor = *unroll_factor_override;
1814   } else if (!MayPreventVectorization(fusion)) {
1815     unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
1816   } else {
1817     unroll_factor = 1;
1818   }
1819 
1820   bool few_waves = [fusion]() mutable {
1821     for (mlir::Operation& op : fusion.region().front()) {
1822       if (mlir::isa<mlir::TensorLoadOp, mlir::TensorStoreOp,
1823                     mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp>(op)) {
1824         continue;
1825       }
1826       HloOpcode opcode = *MhloToHloOpcode(&op);
1827       if (HloInstruction::IsOpElementwise(opcode)) {
1828         continue;
1829       }
1830       if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastOp>(op)) {
1831         if (broadcast.broadcast_sizes().size() == 0) {
1832           continue;
1833         }
1834       }
1835       return false;
1836     }
1837     return true;
1838   }();
1839 
1840   Shape element_shape = context.output_shapes[0];
1841   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1842       element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor,
1843       few_waves);
1844   UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
1845                          ir_emitter_context_->llvm_module());
1846   llvm::Type* index_type = GetIndexTypeForKernelFromMlir(
1847       fusion, launch_dimensions.launch_bound(), &b_);
1848 
1849   if (context.output_shapes.size() > 1) {
1850     // Emit the tuple pointers in one thread.  We could do this at any point in
1851     // the kernel, but we do it at the beginning in the hopes of reducing
1852     // register pressure, since we touch threadIdx.x and blockIdx.x at the
1853     // beginning of the kernel *anyway*.
1854     KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
1855       llvm_ir::EmitTuple(*tuple_output_array, output_element_arrays, &b_);
1856     });
1857     // For multioutput fusion, we need to emit each operand and the root.
1858     TF_RETURN_IF_ERROR(
1859         ParallelLoopEmitter(element_generator, output_element_arrays,
1860                             launch_dimensions, &b_, unroll_factor)
1861             .EmitLoop(context.name, index_type));
1862   } else {
1863     TF_RETURN_IF_ERROR(
1864         ParallelLoopEmitter(element_generator, output_element_arrays[0],
1865                             launch_dimensions, &b_, unroll_factor)
1866             .EmitLoop(context.name, index_type));
1867   }
1868 
1869   b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1870   return Status::OK();
1871 }
1872 
HandleFusion(HloInstruction * fusion)1873 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
1874   TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(fusion));
1875   auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
1876 
1877   if (fusion->IsInputFusion()) {
1878     switch (fusion->fused_expression_root()->opcode()) {
1879       case HloOpcode::kScatter: {
1880         TF_ASSIGN_OR_RETURN(
1881             const HloComputation* fused_computation,
1882             GetOrCreateSubComputationFromRegion(&fusion_op.region(),
1883                                                 /*is_fusion=*/true));
1884         auto* root = fused_computation->root_instruction();
1885 
1886         std::vector<std::unique_ptr<Thunk>> thunks;
1887         // The initialization from 'operand' is using different loop bounds, so
1888         // emit it in a separate kernel. Treat it like a loop fusion, writing to
1889         // the output buffer.
1890         {
1891           std::vector<llvm_ir::IrArray> ir_arrays;
1892           TF_ASSIGN_OR_RETURN(
1893               auto operand_thunk,
1894               BuildKernelThunkForMlir(mlir_input.op, Thunk::ThunkInfo(),
1895                                       mlir_input.extra_slice, &ir_arrays));
1896           thunks.push_back(std::move(operand_thunk));
1897 
1898           GpuElementalIrEmitter operand_elemental_emitter(
1899               hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1900               GetNestedComputer());
1901           FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
1902           for (int i = 0; i < fused_computation->num_parameters(); i++) {
1903             auto fused_operand = fused_computation->parameter_instruction(i);
1904             operand_fused_emitter.BindGenerator(
1905                 fused_operand, [this, &ir_arrays, i, fused_operand](
1906                                    const llvm_ir::IrArray::Index& index) {
1907                   return ir_arrays[i].EmitReadArrayElement(
1908                       index, &b_, fused_operand->name());
1909                 });
1910           }
1911           TF_ASSIGN_OR_RETURN(
1912               auto generator,
1913               operand_fused_emitter.GetGenerator(root->operand(0)));
1914 
1915           auto unroll_factor =
1916               ComputeMaxUnrollFactor(fusion_op, hlo_module_config_);
1917           const Shape& element_shape = root->shape();
1918           LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1919               element_shape, ir_emitter_context_->gpu_device_info(),
1920               unroll_factor, /*few_waves=*/false);
1921           UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
1922                                  ir_emitter_context_->llvm_module());
1923           TF_RETURN_IF_ERROR(
1924               ParallelLoopEmitter(generator, ir_arrays.back(),
1925                                   launch_dimensions, &b_, unroll_factor)
1926                   .EmitLoop(
1927                       IrName(mlir::GetNameFromLoc(fusion_op.getLoc())),
1928                       GetIndexTypeForKernelFromMlir(
1929                           fusion_op, launch_dimensions.launch_bound(), &b_)));
1930         }
1931 
1932         // Now build the actual scatter, reading and writing to the freshly
1933         // filled output buffer.
1934         {
1935           std::vector<llvm_ir::IrArray> ir_arrays;
1936           TF_ASSIGN_OR_RETURN(
1937               auto scatter_thunk,
1938               BuildKernelThunkForMlir(mlir_input.op, Thunk::ThunkInfo(),
1939                                       mlir_input.extra_slice, &ir_arrays));
1940           thunks.push_back(std::move(scatter_thunk));
1941           // Spin up a new fused emitter for the scatter kernel and emit it.
1942           GpuElementalIrEmitter scatter_elemental_emitter(
1943               hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1944               GetNestedComputer());
1945           FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
1946           for (int i = 0; i < fused_computation->num_parameters(); i++) {
1947             auto fused_operand = fused_computation->parameter_instruction(i);
1948             scatter_fused_emitter.BindGenerator(
1949                 fused_operand, [this, &ir_arrays, i, fused_operand](
1950                                    const llvm_ir::IrArray::Index& index) {
1951                   return ir_arrays[i].EmitReadArrayElement(
1952                       index, &b_, fused_operand->name());
1953                 });
1954           }
1955 
1956           TF_ASSIGN_OR_RETURN(
1957               const auto dim_numbers,
1958               lhlo_scratch_emitter_->GetScatterDimensionNumbers(root));
1959 
1960           ScatterDescriptor desc;
1961           desc.name = IrName(root);
1962           desc.operand_shape = root->operand(0)->shape();
1963           desc.scatter_indices_shape = root->operand(1)->shape();
1964           desc.updates_shape = root->operand(2)->shape();
1965           desc.dim_numbers = dim_numbers;
1966           desc.unique_indices = root->unique_indices();
1967           desc.update_computation = root->called_computations()[0];
1968           desc.output = ir_arrays.back();
1969           TF_ASSIGN_OR_RETURN(
1970               desc.scatter_indices_gen,
1971               scatter_fused_emitter.GetGenerator(root->operand(1)));
1972           TF_ASSIGN_OR_RETURN(
1973               desc.updates_gen,
1974               scatter_fused_emitter.GetGenerator(root->operand(2)));
1975           desc.get_index_type = [&](int64 launch_size) {
1976             return GetIndexTypeForKernel(root, launch_size, &b_);
1977           };
1978 
1979           TF_RETURN_IF_ERROR(EmitScatter(desc, thunks.back().get()));
1980         }
1981         AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1982             mlir_input.thunk_info, std::move(thunks)));
1983         return Status::OK();
1984       }
1985       // In the case of root tuple, it can be either reduce or slice input
1986       // fusion.
1987       case HloOpcode::kTuple: {
1988         if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) {
1989           return EmitInputFusibleNonStridedSlices(mlir_input);
1990         }
1991 
1992         CHECK_GE(mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op)
1993                      .getFusionResults()
1994                      .size(),
1995                  1);
1996         return EmitReductionFromOrToContiguousDimensions(mlir_input);
1997       }
1998       case HloOpcode::kReduce: {
1999         // HandleFusion specializes reduction from a multi-dimensional array to
2000         // a 1D array. The specialized version requires a initializer thunk that
2001         // initializes the output array to the initial value of the reduce.
2002         if (mlir_input.op->getNumResults() > 1) {
2003           // TODO(b/129089333): Support tiled vectorized variadic reduce.
2004           return Unimplemented(
2005               "Vectorized variadic reduce is not supported on GPU");
2006         }
2007         return EmitReductionFromOrToContiguousDimensions(mlir_input);
2008       }
2009       case HloOpcode::kSlice: {
2010         return EmitInputFusibleNonStridedSlices(mlir_input);
2011       }
2012       default:
2013         LOG(FATAL) << "Bad opcode for input fusion: "
2014                    << fusion->fused_expression_root()->opcode();
2015     }
2016   } else if (CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
2017                  fusion_op, ir_emitter_context_->allocations())) {
2018     // Fusion node with dynamic-update-slice as the root where the op's input
2019     // (i.e. array to update) shares the same slice as its output.  In this case
2020     // we have a special algorithm that modifies the output in place without
2021     // touching the un-updated elements.
2022     CHECK_EQ(1, GetHloOutputs(mlir_input.op).size());
2023 
2024     // Set up kernel thunk and fused ir emitter.
2025     std::vector<llvm_ir::IrArray> ir_arrays;
2026     TF_ASSIGN_OR_RETURN(
2027         auto fusion_thunk,
2028         BuildKernelThunkForMlir(fusion_op, mlir_input.thunk_info,
2029                                 mlir_input.extra_slice, &ir_arrays));
2030 
2031     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
2032                                             ir_emitter_context_->llvm_module(),
2033                                             &b_, GetNestedComputer());
2034 
2035     // Shape of the dynamic-update-slice's "update" operand.
2036     Shape update_shape = fusion->fused_expression_root()->operand(1)->shape();
2037 
2038     // Array to write into.  Because this is an in-place operation, this is the
2039     // same as operand 0's array.
2040     const IrArray& output_array = ir_arrays.back();
2041 
2042     LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2043         update_shape, ir_emitter_context_->gpu_device_info());
2044     UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(),
2045                            ir_emitter_context_->llvm_module());
2046     AddThunkToThunkSequence(std::move(fusion_thunk));
2047 
2048     FusedIrEmitter fused_emitter(&elemental_emitter);
2049 
2050     TF_ASSIGN_OR_RETURN(
2051         const HloComputation* fused_computation,
2052         GetOrCreateSubComputationFromRegion(&fusion_op.region(),
2053                                             /*is_fusion=*/true));
2054 
2055     for (int i = 0; i < fused_computation->num_parameters(); i++) {
2056       auto fused_operand = fused_computation->parameter_instruction(i);
2057       fused_emitter.BindGenerator(
2058           fused_operand, [this, &ir_arrays, i,
2059                           fused_operand](const llvm_ir::IrArray::Index& index) {
2060             return ir_arrays[i].EmitReadArrayElement(index, &b_,
2061                                                      fused_operand->name());
2062           });
2063     }
2064 
2065     return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
2066         fused_computation, output_array, &fused_emitter, launch_dimensions,
2067         &b_);
2068   }
2069 
2070   CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
2071       << ": " << fusion->ToString();
2072 
2073   TF_ASSIGN_OR_RETURN(const bool matched_021,
2074                       CheckAndEmitHloWithTile021(mlir_input));
2075   if (matched_021) {
2076     return Status::OK();
2077   }
2078 
2079   return EmitLoopFusionFromMlir(mlir_input, fusion->shape());
2080 }
2081 
HandleCopy(HloInstruction * copy)2082 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
2083   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy));
2084   return EmitCopyForMlir(input);
2085 }
2086 
EmitCopyForMlir(MlirEmitterInput input)2087 Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) {
2088   auto copy = mlir::cast<mlir::lmhlo::CopyOp>(input.op);
2089   auto operand_shape = TypeToShape(copy.operand().getType());
2090   auto output_shape = TypeToShape(copy.output().getType());
2091 
2092   CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
2093   auto maybe_slice = GetAllocationSliceForMlir(copy.operand());
2094   if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
2095       maybe_slice.ok()) {
2096     // Copy the operand into the output if it's not the same buffer already.
2097     auto operand_buffer = *maybe_slice;
2098     auto destination_buffer = *GetAllocationSliceForMlir(copy.output());
2099     if (operand_buffer != destination_buffer) {
2100       AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
2101           input.thunk_info,
2102           /*source_address=*/operand_buffer,
2103           /*destination_buffer=*/destination_buffer,
2104           /*mem_size=*/
2105           ByteSizeOf(operand_shape)));
2106     }
2107     return Status::OK();
2108   }
2109   TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input));
2110   if (matched_021) {
2111     return Status::OK();
2112   }
2113 
2114   return EmitUsingElementalIrEmitter(input);
2115 }
2116 
EmitExtraOutputsForReduce(absl::Span<const llvm_ir::IrArray> result_ir_arrays,const IrArray::Index & index,bool use_linear_index,absl::Span<const std::pair<llvm_ir::ElementGenerator,int>> extra_output_gens)2117 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
2118     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
2119     const IrArray::Index& index, bool use_linear_index,
2120     absl::Span<const std::pair<llvm_ir::ElementGenerator, int>>
2121         extra_output_gens) {
2122   // Compute all extra output values before writing them. This avoids
2123   // overwriting aliased input/output buffers before all reads occured.
2124   absl::InlinedVector<llvm::Value*, 8> extra_output_ir_values;
2125   for (int i = 0; i < extra_output_gens.size(); ++i) {
2126     TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
2127                         extra_output_gens[i].first(index));
2128     extra_output_ir_values.push_back(extra_output_ir_value);
2129   }
2130   for (int i = 0; i < extra_output_gens.size(); ++i) {
2131     result_ir_arrays[extra_output_gens[i].second].EmitWriteArrayElement(
2132         index, extra_output_ir_values[i], &b_, use_linear_index);
2133   }
2134   return Status::OK();
2135 }
2136 
HandleReduce(HloInstruction * reduce)2137 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
2138   TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(reduce));
2139   return EmitReduceFromMlir(mlir_input);
2140 }
2141 
EmitReduceFromMlir(MlirEmitterInput mlir_input)2142 Status IrEmitterUnnested::EmitReduceFromMlir(MlirEmitterInput mlir_input) {
2143   if (GetHloOutputs(mlir_input.op).size() == 1 &&
2144       IsReductionFromOrToContiguousDimensions(mlir_input.op)) {
2145     return EmitReductionFromOrToContiguousDimensions(mlir_input);
2146   }
2147 
2148   return EmitUsingElementalIrEmitter(mlir_input);
2149 }
2150 
HandleTuple(HloInstruction * tuple)2151 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
2152   // For all tuples, we expect the elements of the tuple to be directly consumed
2153   // by instructions that read from that tuple either directly, or through a
2154   // GTE instruction. This is possible we do not support "dynamic tuples" since
2155   // tuple-select is not supported. As a result, we never need to materialize a
2156   // tuple (which has a runtime representation of an array of pointers) in
2157   // memory at runtime. So there is no need to generate any code for tuples.
2158   return Status::OK();
2159 }
2160 
HandleGetTupleElement(HloInstruction *)2161 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
2162   // GetTupleElement IR is emitted in the IR context of the user instruction,
2163   // and so we do not build a kernel for GetTupleElement instructions.
2164   return Status::OK();
2165 }
2166 
AssertNonDeterminismIsOkay(const string & op_name)2167 Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
2168   if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
2169     return Unimplemented(
2170         "HLO instruction %s does not have a deterministic implementation, "
2171         "but run-to-run determinism is required by "
2172         "--xla_gpu_deterministic_ops.",
2173         op_name);
2174   }
2175   return Status::OK();
2176 }
2177 
HandleSelectAndScatter(HloInstruction * select_and_scatter)2178 Status IrEmitterUnnested::HandleSelectAndScatter(
2179     HloInstruction* select_and_scatter) {
2180   const Window& window = select_and_scatter->window();
2181   const auto* operand = select_and_scatter->operand(0);
2182   const auto* source = select_and_scatter->operand(1);
2183   const int64 rank = operand->shape().rank();
2184   CHECK_EQ(rank, source->shape().rank());
2185   CHECK_EQ(rank, window.dimensions_size());
2186 
2187   // TODO(b/31410564): Implement dilation rate for select-and-scatter.
2188   if (window_util::HasDilation(window)) {
2189     return Unimplemented(
2190         "Dilation for SelectAndScatter not implemented on GPU.");
2191   }
2192 
2193   TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(select_and_scatter->name()));
2194 
2195   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(select_and_scatter));
2196   return EmitSelectAndScatterFromMlir(input);
2197 }
2198 
EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input)2199 Status IrEmitterUnnested::EmitSelectAndScatterFromMlir(
2200     MlirEmitterInput mlir_input) {
2201   auto select_and_scatter_op =
2202       mlir::cast<mlir::lmhlo::SelectAndScatterOp>(mlir_input.op);
2203 
2204   std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc());
2205 
2206   std::vector<std::unique_ptr<Thunk>> thunks;
2207   thunks.emplace_back();
2208   TF_ASSIGN_OR_RETURN(thunks.back(),
2209                       BuildInitializerThunkForMlir(
2210                           mlir_input.op, select_and_scatter_op.init_value(),
2211                           select_and_scatter_op.out()));
2212 
2213   std::vector<llvm_ir::IrArray> ir_arrays;
2214   thunks.emplace_back();
2215   // Init value is not needed in IR emission.
2216   TF_ASSIGN_OR_RETURN(
2217       thunks.back(),
2218       BuildKernelThunkForMlir(
2219           select_and_scatter_op,
2220           {select_and_scatter_op.operand(), select_and_scatter_op.source(),
2221            select_and_scatter_op.out()},
2222           Thunk::ThunkInfo(), mlir_input.extra_slice, &ir_arrays));
2223 
2224   CHECK_EQ(ir_arrays.size(), 3);
2225   const IrArray& operand_array = ir_arrays[0];
2226   const IrArray& source_array = ir_arrays[1];
2227   const IrArray& out_array = ir_arrays[2];
2228 
2229   auto select_and_scatter_thunk = absl::make_unique<SequentialThunk>(
2230       mlir_input.thunk_info, std::move(thunks));
2231 
2232   const Shape source_shape =
2233       TypeToShape(select_and_scatter_op.source().getType());
2234   const Shape operand_shape =
2235       TypeToShape(select_and_scatter_op.operand().getType());
2236   const int64 rank = operand_shape.rank();
2237 
2238   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2239       source_shape, ir_emitter_context_->gpu_device_info());
2240   llvm::Type* index_type = GetIndexTypeForKernelFromMlir(
2241       select_and_scatter_op, launch_dimensions.launch_bound(), &b_);
2242   auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
2243     return llvm::ConstantInt::get(index_type, c);
2244   };
2245 
2246   // kSelectAndScatter is implemented as two kernel launches: the first launch
2247   // initializes the output array to the given initial value,
2248   // and the second accumulates the "source" matrix to the
2249   // selected elements in the output array. The first launch is already
2250   // implemented by the initializer thunk generated earlier, so this function
2251   // only needs to take care of the select-and-scatter part.
2252   //
2253   // Pseudo code for select-and-scatter:
2254   //
2255   // for (coordinates S in the source):  # This loop is parallel.
2256   //   initialized_flag = false
2257   //   for (coordinates W in the window):
2258   //     I = S * stride + W - pad_low
2259   //     if I within bounds of operand:
2260   //       if !(initialized_flag and select(selected_value, operand(I))):
2261   //         selected_value = operand(I)
2262   //         selected_index = I
2263   //         initialized_flag = true
2264   //   output(selected_index) = scatter(output(selected_index), source(S))
2265   auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status {
2266     // Allocate space to keep the currently selected value, its index, and a
2267     // boolean flag if the value is initialized. The initialized_flag is set
2268     // false.
2269     llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
2270         llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(),
2271                                        ir_emitter_context_->llvm_module()),
2272         "selected_value_address", &b_);
2273 
2274     llvm::Value* selected_index_address =
2275         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2276             index_type, index_typed_constant(rank), "selected_index_address",
2277             &b_);
2278 
2279     llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
2280         b_.getInt1Ty(), "initialized_flag_address", &b_);
2281     Store(b_.getInt1(false), initialized_flag_address);
2282 
2283     // Create the inner loop to iterate over the window.
2284     llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_,
2285                                       index_type);
2286 
2287     DimensionVector window_size;
2288     mlir::DenseIntElementsAttr window_dimensions =
2289         select_and_scatter_op.window_dimensions().getValue();
2290     for (const auto& dim : window_dimensions) {
2291       window_size.push_back(dim.getSExtValue());
2292       CHECK_GT(dim.getSExtValue(), 0);
2293     }
2294 
2295     const IrArray::Index window_index = window_loops.AddLoopsForShape(
2296         ShapeUtil::MakeShape(operand_shape.element_type(), window_size),
2297         "window");
2298     llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
2299                                    &b_);
2300 
2301     // Compute the operand index to visit and evaluate the condition whether the
2302     // operand index is within the bounds. The unsigned comparison includes
2303     // checking whether the operand index >= 0.
2304     std::vector<llvm::Value*> operand_multi_index(source_index.size());
2305     llvm::Value* in_bounds_condition = b_.getInt1(true);
2306 
2307     auto strides = *select_and_scatter_op.window_strides();
2308     auto paddings = *select_and_scatter_op.padding();
2309 
2310     for (auto stride_and_padding :
2311          llvm::enumerate(llvm::zip(strides, paddings))) {
2312       const int i = stride_and_padding.index();
2313       int64 stride = std::get<0>(stride_and_padding.value()).getSExtValue();
2314       int64 padding = std::get<1>(stride_and_padding.value()).getSExtValue();
2315 
2316       llvm::Value* strided_index =
2317           NSWMul(source_index[i], index_typed_constant(stride));
2318       operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
2319                                       index_typed_constant(padding));
2320       llvm::Value* index_condition = ICmpULT(
2321           operand_multi_index[i],
2322           index_typed_constant(ShapeUtil::GetDimension(operand_shape, i)));
2323       in_bounds_condition = And(in_bounds_condition, index_condition);
2324     }
2325 
2326     // Only need to do something if the operand index is within the bounds.
2327     // First check if the initialized_flag is set.
2328     llvm_ir::LlvmIfData if_in_bounds =
2329         llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
2330     llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
2331     llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
2332         Load(initialized_flag_address), "initialized", &b_);
2333 
2334     // If the initialized_flag is false, initialize the selected value and index
2335     // with the currently visiting operand.
2336     llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
2337     const auto save_operand_index = [&](const IrArray::Index& operand_index) {
2338       for (int64 i = 0; i < rank; ++i) {
2339         llvm::Value* selected_index_address_slot =
2340             InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2341         Store(operand_index[i], selected_index_address_slot);
2342       }
2343     };
2344     IrArray::Index operand_index(operand_multi_index, operand_shape,
2345                                  index_type);
2346     llvm::Value* operand_data =
2347         operand_array.EmitReadArrayElement(operand_index, &b_);
2348     Store(operand_data, selected_value_address);
2349     save_operand_index(operand_index);
2350     Store(b_.getInt1(true), initialized_flag_address);
2351 
2352     // If the initialized_flag is true, call the `select` function to
2353     // potentially update the selected value and index with the currently
2354     // visiting operand.
2355     llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
2356     llvm::Value* operand_address =
2357         operand_array.EmitArrayElementAddress(operand_index, &b_);
2358     llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2359         llvm_ir::PrimitiveTypeToIrType(PRED,
2360                                        ir_emitter_context_->llvm_module()),
2361         "select_return_buffer", &b_);
2362 
2363     TF_ASSIGN_OR_RETURN(
2364         const HloComputation* select_computation,
2365         GetOrCreateSubComputationFromRegion(&select_and_scatter_op.select(),
2366                                             /*is_fusion=*/false));
2367 
2368     TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
2369         *select_computation, {selected_value_address, operand_address},
2370         select_return_buffer));
2371     llvm::Value* result = Load(select_return_buffer);
2372 
2373     // If the 'select' function returns false, update the selected value and the
2374     // index to the currently visiting operand.
2375     llvm::Value* cond = ICmpNE(
2376         result,
2377         llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
2378                                    PRED, ir_emitter_context_->llvm_module()),
2379                                0),
2380         "boolean_predicate");
2381     llvm_ir::LlvmIfData if_select_lhs =
2382         llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
2383     llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
2384     Store(Load(operand_address), selected_value_address);
2385     save_operand_index(operand_index);
2386 
2387     // After iterating over the window elements, scatter the source element to
2388     // the selected index of the output. The value we store at the output
2389     // location is computed by calling the `scatter` function with the source
2390     // value and the current output value.
2391     llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
2392                                    &b_);
2393     std::vector<llvm::Value*> selected_multi_index;
2394     for (int64 i = 0; i < rank; ++i) {
2395       llvm::Value* selected_index_address_slot =
2396           InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2397       selected_multi_index.push_back(Load(selected_index_address_slot));
2398     }
2399     const Shape output_shape =
2400         TypeToShape(select_and_scatter_op.out().getType());
2401     llvm::Value* source_value_address =
2402         source_array.EmitArrayElementAddress(source_index, &b_);
2403     IrArray::Index selected_index(selected_multi_index, output_shape,
2404                                   operand_index.GetType());
2405     llvm::Value* output_value_address =
2406         out_array.EmitArrayElementAddress(selected_index, &b_);
2407 
2408     TF_ASSIGN_OR_RETURN(
2409         const HloComputation* scatter_computation,
2410         GetOrCreateSubComputationFromRegion(&select_and_scatter_op.scatter(),
2411                                             /*is_fusion=*/false));
2412 
2413     return EmitAtomicOperationForNestedComputation(
2414         *scatter_computation, output_value_address, source_value_address);
2415   };
2416 
2417   UpdateLaunchDimensions(
2418       launch_dimensions,
2419       // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
2420       // consisting of two thunks, an initializer KernelThunk that initializes
2421       // the output and another KernelThunk that accumulates the scattered
2422       // elements.
2423       select_and_scatter_thunk->thunks().back().get(),
2424       ir_emitter_context_->llvm_module());
2425   AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
2426   return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions,
2427                              &b_)
2428       .EmitLoop(name, index_type);
2429 }
2430 
HandleWhile(HloInstruction * xla_while)2431 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
2432   HloComputation* condition = xla_while->while_condition();
2433   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2434                condition->root_instruction()->shape().element_type() == PRED)
2435       << "While condition computation must return bool";
2436   // Build ForThunk for conformant while loops, otherwise build WhileThunk.
2437   auto config = xla_while->backend_config<WhileLoopBackendConfig>();
2438   if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
2439     TF_ASSIGN_OR_RETURN(
2440         auto thunk,
2441         BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
2442     AddThunkToThunkSequence(std::move(thunk));
2443   } else {
2444     TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while));
2445     AddThunkToThunkSequence(std::move(thunk));
2446   }
2447   return Status::OK();
2448 }
2449 
HandleRng(HloInstruction * rng)2450 Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
2451   return Unimplemented("Rng should be expanded for GPU.");
2452 }
2453 
HandleRngGetAndUpdateState(HloInstruction * rng_state)2454 Status IrEmitterUnnested::HandleRngGetAndUpdateState(
2455     HloInstruction* rng_state) {
2456   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(rng_state));
2457   return EmitRngGetAndUpdateState(input);
2458 }
2459 
EmitRngGetAndUpdateState(MlirEmitterInput mlir_input)2460 Status IrEmitterUnnested::EmitRngGetAndUpdateState(
2461     MlirEmitterInput mlir_input) {
2462   auto rng_op =
2463       mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(mlir_input.op);
2464 
2465   // Emit a kernel to increment the global state for Philox RNG algorithm.
2466   std::vector<llvm_ir::IrArray> ir_arrays;
2467   TF_ASSIGN_OR_RETURN(
2468       auto kernel_thunk,
2469       BuildKernelThunkForMlir(rng_op, rng_op.state(), mlir_input.thunk_info,
2470                               mlir_input.extra_slice, &ir_arrays));
2471   AddThunkToThunkSequence(std::move(kernel_thunk));
2472 
2473   llvm::Value* old_state =
2474       llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_);
2475 
2476   const Shape shape = TypeToShape(rng_op.state().getType());
2477 
2478   llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
2479       llvm_ir::IrArray::Index(
2480           /*linear=*/b_.getInt64(0), shape, &b_),
2481       &b_, "rng_state_address");
2482   output_address = BitCast(
2483       output_address, llvm::PointerType::get(
2484                           old_state->getType(),
2485                           output_address->getType()->getPointerAddressSpace()));
2486   Store(old_state, output_address);
2487 
2488   return Status::OK();
2489 }
2490 
HandleScatter(HloInstruction * scatter)2491 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
2492   if (!scatter->unique_indices()) {
2493     TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(scatter->name()));
2494   }
2495   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(scatter));
2496   return EmitScatterFromMlir(input);
2497 }
2498 
EmitScatterFromMlir(MlirEmitterInput mlir_input)2499 Status IrEmitterUnnested::EmitScatterFromMlir(MlirEmitterInput mlir_input) {
2500   std::vector<std::unique_ptr<Thunk>> thunks;
2501 
2502   auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(mlir_input.op);
2503 
2504   TF_ASSIGN_OR_RETURN(auto operand_buffer,
2505                       GetAllocationSliceForMlir(scatter_op.operand()));
2506   TF_ASSIGN_OR_RETURN(auto output_buffer,
2507                       GetAllocationSliceForMlir(scatter_op.output()));
2508 
2509   // Copy the operand into the output if it's not the same buffer already.
2510   if (operand_buffer != output_buffer) {
2511     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2512         Thunk::ThunkInfo(),
2513         /*source_address=*/operand_buffer,
2514         /*destination_buffer=*/output_buffer,
2515         /*mem_size=*/
2516         ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType()))));
2517   }
2518 
2519   // Create kernel thunk for all operands except the first one (`operand`). The
2520   // code generated for scatter below assumes that the input operand is already
2521   // copied into the output, so does not use it in codegen.
2522   std::vector<llvm_ir::IrArray> ir_arrays;
2523   thunks.emplace_back();
2524   TF_ASSIGN_OR_RETURN(
2525       thunks.back(),
2526       BuildKernelThunkForMlir(scatter_op, scatter_op.getOperands().drop_front(),
2527                               mlir_input.thunk_info, mlir_input.extra_slice,
2528                               &ir_arrays));
2529 
2530   CHECK_EQ(ir_arrays.size(), 3);
2531   const IrArray& scatter_indices = ir_arrays[0];
2532   const IrArray& updates = ir_arrays[1];
2533   const IrArray& output = ir_arrays[2];
2534 
2535   auto get_index_type = [&](int64 launch_size) {
2536     return GetIndexTypeForKernelFromMlir(scatter_op, launch_size, &b_);
2537   };
2538 
2539   TF_RETURN_IF_ERROR(EmitScatter(
2540       thunks.back().get(), scatter_op, output,
2541       /*scatter_indices_gen=*/
2542       [&](const IrArray::Index& index) {
2543         return scatter_indices.EmitReadArrayElement(index, &b_,
2544                                                     "scatter_index");
2545       },
2546       /*updates_gen=*/
2547       [&](const IrArray::Index& index) {
2548         return updates.EmitReadArrayElement(index, &b_, "update");
2549       },
2550       /* get_index_type=*/
2551       get_index_type));
2552 
2553   // Elide the sequential thunk if there's no copy.
2554   if (thunks.size() == 1) {
2555     AddThunkToThunkSequence(std::move(thunks[0]));
2556   } else {
2557     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
2558         mlir_input.thunk_info, std::move(thunks)));
2559   }
2560 
2561   return Status::OK();
2562 }
2563 
EmitScatter(Thunk * thunk,mlir::lmhlo::ScatterOp scatter,const llvm_ir::IrArray & output,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen,std::function<llvm::Type * (int64)> get_index_type)2564 Status IrEmitterUnnested::EmitScatter(
2565     Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
2566     const llvm_ir::IrArray& output,
2567     const llvm_ir::ElementGenerator& scatter_indices_gen,
2568     const llvm_ir::ElementGenerator& updates_gen,
2569     std::function<llvm::Type*(int64)> get_index_type) {
2570   const Shape operand_shape = TypeToShape(scatter.operand().getType());
2571   CHECK(
2572       ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape));
2573 
2574   TF_ASSIGN_OR_RETURN(
2575       const HloComputation* update_computation,
2576       GetOrCreateSubComputationFromRegion(&scatter.update_computation(),
2577                                           /*is_fusion=*/false));
2578 
2579   ScatterDescriptor desc;
2580   desc.name = mlir::GetNameFromLoc(scatter.getLoc());
2581   desc.operand_shape = operand_shape;
2582   desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType());
2583   desc.updates_shape = TypeToShape(scatter.updates().getType());
2584   desc.dim_numbers = scatter.scatter_dimension_numbers();
2585   desc.unique_indices = scatter.unique_indices();
2586   desc.update_computation = update_computation;
2587   desc.output = output;
2588   desc.scatter_indices_gen = scatter_indices_gen;
2589   desc.updates_gen = updates_gen;
2590   desc.get_index_type = get_index_type;
2591   return EmitScatter(desc, thunk);
2592 }
2593 
EmitScatter(const ScatterDescriptor & desc,Thunk * thunk)2594 Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc,
2595                                       Thunk* thunk) {
2596   if (!desc.unique_indices) {
2597     TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
2598   }
2599   auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
2600     std::vector<llvm::Value*> raw_window_multidim;
2601     std::vector<llvm::Value*> input_scatter_multidim;
2602     std::vector<int64> raw_window_bounds;
2603 
2604     // Partition the index into window indices and scatter indices.
2605     for (int64 i = 0, e = index.size(); i != e; ++i) {
2606       // For window indices also remember the window size, this comes in handy
2607       // later.
2608       if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(),
2609                                         i)) {
2610         raw_window_multidim.push_back(index[i]);
2611         raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
2612       } else {
2613         input_scatter_multidim.push_back(index[i]);
2614       }
2615     }
2616     DCHECK_EQ(raw_window_multidim.size(),
2617               desc.dim_numbers.update_window_dims().size());
2618 
2619     // Apply inserted_window_dims to the window dimensions.
2620     int64 raw_window_multidim_idx = 0;
2621     std::vector<llvm::Value*> input_window_multidim;
2622     std::vector<int64> input_window_bounds;
2623 
2624     for (int64 i = 0, e = desc.operand_shape.rank(); i != e; ++i) {
2625       if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(),
2626                                         i)) {
2627         input_window_bounds.push_back(1);  // Trivial dimension.
2628         input_window_multidim.push_back(index.GetConstantWithIndexType(0));
2629       } else {
2630         input_window_bounds.push_back(
2631             raw_window_bounds[raw_window_multidim_idx]);
2632         input_window_multidim.push_back(
2633             raw_window_multidim[raw_window_multidim_idx]);
2634         ++raw_window_multidim_idx;
2635       }
2636     }
2637     DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
2638 
2639     // Insert a 1 dimension at the end if index_vector_dim requests one.
2640     Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
2641     if (desc.dim_numbers.index_vector_dim().getInt() ==
2642         desc.scatter_indices_shape.rank()) {
2643       scatter_indices_shape_fixed.add_dimensions(1);
2644       scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
2645           desc.dim_numbers.index_vector_dim().getInt());
2646     }
2647 
2648     // Now load the indices corresponding to the current window from
2649     // scatter_indices.
2650     std::vector<llvm::Value*> raw_scatter_index_multidim =
2651         input_scatter_multidim;
2652     raw_scatter_index_multidim.insert(
2653         raw_scatter_index_multidim.begin() +
2654             desc.dim_numbers.index_vector_dim().getInt(),
2655         nullptr);
2656     llvm::Value* is_in_bounds = b_.getTrue();
2657     for (int64 i = 0,
2658                e = desc.dim_numbers.scatter_dims_to_operand_dims().size();
2659          i != e; ++i) {
2660       // Our index is stored along index_vector_dim, insert that into the lookup
2661       // index into scatter_indices.
2662       raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] =
2663           index.GetConstantWithIndexType(i);
2664       llvm_ir::IrArray::Index raw_scatter_index_index(
2665           raw_scatter_index_multidim, scatter_indices_shape_fixed,
2666           index.GetType());
2667 
2668       int64 operand_dim =
2669           desc.dim_numbers.scatter_dims_to_operand_dims().getValue<int64>(i);
2670       TF_ASSIGN_OR_RETURN(
2671           llvm::Value* const loaded_scatter_index,
2672           desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
2673               scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
2674       // And add the index to our window index. This yields the output index.
2675       llvm::Value* casted_scatter_index =
2676           IntCast(loaded_scatter_index, index.GetType(),
2677                   /*isSigned=*/true);
2678       llvm::Value* dim_offset =
2679           Add(input_window_multidim[operand_dim], casted_scatter_index);
2680       input_window_multidim[operand_dim] = dim_offset;
2681 
2682       // Also do the bounds check now.
2683       int64 max_index = desc.operand_shape.dimensions(operand_dim) -
2684                         input_window_bounds[operand_dim] + 1;
2685       // is_in_bounds = index >= 0 && index < dim_size-window_size+1
2686       //   --> index u< dim_size-window_size+1
2687       is_in_bounds =
2688           And(is_in_bounds, ICmpULT(casted_scatter_index,
2689                                     index.GetConstantWithIndexType(max_index)));
2690     }
2691 
2692     llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
2693         is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
2694     llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
2695     // All done, now just read from the calculated input from the window, and do
2696     // an atomic store to the calculated location in the output.
2697     llvm_ir::IrArray::Index input_window_index(
2698         input_window_multidim, desc.output.GetShape(), index.GetType());
2699     llvm::Value* output_address =
2700         desc.output.EmitArrayElementAddress(input_window_index, &b_);
2701     llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
2702         llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
2703                                        module_),
2704         "input_address", &b_);
2705     TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
2706                         desc.updates_gen(index));
2707     Store(input_ir_value, input_address);
2708 
2709     if (!desc.unique_indices) {
2710       return EmitAtomicOperationForNestedComputation(
2711           *desc.update_computation, output_address, input_address);
2712     } else {
2713       return EmitCallToNestedComputation(*desc.update_computation,
2714                                          {output_address, input_address},
2715                                          output_address);
2716     }
2717   };
2718 
2719   // Launch a kernel that reads every element in the updates tensor. We could
2720   // also do one kernel per window instead if bounds checks turn out to be a
2721   // bottleneck.
2722   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2723       desc.updates_shape, ir_emitter_context_->gpu_device_info());
2724   UpdateLaunchDimensions(launch_dimensions, thunk,
2725                          ir_emitter_context_->llvm_module());
2726 
2727   return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
2728                              launch_dimensions, &b_)
2729       .EmitLoop(desc.name,
2730                 desc.get_index_type(launch_dimensions.launch_bound()));
2731 }
2732 
2733 // This transformation should be migrated off. See b/171334474.
2734 StatusOr<HloComputation*>
GetOrCreateSubComputationFromRegion(mlir::Region * region,bool is_fusion)2735 IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
2736                                                        bool is_fusion) {
2737   std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
2738   if (module == nullptr) {
2739     std::vector<Shape> operand_shapes, output_shapes;
2740     if (is_fusion) {
2741       mlir::Operation* clone = region->getParentOp()->clone();
2742       region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).region();
2743       TF_RETURN_IF_ERROR(
2744           ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
2745     }
2746 
2747     xla::XlaComputation xla_computation;
2748     mlir::MlirToHloConversionOptions options;
2749     options.propagate_layouts = true;
2750     TF_RETURN_IF_ERROR(
2751         ConvertRegionToComputation(region, &xla_computation, options));
2752 
2753     if (is_fusion) {
2754       region->getParentOp()->erase();
2755     }
2756 
2757     TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
2758     TF_ASSIGN_OR_RETURN(
2759         module, HloModule::CreateFromProto(xla_computation.proto(),
2760                                            HloModuleConfig(program_shape)));
2761 
2762     if (is_fusion) {
2763       HloComputation* fused_computation = module->entry_computation();
2764 
2765       CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
2766       for (int i = 0; i < fused_computation->num_parameters(); i++) {
2767         *fused_computation->parameter_instruction(i)
2768              ->mutable_shape()
2769              ->mutable_layout() = operand_shapes[i].layout();
2770       }
2771       HloInstruction* root = fused_computation->root_instruction();
2772       // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
2773       // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
2774       // as much as possible.
2775       if (root->opcode() == HloOpcode::kTuple) {
2776         [&] {
2777           HloInstruction* real_root = nullptr;
2778           int expected_tuple_index = 0;
2779           for (HloInstruction* operand : root->operands()) {
2780             if (operand->opcode() != HloOpcode::kGetTupleElement) {
2781               return;
2782             }
2783             if (real_root == nullptr) {
2784               real_root = operand->mutable_operand(0);
2785             } else if (real_root != operand->operand(0)) {
2786               return;
2787             }
2788             if (expected_tuple_index != operand->tuple_index()) {
2789               return;
2790             }
2791             expected_tuple_index++;
2792           }
2793           fused_computation->set_root_instruction(real_root);
2794           std::vector<HloInstruction*> to_be_removed;
2795           to_be_removed.push_back(root);
2796           for (HloInstruction* operand : root->operands()) {
2797             to_be_removed.push_back(operand);
2798           }
2799           for (auto instr : to_be_removed) {
2800             TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
2801           }
2802 
2803           root = real_root;
2804         }();
2805       }
2806 
2807       if (output_shapes.size() > 1) {
2808         CHECK(root->shape().IsTuple());
2809         CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
2810 
2811         for (int i = 0; i < output_shapes.size(); i++) {
2812           *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
2813         }
2814       } else {
2815         CHECK_EQ(1, output_shapes.size());
2816         *root->mutable_shape() = output_shapes[0];
2817       }
2818     }
2819     // Post-process the generated computation:
2820     // * Sanitize constant names, so that they can be used as LLVM global
2821     // symbols.
2822     // * Propagate layouts for tuple types.
2823     for (HloComputation* computation : module->computations()) {
2824       for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
2825         if (instr->opcode() == HloOpcode::kConstant) {
2826           // Notice that IR emitters use the name of constants as LLVM symbol
2827           // names, therefore it's important to not let these constants in the
2828           // new module collide with constants in the original module by names.
2829           // Unique them by prepending the module name.
2830           //
2831           // TODO(timshen): A better solution would be to plumb the exact
2832           // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
2833           // hard because XLA builder doesn't support setting names. Revisit
2834           // this once we get rid of this function, or don't rely on the op name
2835           // (which shouldn't be the identity) to generate LLVM symbols.
2836           instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
2837               module->name() + "_" + instr->name()));
2838         }
2839         if (instr->shape().IsTuple() &&
2840             computation == module->entry_computation() &&
2841             instr != computation->root_instruction()) {
2842           return InternalError("Non-root tuple types are not handled.");
2843         }
2844       }
2845     }
2846   }
2847   return module->entry_computation();
2848 }
2849 
HandleSort(HloInstruction * sort)2850 Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
2851   TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(sort));
2852   return EmitSortFromMlir(mlir_input);
2853 }
2854 
EmitSortFromMlir(MlirEmitterInput mlir_input)2855 Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
2856   auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(mlir_input.op);
2857   MlirEmitterContext context;
2858   context.SetOperation(sort_op);
2859 
2860   std::vector<std::unique_ptr<Thunk>> thunks;
2861 
2862   const Shape& keys_shape = context.operand_shapes[0];
2863   int64 dimension_to_sort = sort_op.dimension();
2864   for (int64 i = 0; i < context.operand_shapes.size(); ++i) {
2865     // We assume that the layout of all involved operands and outputs is the
2866     // same.
2867     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
2868                                                   context.operand_shapes[i]));
2869     TF_RET_CHECK(
2870         LayoutUtil::LayoutsInShapesEqual(keys_shape, context.output_shapes[i]));
2871 
2872     // If possible, we share buffers. If that is not possible, we need to copy
2873     // the values, because the emitter does the sorting in-place.
2874     TF_ASSIGN_OR_RETURN(auto destination_buffer,
2875                         GetAllocationSliceForMlir(sort_op.output()[i]));
2876     TF_ASSIGN_OR_RETURN(auto source_address,
2877                         GetAllocationSliceForMlir(sort_op.operands()[i]));
2878     if (destination_buffer != source_address) {
2879       // TODO(b/26783907): Figure out why we never seem to share buffers for
2880       // key/value sort.
2881       VLOG(2) << context.name << " requires initial D2D copy for operand " << i;
2882       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2883           Thunk::ThunkInfo(),
2884           /*source_address=*/source_address,
2885           /*destination_buffer=*/destination_buffer,
2886           /*mem_size=*/ShapeUtil::ByteSizeOf(context.operand_shapes[i])));
2887     }
2888   }
2889 
2890   uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
2891   int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
2892   VLOG(2) << context.name << " requires " << num_stages << " stages.";
2893   CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
2894   CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
2895 
2896   // Naive C++ code for the outer loops:
2897   //
2898   // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
2899   //     ++stage) {
2900   //   int64 first_xor_mask = (1LL << (stage + 1)) - 1;
2901   //   SortInPlace(first_xor_mask);
2902   //   for (int64 mask = stage - 1; mask >= 0; --mask) {
2903   //     int64 later_xor_mask = 1LL << mask;
2904   //     SortInPlace(later_xor_mask);
2905   //   }
2906   // }
2907   //
2908   // This follows the alternative representation of the algorithm described on
2909   // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
2910   //
2911   // Each mask specifies how to derive from one position in the array the
2912   // position with which it should be compared (we calculate the xor of the
2913   // position with the mask).
2914   // As an optimization, we can move the 'mask' loop to inside the
2915   // sorting/comparison loop if the comparisons happen within a small block of
2916   // the array. To make this work, we collect all consecutive masks that are
2917   // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
2918   // Each thread then processes one tile of data.
2919 
2920   const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
2921 
2922   // If we cannot combine several xor masks together, we don't use tiling, so we
2923   // calculate the standard launch dimensions for the shape. However we only
2924   // need to iterate through ~half of the dimension to sort (rounded up to the
2925   // next highest power of 2), because each iteration compares one pair of
2926   // elements.
2927   Shape standard_iteration_shape = keys_shape;
2928   uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
2929   standard_iteration_shape.set_dimensions(dimension_to_sort,
2930                                           standard_num_iterations_in_sort_dim);
2931   LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions(
2932       standard_iteration_shape, ir_emitter_context_->gpu_device_info());
2933 
2934   // Calculate the launch dimensions for the case where we use tiling. We split
2935   // the dimension that should be sorted into tiles of size 'kTileSize'. This
2936   // means we first need to round 'dimension_to_sort_bound' up to be a multiple
2937   // of the tile size.
2938   int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
2939   Shape iteration_shape = keys_shape;
2940 
2941   // We iterate through the element pairs that should be compared.
2942   uint64 num_iterations_in_sort_dim = rounded_bound / 2;
2943   iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
2944   uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
2945 
2946   // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
2947   // block. Each thread is responsible for copying exactly two adjacent elements
2948   // into shared memory, and then does a comparison of two possibly different
2949   // elements taken from shared memory.
2950   const uint64 kThreadsPerBlock = kTileSize / 2;
2951 
2952   // Check whether we should use any tiling. We might not be able to use it if
2953   // we have not enough threads, or not enough shared memory. Also it does not
2954   // give a speedup if the tile size is < 128.
2955   int64 total_shared_memory_needed = 0;
2956   for (int64 i = 0; i < context.operand_shapes.size(); ++i) {
2957     total_shared_memory_needed +=
2958         kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
2959                         context.operand_shapes[i].element_type());
2960   }
2961   bool no_tiling =
2962       kTileSize < 128 ||
2963       kThreadsPerBlock >
2964           ir_emitter_context_->gpu_device_info().threads_per_block_limit ||
2965       total_shared_memory_needed >
2966           ir_emitter_context_->gpu_device_info().shared_memory_per_block;
2967   VLOG(2) << absl::StreamFormat(
2968       "%s %s use tiling. No tiling if any of the following is true: "
2969       "kTileSize=%d < 128, "
2970       "kThreadsPerBlock=%d > threads_per_block_limit=%d, "
2971       "total_shared_memory_needed=%d > shared_memory_per_block=%d",
2972       context.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
2973       ir_emitter_context_->gpu_device_info().threads_per_block_limit,
2974       total_shared_memory_needed,
2975       ir_emitter_context_->gpu_device_info().shared_memory_per_block);
2976 
2977   uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
2978   LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
2979   VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
2980                                 context.name, num_blocks, kThreadsPerBlock);
2981 
2982   std::vector<llvm_ir::IrArray> ir_arrays;
2983   auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
2984     VLOG(2) << absl::StreamFormat(
2985         "%s uses kernel for xor masks [%s]", context.name,
2986         absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) {
2987           absl::StrAppendFormat(out, "0x%x", xor_mask);
2988         }));
2989     thunks.emplace_back();
2990     TF_ASSIGN_OR_RETURN(
2991         thunks.back(),
2992         BuildKernelThunkForMlir(sort_op, sort_op.output(), Thunk::ThunkInfo(),
2993                                 mlir_input.extra_slice, &ir_arrays));
2994     LaunchDimensions launch_dimensions = xor_masks.size() > 1
2995                                              ? tiled_launch_dimensions
2996                                              : standard_launch_dimensions;
2997     UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
2998                            ir_emitter_context_->llvm_module());
2999     std::vector<IrArray> values_arrays;
3000     values_arrays.reserve(context.operand_shapes.size());
3001     for (int64 i = 0; i < context.operand_shapes.size(); ++i) {
3002       values_arrays.push_back(ir_arrays[i]);
3003     }
3004     TF_ASSIGN_OR_RETURN(const HloComputation* comparator,
3005                         GetOrCreateSubComputationFromRegion(
3006                             &sort_op.comparator(), /*is_fusion=*/false));
3007     return llvm_ir::EmitSortInPlace(
3008         dimension_to_sort, values_arrays, IrName(context.name), xor_masks, &b_,
3009         launch_dimensions,
3010         xor_masks.size() > 1 ? num_iterations_in_sort_dim
3011                              : standard_num_iterations_in_sort_dim,
3012         kTileSize,
3013         [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
3014           return EmitCallToNestedComputation(*comparator, operands, output);
3015         });
3016   };
3017   std::vector<int64> xor_masks;
3018   for (int64 stage = 0; stage < num_stages; ++stage) {
3019     for (int64 mask = stage; mask >= 0; --mask) {
3020       int64 xor_mask;
3021       if (mask == stage) {
3022         xor_mask = (1LL << (stage + 1)) - 1;
3023       } else {
3024         xor_mask = 1LL << mask;
3025       }
3026       if (xor_mask >= kTileSize || no_tiling) {
3027         if (!xor_masks.empty()) {
3028           TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
3029           xor_masks.clear();
3030         }
3031         TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
3032       } else {
3033         xor_masks.push_back(xor_mask);
3034       }
3035     }
3036   }
3037   if (!xor_masks.empty()) {
3038     TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
3039   }
3040   VLOG(2) << absl::StreamFormat(
3041       "%s requires %d thunks (including any D2D copies)", context.name,
3042       thunks.size());
3043 
3044   AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
3045       mlir_input.thunk_info, std::move(thunks)));
3046   return Status::OK();
3047 }
3048 
3049 template <typename ThunkType, typename OpT>
EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input)3050 Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
3051     MlirEmitterInput input) {
3052   auto op = mlir::cast<OpT>(input.op);
3053   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
3054                       GetAllocationSliceForMlir(op.getOperand()));
3055   AddThunkToThunkSequence(
3056       absl::make_unique<ThunkType>(input.thunk_info, result_slice));
3057   return Status::OK();
3058 }
3059 
HandleReplicaId(HloInstruction * hlo)3060 Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
3061   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3062   return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
3063                                           mlir::lmhlo::ReplicaIdOp>(input);
3064 }
3065 
HandlePartitionId(HloInstruction * hlo)3066 Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
3067   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3068   return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
3069                                           mlir::lmhlo::PartitionIdOp>(input);
3070 }
3071 
HandleCollectivePermute(HloInstruction * hlo)3072 Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
3073   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3074   return EmitCollectivePermuteFromMlir(input);
3075 }
3076 
EmitCollectivePermuteFromMlir(MlirEmitterInput input)3077 Status IrEmitterUnnested::EmitCollectivePermuteFromMlir(
3078     MlirEmitterInput input) {
3079   auto collective_permute_op =
3080       mlir::cast<mlir::lmhlo::CollectivePermuteOp>(input.op);
3081   if (collective_permute_op.channel_id())
3082     return Unimplemented("collective permute with channel_id not implemented");
3083   using source_dest_pairs_t = std::vector<std::pair<int64, int64>>;
3084   TF_ASSIGN_OR_RETURN(
3085       source_dest_pairs_t source_dest_pairs,
3086       ConvertNx2Attribute(collective_permute_op.source_target_pairs()));
3087 
3088   TF_ASSIGN_OR_RETURN(
3089       BufferAllocation::Slice source_slice,
3090       GetAllocationSliceForMlir(collective_permute_op.operand()));
3091   TF_ASSIGN_OR_RETURN(
3092       BufferAllocation::Slice result_slice,
3093       GetAllocationSliceForMlir(collective_permute_op.output()));
3094 
3095   AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
3096       input.thunk_info, std::move(source_dest_pairs), source_slice,
3097       result_slice));
3098   return Status::OK();
3099 }
3100 
3101 template <typename NcclThunkType, typename OpTy>
EmitNcclThunkFromMlir(MlirEmitterInput input)3102 Status IrEmitterUnnested::EmitNcclThunkFromMlir(MlirEmitterInput input) {
3103   OpTy op = mlir::cast<OpTy>(input.op);
3104   int64 replica_count = hlo_module_config_.replica_count();
3105   VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count
3106           << "; operand count: " << op.operands().size()
3107           << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled();
3108 
3109   // Note the replica_count == 1 case is handled via device-to-device copy
3110   // below.
3111   bool should_use_nccl_thunk =
3112       replica_count > 1 && NcclThunkType::CanImplement(op);
3113 
3114   // Stash relevant information in NcclAllGatherThunk::Buffer even if we may
3115   // not generate an NcclAllGatherThunk.
3116   std::vector<NcclCollectiveThunk::Buffer> buffers;
3117   buffers.reserve(op.operands().size());
3118   for (auto it : llvm::zip(op.operands(), op.results())) {
3119     mlir::Value operand = std::get<0>(it);
3120     mlir::Value result = std::get<1>(it);
3121     const Shape shape = TypeToShape(operand.getType());
3122     TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSliceForMlir(operand));
3123     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(result));
3124     buffers.push_back(NcclCollectiveThunk::Buffer{
3125         /*element_count=*/ShapeUtil::ElementsIn(shape),
3126         /*source_buffer=*/source_slice,
3127         /*destination_buffer=*/dest_slice});
3128   }
3129 
3130   if (should_use_nccl_thunk) {
3131     auto nccl_thunk =
3132         absl::make_unique<NcclThunkType>(input.thunk_info, op, replica_count,
3133                                          /*buffers=*/std::move(buffers));
3134     AddThunkToThunkSequence(std::move(nccl_thunk));
3135     return Status::OK();
3136   }
3137 
3138   if (replica_count != 1) {
3139     string message = absl::StrFormat(
3140         "Requested %s not implemented on GPU; replica_count: %d; "
3141         "operand_count: %d; NCCL support: %d",
3142         NcclThunkType::GetName(), replica_count, op.operands().size(),
3143         NcclThunkType::NcclIsEnabled());
3144     if (!op.operands().empty()) {
3145       const Shape shape = TypeToShape(op.operands().front().getType());
3146       absl::StrAppendFormat(&message, "; first operand array element-type: %s",
3147                             PrimitiveType_Name(shape.element_type()));
3148     }
3149     return Unimplemented("%s", message);
3150   }
3151 
3152   // All-gather with one replica is simply the identity function. Buffer
3153   // assignment expects a copy, so that's what we do.
3154   std::vector<std::unique_ptr<Thunk>> thunks;
3155   for (int64 i = 0; i < buffers.size(); i++) {
3156     const Shape shape = TypeToShape(op.operands()[i].getType());
3157     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
3158         buffers.size() == 1 ? input.thunk_info : Thunk::ThunkInfo(),
3159         /*source_address=*/buffers[i].source_buffer,
3160         /*destination_buffer=*/buffers[i].destination_buffer,
3161         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
3162   }
3163   if (thunks.size() == 1) {
3164     AddThunkToThunkSequence(std::move(thunks[0]));
3165   } else {
3166     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
3167         input.thunk_info, std::move(thunks)));
3168   }
3169   return Status::OK();
3170 }
3171 
HandleAllGather(HloInstruction * hlo)3172 Status IrEmitterUnnested::HandleAllGather(HloInstruction* hlo) {
3173   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3174   return EmitNcclThunkFromMlir<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(
3175       input);
3176 }
3177 
HandleAllReduce(HloInstruction * hlo)3178 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* hlo) {
3179   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3180   return EmitNcclThunkFromMlir<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(
3181       input);
3182 }
3183 
HandleAllToAll(HloInstruction * hlo)3184 Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
3185   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3186   return EmitNcclThunkFromMlir<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(
3187       input);
3188 }
3189 
HandleInfeed(HloInstruction * xla_infeed)3190 Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
3191   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
3192 
3193   auto infeed_op = mlir::cast<mlir::lmhlo::InfeedOp>(input.op);
3194 
3195   std::vector<ShapedSlice> dest_slices;
3196   dest_slices.reserve(infeed_op.outputs().size());
3197 
3198   for (mlir::Value output : infeed_op.outputs()) {
3199     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
3200     const Shape& shape = TypeToShape(output.getType());
3201     dest_slices.push_back(ShapedSlice{slice, shape});
3202   }
3203 
3204   AddThunkToThunkSequence(
3205       absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
3206   return Status::OK();
3207 }
3208 
HandleOutfeed(HloInstruction * outfeed)3209 Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
3210   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(outfeed));
3211 
3212   auto outfeed_op = mlir::cast<mlir::lmhlo::OutfeedOp>(input.op);
3213 
3214   std::vector<ShapedSlice> source_slices;
3215   source_slices.reserve(outfeed_op.operands().size());
3216 
3217   for (mlir::Value operand : outfeed_op.operands()) {
3218     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
3219     const Shape& shape = TypeToShape(operand.getType());
3220     source_slices.push_back(ShapedSlice{slice, shape});
3221   }
3222 
3223   AddThunkToThunkSequence(absl::make_unique<OutfeedThunk>(
3224       input.thunk_info, std::move(source_slices)));
3225   return Status::OK();
3226 }
3227 
HandleAfterAll(HloInstruction * after_all)3228 Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
3229   return Status::OK();
3230 }
3231 
3232 // Figures out how to access the buffers for all subshapes of hlo's operands and
3233 // for hlo itself (i.e. all the buffers produced by HLO).
3234 //
3235 // Returns a vector of `HloBufferSlice`s, one for each HLO subshape `hlo` needs
3236 // to access (including one or more for itself).
3237 //
3238 // This function conservatively assumes that we'll touch all sub-buffers of
3239 // every operand and of the output.
GetHloBufferSlices(const HloInstruction * hlo,const BufferAssignment & buffer_assn)3240 static std::vector<HloBufferSlice> GetHloBufferSlices(
3241     const HloInstruction* hlo, const BufferAssignment& buffer_assn) {
3242   std::vector<HloBufferSlice> result;
3243   absl::flat_hash_set<std::pair<const HloInstruction*, ShapeIndex>>
3244       inserted_buffer_slices;
3245 
3246   // Tries to find a slice plus an array of indices i1, ..., iN such that the
3247   // sub-buffer for instr at index can be found at slice[i1]...[iN].
3248   auto find_slice_for = [&](const HloInstruction* instr,
3249                             const ShapeIndex& index)
3250       -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
3251     // Simple, common case: Is the buffer for instr known at runtime?  If so,
3252     // we're done.
3253     auto slice = buffer_assn.GetUniqueSlice(instr, index);
3254     if (slice.ok()) {
3255       return {{slice.ValueOrDie(), ShapeIndex()}};
3256     }
3257 
3258     // If that didn't work, walk up any bitcasts that we might see.  These must
3259     // appear before any GTE instructions, because it's illegal to bitcast to a
3260     // tuple type.
3261     const HloInstruction* parent = instr;
3262     while (parent->IsEffectiveBitcast()) {
3263       parent = parent->operand(0);
3264 
3265       auto slice = buffer_assn.GetUniqueSlice(parent, {});
3266       if (slice.ok()) {
3267         return {{slice.ValueOrDie(), ShapeIndex()}};
3268       }
3269     }
3270 
3271     // Check whether instr is a GTE instruction.  If it is, see if we can get a
3272     // buffer for its parent, and continue walking up parents until we find a
3273     // defined buffer or we hit something that's not a GTE.
3274     ShapeIndex gte_indices;
3275     while (parent->opcode() == HloOpcode::kGetTupleElement) {
3276       gte_indices.push_front(parent->tuple_index());
3277       parent = parent->operand(0);
3278 
3279       auto slice = buffer_assn.GetUniqueSlice(parent, {});
3280       if (slice.ok()) {
3281         return {{slice.ValueOrDie(), gte_indices}};
3282       }
3283     }
3284 
3285     // Finally, if we don't know the buffer for instr at index, see if we know
3286     // the buffer for instr at index without its last element.  If so, we can
3287     // dynamically find the buffer for instr by dereferencing a pointer in that
3288     // buffer.  Continue looking this way until we run out of elements in
3289     // 'index'.
3290     //
3291     // We can almost always get a buffer without resorting to this.  The only
3292     // exception is for cases where the relevant sub-buffer is truly unknowable,
3293     // for example the sub-buffer of a tuple-shaped select.
3294     ShapeIndex new_index = index;
3295     while (!new_index.empty()) {
3296       gte_indices.push_front(new_index.back());
3297       new_index.pop_back();
3298       auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
3299       if (slice.ok()) {
3300         return {{slice.ValueOrDie(), gte_indices}};
3301       }
3302     }
3303 
3304     return nullopt;
3305   };
3306 
3307   // Adds entries for all subshapes of instr to `slices`.
3308   auto add_slices_for = [&](const HloInstruction* instr) {
3309     ShapeUtil::ForEachSubshape(
3310         instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
3311           if (!inserted_buffer_slices.insert({instr, index}).second) {
3312             // HLOs can have duplicate operands; don't bother redoing work.
3313             return;
3314           }
3315           auto maybe_slice = find_slice_for(instr, index);
3316           if (maybe_slice.has_value()) {
3317             HloBufferSlice hlo_buffer_slice;
3318             hlo_buffer_slice.instr = instr;
3319             hlo_buffer_slice.hlo_index = index;
3320             hlo_buffer_slice.buffer_slice = maybe_slice->first;
3321             hlo_buffer_slice.gte_index = maybe_slice->second;
3322             result.push_back(hlo_buffer_slice);
3323           } else {
3324             VLOG(1) << "Couldn't find buffer for " << instr->ToString()
3325                     << " at index " << index.ToString();
3326           }
3327         });
3328   };
3329 
3330   add_slices_for(hlo);
3331   for (const HloInstruction* operand : hlo->operands()) {
3332     // Conservatively assume we'll need the buffers for all subshapes of the
3333     // operand.
3334     add_slices_for(operand);
3335   }
3336 
3337   return result;
3338 }
3339 
3340 std::unique_ptr<KernelThunk>
BuildKernelThunkFromBufferSlices(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const BufferSlice * const> slices,std::function<void (const BufferSlice *,llvm::Value *)> bind_slice_to_ir_value)3341 IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
3342     absl::string_view name, Thunk::ThunkInfo thunk_info,
3343     absl::Span<const BufferSlice* const> slices,
3344     std::function<void(const BufferSlice*, llvm::Value*)>
3345         bind_slice_to_ir_value) {
3346   // Figure out which buffer allocations need to be passed as arguments to our
3347   // kernel.  This is simply all of the allocations referenced in slices,
3348   // plus the XLA temp buffer (if we have it).  We always include the temp
3349   // buffer because even if the kernel itself doesn't use it, a nested
3350   // subcomputation within the kernel (e.g. a kMap's computation) might.
3351   std::unordered_set<const BufferAllocation*> buffers_needed;
3352   for (auto* slice : slices) {
3353     buffers_needed.insert(slice->buffer_slice.allocation());
3354   }
3355   absl::optional<const BufferAllocation*> temp_buffer;
3356   for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
3357     if (alloc.IsPreallocatedTempBuffer()) {
3358       if (!temp_buffer.has_value()) {
3359         // Retrieve the first seen temp buffer.
3360         temp_buffer = &alloc;
3361       }
3362     }
3363   }
3364   if (temp_buffer.has_value()) {
3365     buffers_needed.insert(*temp_buffer);
3366   }
3367 
3368   // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
3369   // this order.
3370   std::vector<const BufferAllocation*> non_constant_buffers;
3371   absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
3372                   [](const BufferAllocation* allocation) {
3373                     return !allocation->is_constant();
3374                   });
3375 
3376   absl::c_sort(non_constant_buffers,
3377                [](const BufferAllocation* a, const BufferAllocation* b) {
3378                  return a->index() < b->index();
3379                });
3380 
3381   llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
3382 
3383   // Build a map from a BufferAllocation to the corresponding argument in our
3384   // kernel.
3385   std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
3386   {
3387     auto arg_it = kernel->arg_begin();
3388     auto buffers_it = non_constant_buffers.begin();
3389     for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
3390       kernel_args[*buffers_it] = arg_it;
3391 
3392       // Annotate all allocations with LLVM's `noalias`.
3393       // There are three kinds of allocations:
3394       // * Read-only allocations, aka input parameters that are not aliased with
3395       // outputs.
3396       // * Read-write allocations, including all output buffers, some of which
3397       // may alias with input HLO parameters, but aliased HLO buffers are always
3398       // assigned with the same allocation.
3399       // * The temp buffer.
3400       //
3401       // Read-only allocations may overlap with each other, but since they are
3402       // not mutated, they can always be annotated with `noalias` per LLVM
3403       // semantics.
3404       //
3405       // Read-write allocations and the temp buffer don't overlap with any
3406       // allocations, therefore they can also be annotated with `noalias`.
3407       kernel->addParamAttr(
3408           arg_it->getArgNo(),
3409           llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
3410     }
3411   }
3412 
3413   // For each buffer our kernel might want to touch, bind it to a value derived
3414   // from our kernel args.
3415   for (auto* slice : slices) {
3416     const BufferAllocation::Slice& buffer_slice = slice->buffer_slice;
3417     const ShapeIndex& gte_index = slice->gte_index;
3418 
3419     llvm::Value* loc;
3420     if (buffer_slice.allocation()->is_constant()) {
3421       loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
3422           llvm_ir::ConstantBufferAllocationToGlobalName(
3423               *buffer_slice.allocation()));
3424       CHECK_NE(loc, nullptr);
3425     } else {
3426       loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
3427                         {b_.getInt64(buffer_slice.offset())});
3428     }
3429 
3430     // If gte_index is nonempty, we have to dereference `loc` to get to the
3431     // value we're ultimately interested in.
3432     llvm::Type* int8_double_pointer =
3433         llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
3434     for (int64 idx : gte_index) {
3435       loc = b_.CreatePointerBitCastOrAddrSpaceCast(loc, int8_double_pointer);
3436       loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
3437     }
3438 
3439     bind_slice_to_ir_value(slice, loc);
3440   }
3441 
3442   // Bind the temp buffer so that nested subcomputations can find it if they
3443   // need.
3444   if (temp_buffer.has_value()) {
3445     bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
3446   } else {
3447     bindings_.SetTempBufferBase(
3448         llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
3449   }
3450 
3451   return absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
3452                                         std::string(kernel->getName()));
3453 }
3454 
BuildKernelThunk(const HloInstruction * inst,bool implements_whole_instruction)3455 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
3456     const HloInstruction* inst, bool implements_whole_instruction) {
3457   std::vector<HloBufferSlice> hlo_slices =
3458       GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment());
3459 
3460   std::vector<BufferSlice*> slice_ptrs;
3461   slice_ptrs.reserve(hlo_slices.size());
3462   for (auto& slice : hlo_slices) {
3463     slice_ptrs.push_back(&slice);
3464   }
3465 
3466   return BuildKernelThunkFromBufferSlices(
3467       inst->name(),
3468       implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
3469       slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) {
3470         const HloBufferSlice* hlo_buffer_slice =
3471             static_cast<const HloBufferSlice*>(slice);
3472         const HloInstruction* instr = hlo_buffer_slice->instr;
3473         const ShapeIndex& index = hlo_buffer_slice->hlo_index;
3474         VLOG(3) << "Buffer for " << instr->ToString() << " at "
3475                 << index.ToString() << " is found in slice "
3476                 << hlo_buffer_slice->buffer_slice.ToString() << " at GTE index "
3477                 << hlo_buffer_slice->gte_index.ToString();
3478 
3479         bindings_.BindHloToIrValue(*instr, value, index);
3480       });
3481 }
3482 
BuildKernelThunkForMlirImpl(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const MlirBufferSlice> slices,std::vector<llvm_ir::IrArray> * ir_arrays)3483 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkForMlirImpl(
3484     absl::string_view name, Thunk::ThunkInfo thunk_info,
3485     absl::Span<const MlirBufferSlice> slices,
3486     std::vector<llvm_ir::IrArray>* ir_arrays) {
3487   absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
3488   std::vector<const BufferSlice*> slice_ptrs;
3489   slice_ptrs.reserve(slices.size());
3490   for (auto& slice : slices) {
3491     slice_ptrs.push_back(&slice);
3492     if (slice.written) {
3493       buffers_written.insert(slice.buffer_slice);
3494     }
3495   }
3496 
3497   ir_arrays->clear();
3498   return BuildKernelThunkFromBufferSlices(
3499       name, thunk_info, slice_ptrs,
3500       [&](const BufferSlice* slice, llvm::Value* value) {
3501         const auto& mlir_slice = static_cast<const MlirBufferSlice&>(*slice);
3502 
3503         llvm_ir::IrArray ir_array(
3504             CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape);
3505         if (!buffers_written.contains(slice->buffer_slice)) {
3506           ir_array.MarkInvariantOverWholeProgram(&value->getContext());
3507         }
3508 
3509         ir_arrays->push_back(ir_array);
3510       });
3511 }
3512 
3513 StatusOr<std::unique_ptr<KernelThunk>>
BuildKernelThunkForMlir(mlir::Operation * op,mlir::ValueRange operands,Thunk::ThunkInfo thunk_info,absl::optional<MlirBufferSlice> extra_slice,std::vector<llvm_ir::IrArray> * ir_arrays)3514 IrEmitterUnnested::BuildKernelThunkForMlir(
3515     mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info,
3516     absl::optional<MlirBufferSlice> extra_slice,
3517     std::vector<llvm_ir::IrArray>* ir_arrays) {
3518   TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op));
3519 
3520   std::vector<MlirBufferSlice> slices;
3521   for (mlir::Value operand : operands) {
3522     slices.emplace_back();
3523     auto& slice = slices.back();
3524     TF_ASSIGN_OR_RETURN(slice.buffer_slice, GetAllocationSliceForMlir(operand));
3525     slice.written = WritesMlirBuffer(op, operand);
3526     slice.shape = TypeToShape(operand.getType());
3527   }
3528   if (extra_slice) {
3529     slices.push_back(*extra_slice);
3530   }
3531   std::string name = mlir::GetNameFromLoc(op->getLoc());
3532   return BuildKernelThunkForMlirImpl(name, thunk_info, slices, ir_arrays);
3533 }
3534 
3535 StatusOr<std::unique_ptr<KernelThunk>>
BuildKernelThunkForMlir(mlir::Operation * op,Thunk::ThunkInfo thunk_info,absl::optional<MlirBufferSlice> extra_slice,std::vector<llvm_ir::IrArray> * ir_arrays)3536 IrEmitterUnnested::BuildKernelThunkForMlir(
3537     mlir::Operation* op, Thunk::ThunkInfo thunk_info,
3538     absl::optional<MlirBufferSlice> extra_slice,
3539     std::vector<llvm_ir::IrArray>* ir_arrays) {
3540   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
3541     auto operands = GetHloOperands(op);
3542     auto outputs = GetHloOutputs(op);
3543 
3544     std::vector<MlirBufferSlice> slices;
3545     for (auto operand : operands) {
3546       slices.emplace_back();
3547       auto& slice = slices.back();
3548       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3549                           GetAllocationSliceForMlir(operand));
3550       slice.written = false;
3551       slice.shape = TypeToShape(operand.getType());
3552     }
3553     for (auto output : outputs) {
3554       slices.emplace_back();
3555       auto& slice = slices.back();
3556       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3557                           GetAllocationSliceForMlir(output));
3558       slice.written = true;
3559       slice.shape = TypeToShape(output.getType());
3560     }
3561     std::string name = mlir::GetNameFromLoc(op->getLoc());
3562     if (extra_slice) {
3563       slices.push_back(*extra_slice);
3564     }
3565     return BuildKernelThunkForMlirImpl(name, thunk_info, slices, ir_arrays);
3566   }
3567   return BuildKernelThunkForMlir(op, op->getOperands(), thunk_info, extra_slice,
3568                                  ir_arrays);
3569 }
3570 
BuildConstantInitializerThunk(absl::Span<const uint8> init_value,const BufferAllocation::Slice & dest,const Shape & output_shape)3571 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
3572     absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
3573     const Shape& output_shape) {
3574   int64 num_bytes = init_value.size();
3575   if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) {
3576     return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
3577   }
3578 
3579   // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
3580   // repeating the literal 4 or 2 times, so long as the destination buffer is
3581   // an even multiple of 32 bits long.
3582   if ((num_bytes == 1 || num_bytes == 2) &&
3583       ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
3584     uint16 pattern16;
3585     if (num_bytes == 1) {
3586       uint8 b = init_value.front();
3587       pattern16 = uint16{b} | (uint16{b} << 8);
3588     } else {
3589       memcpy(&pattern16, init_value.data(), sizeof(pattern16));
3590     }
3591     uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
3592     return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
3593                                                     pattern32, dest);
3594   }
3595 
3596   // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
3597   // memset so long as all 32-bit words of the scalar are equal to each other.
3598   if (num_bytes >= 4 && num_bytes % 4 == 0 &&
3599       memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
3600           0) {
3601     uint32 word;
3602     memcpy(&word, init_value.data(), sizeof(word));
3603     return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
3604                                                     dest);
3605   }
3606 
3607   return nullptr;
3608 }
3609 
3610 StatusOr<std::unique_ptr<Thunk>>
TryBuildConstantInitializerThunk(mlir::Value init_value,mlir::Value dest)3611 IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
3612                                                     mlir::Value dest) {
3613   mlir::DenseElementsAttr const_init;
3614   if (auto get_global_memref = mlir::dyn_cast_or_null<mlir::GetGlobalMemrefOp>(
3615           init_value.getDefiningOp())) {
3616     auto global_memref =
3617         mlir::SymbolTable::lookupNearestSymbolFrom<mlir::GlobalMemrefOp>(
3618             get_global_memref, get_global_memref.name());
3619     if (global_memref.constant() && global_memref.initial_value()) {
3620       // If the initial value happens to be a constant, generate a specialized
3621       // thunk.
3622       const_init = global_memref.initial_value()
3623                        .getValue()
3624                        .cast<mlir::DenseElementsAttr>();
3625     }
3626   } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstOp>(
3627                  init_value.getDefiningOp())) {
3628     const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>();
3629   }
3630 
3631   if (const_init) {
3632     std::vector<uint8> literal_bytes;
3633     TF_RETURN_IF_ERROR(
3634         CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
3635 
3636     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(dest));
3637 
3638     const Shape dest_shape = TypeToShape(dest.getType());
3639     auto thunk =
3640         BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
3641     if (thunk) {
3642       return {std::move(thunk)};
3643     }
3644   }
3645   return std::unique_ptr<Thunk>();
3646 }
3647 
3648 StatusOr<std::unique_ptr<Thunk>>
BuildInitializerThunkForMlir(mlir::Operation * op,mlir::Value init_value,mlir::Value dest)3649 IrEmitterUnnested::BuildInitializerThunkForMlir(mlir::Operation* op,
3650                                                 mlir::Value init_value,
3651                                                 mlir::Value dest) {
3652   // initial value must be a scalar memref.
3653   auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
3654   TF_RET_CHECK(init_type.getRank() == 0);
3655 
3656   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3657                       TryBuildConstantInitializerThunk(init_value, dest));
3658   if (constant_init_thunk) {
3659     return {std::move(constant_init_thunk)};
3660   }
3661 
3662   // Otherwise fall back to our slow initializer code. The thunk in this case
3663   // will just need the IR arrays for the initial value and the destination.
3664   std::vector<llvm_ir::IrArray> ir_arrays;
3665   TF_ASSIGN_OR_RETURN(
3666       std::unique_ptr<KernelThunk> kernel_thunk,
3667       BuildKernelThunkForMlir(op, {init_value, dest}, Thunk::ThunkInfo(), {},
3668                               &ir_arrays));
3669   const llvm_ir::IrArray init_array = ir_arrays[0];
3670   const llvm_ir::IrArray dest_array = ir_arrays[1];
3671 
3672   const Shape dest_shape = TypeToShape(dest.getType());
3673   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
3674       dest_shape, ir_emitter_context_->gpu_device_info());
3675   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3676                          ir_emitter_context_->llvm_module());
3677 
3678   std::string name = mlir::GetNameFromLoc(op->getLoc());
3679   TF_RETURN_IF_ERROR(ParallelLoopEmitter(
3680                          [=](const IrArray::Index& index) {
3681                            return init_array.EmitReadArrayElement(index, &b_);
3682                          },
3683                          dest_array, launch_dimensions, &b_)
3684                          .EmitLoop(mlir::GetNameFromLoc(op->getLoc())));
3685 
3686   // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
3687   return {std::move(kernel_thunk)};
3688 }
3689 
3690 StatusOr<std::unique_ptr<Thunk>>
BuildFusedInitializerThunkForMlir(mlir::lmhlo::FusionOp fusion,int output_index)3691 IrEmitterUnnested::BuildFusedInitializerThunkForMlir(
3692     mlir::lmhlo::FusionOp fusion, int output_index) {
3693   auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
3694       fusion.getFusionResults()[output_index].getDefiningOp());
3695 
3696   TF_RET_CHECK(reduce);
3697   TF_RET_CHECK(reduce.getNumResults() == 1);
3698 
3699   mlir::Value init_value = reduce.init_values()[0];
3700   mlir::Value dest = fusion.getOutputBuffers()[output_index];
3701   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3702                       TryBuildConstantInitializerThunk(init_value, dest));
3703   if (constant_init_thunk) {
3704     return {std::move(constant_init_thunk)};
3705   }
3706 
3707   auto input_buffers = fusion.getInputBuffers();
3708 
3709   std::vector<llvm_ir::IrArray> ir_arrays;
3710   TF_ASSIGN_OR_RETURN(
3711       std::unique_ptr<KernelThunk> kernel_thunk,
3712       BuildKernelThunkForMlir(fusion, Thunk::ThunkInfo(), {}, &ir_arrays));
3713   const llvm_ir::IrArray dest_array =
3714       ir_arrays[input_buffers.size() + output_index];
3715 
3716   const Shape dest_shape = TypeToShape(dest.getType());
3717   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
3718       dest_shape, ir_emitter_context_->gpu_device_info());
3719   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3720                          ir_emitter_context_->llvm_module());
3721 
3722   const HloComputation* fused_computation =
3723       *GetOrCreateSubComputationFromRegion(&fusion.region(),
3724                                            /*is_fusion=*/true);
3725 
3726   // If init_value was fused into this reduce we have to generate it first.
3727   GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
3728                                           ir_emitter_context_->llvm_module(),
3729                                           &b_, GetNestedComputer());
3730 
3731   FusedIrEmitter fused_emitter(&elemental_emitter);
3732   for (int i = 0; i < fused_computation->num_parameters(); i++) {
3733     fused_emitter.BindGenerator(
3734         fused_computation->parameter_instruction(i),
3735         [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
3736           return ir_arrays[i].EmitReadArrayElement(index, &b_);
3737         });
3738   }
3739   HloInstruction* instr = fused_computation->root_instruction();
3740   if (instr->opcode() != HloOpcode::kTuple) {
3741     CHECK_EQ(0, output_index);
3742   } else {
3743     instr = instr->mutable_operand(output_index);
3744   }
3745   TF_RET_CHECK(instr->shape().IsArray());
3746   TF_ASSIGN_OR_RETURN(auto generator,
3747                       fused_emitter.GetGenerator(instr->operand(1)));
3748   TF_RETURN_IF_ERROR(
3749       ParallelLoopEmitter(generator, dest_array, launch_dimensions, &b_)
3750           .EmitLoop(mlir::GetNameFromLoc(fusion.getLoc())));
3751   return {std::move(kernel_thunk)};
3752 }
3753 
3754 namespace {
3755 
3756 // Checks that the buffers corresponding to the given two HLOs share the same
3757 // allocation.
CheckHloBuffersShareAllocation(const HloInstruction * a,const HloInstruction * b,const ShapeIndex & index,const BufferAssignment & buffer_assignment)3758 Status CheckHloBuffersShareAllocation(
3759     const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
3760     const BufferAssignment& buffer_assignment) {
3761   const BufferAllocation::Slice slice_a =
3762       buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
3763   const BufferAllocation::Slice slice_b =
3764       buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
3765   if (slice_a != slice_b) {
3766     return InternalError(
3767         "instruction %s %s does not share allocation with instruction %s %s",
3768         a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
3769   }
3770   return Status::OK();
3771 }
3772 
3773 // Checks that all buffers used during while loop iteration share the same
3774 // buffer allocation. This includes buffers for while result, while init
3775 // operand, condition parameter, body parameter and body result.
3776 // Returns OK on success, error status otherwise.
CheckWhileBuffersShareAllocation(const HloInstruction * xla_while,const BufferAssignment & buffer_assignment)3777 Status CheckWhileBuffersShareAllocation(
3778     const HloInstruction* xla_while,
3779     const BufferAssignment& buffer_assignment) {
3780   return ShapeUtil::ForEachSubshapeWithStatus(
3781       xla_while->shape(),
3782       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
3783         const HloInstruction* condition_parameter =
3784             xla_while->while_condition()->parameter_instruction(0);
3785         const HloComputation* body = xla_while->while_body();
3786         const HloInstruction* body_parameter = body->parameter_instruction(0);
3787         const HloInstruction* body_result = body->root_instruction();
3788         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3789             xla_while, xla_while->operand(0), index, buffer_assignment));
3790         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3791             xla_while, condition_parameter, index, buffer_assignment));
3792         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3793             xla_while, body_parameter, index, buffer_assignment));
3794         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3795             xla_while, body_result, index, buffer_assignment));
3796         return Status::OK();
3797       });
3798 }
3799 
3800 // Checks that the buffers used in a conditional instruction are shared with the
3801 // operands and result as follows:
3802 //   * The result buffer of the conditional should share the allocation with the
3803 //     result buffers of each branch computation.
3804 //   * The buffer of operand b+1 should share the allocation with the buffer of
3805 //     the parameter 0 instruction of the b'th computation.
CheckConditionalBuffersShareAllocation(const HloInstruction * conditional,const BufferAssignment & buffer_assignment)3806 Status CheckConditionalBuffersShareAllocation(
3807     const HloInstruction* conditional,
3808     const BufferAssignment& buffer_assignment) {
3809   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
3810       conditional->shape(),
3811       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
3812         for (auto branch_computation : conditional->branch_computations()) {
3813           TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3814               conditional, branch_computation->root_instruction(), index,
3815               buffer_assignment));
3816         }
3817         return Status::OK();
3818       }));
3819   for (int j = 0; j < conditional->branch_count(); ++j) {
3820     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
3821         conditional->operand(j + 1)->shape(),
3822         [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
3823           return CheckHloBuffersShareAllocation(
3824               conditional->operand(j + 1),
3825               conditional->branch_computation(j)->parameter_instruction(0),
3826               index, buffer_assignment);
3827         }));
3828   }
3829   return Status::OK();
3830 }
3831 
3832 }  // namespace
3833 
BuildWhileThunk(const HloInstruction * hlo)3834 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
3835     const HloInstruction* hlo) {
3836   // Check that all while-related buffers share an allocation.
3837   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
3838       hlo, ir_emitter_context_->buffer_assignment()));
3839 
3840   // Generate thunk sequence for while 'condition'.
3841   HloComputation* condition = hlo->while_condition();
3842   TF_ASSIGN_OR_RETURN(auto ir_emitter_condition,
3843                       IrEmitterUnnested::Create(hlo_module_config_, condition,
3844                                                 ir_emitter_context_));
3845   TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get()));
3846 
3847   // Generate thunk sequence for while 'body'.
3848   HloComputation* body = hlo->while_body();
3849   TF_ASSIGN_OR_RETURN(
3850       auto ir_emitter_body,
3851       IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
3852   TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
3853 
3854   const auto* index_map = ir_emitter_context_->profile_index_map();
3855   absl::optional<size_t> condition_profile_index, body_profile_index;
3856   if (index_map) {
3857     condition_profile_index = index_map->GetProfileIndexFor(*condition);
3858     body_profile_index = index_map->GetProfileIndexFor(*body);
3859   }
3860 
3861   return std::unique_ptr<Thunk>(new WhileThunk(
3862       GetThunkInfo(hlo),
3863       GetAllocationSlice(*condition->root_instruction()),  // cond result
3864       ir_emitter_condition->ConsumeThunkSequence(),
3865       ir_emitter_body->ConsumeThunkSequence(), condition_profile_index,
3866       body_profile_index));
3867 }
3868 
BuildForThunk(const HloInstruction * hlo,const int64 loop_limit)3869 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
3870     const HloInstruction* hlo, const int64 loop_limit) {
3871   // Check that all while-related buffers share an allocation.
3872   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
3873       hlo, ir_emitter_context_->buffer_assignment()));
3874 
3875   // Generate thunk sequence for while 'body' (will be used a For loop body).
3876   HloComputation* body = hlo->while_body();
3877   TF_ASSIGN_OR_RETURN(
3878       auto ir_emitter_body,
3879       IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
3880   TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
3881 
3882   const auto* index_map = ir_emitter_context_->profile_index_map();
3883   absl::optional<size_t> body_profile_index;
3884   if (index_map) {
3885     body_profile_index = index_map->GetProfileIndexFor(*body);
3886   }
3887 
3888   return std::unique_ptr<Thunk>(new ForThunk(
3889       GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(),
3890       body_profile_index));
3891 }
3892 
BuildConditionalThunk(const HloInstruction * hlo)3893 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
3894     const HloInstruction* hlo) {
3895   // Check that the buffers used in conditional are shared with the operands and
3896   // result appropriately.
3897   TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
3898       hlo, ir_emitter_context_->buffer_assignment()));
3899 
3900   std::vector<BufferAllocation::Slice> branch_operands;
3901   std::vector<ThunkSequence> branch_thunks;
3902   std::vector<absl::optional<size_t>> branch_profile_indices;
3903 
3904   int branch_count = hlo->branch_count();
3905   branch_thunks.reserve(branch_count);
3906   branch_profile_indices.reserve(branch_count);
3907 
3908   const auto* index_map = ir_emitter_context_->profile_index_map();
3909 
3910   for (int j = 0; j < branch_count; ++j) {
3911     branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
3912     HloComputation* branch_computation = hlo->branch_computation(j);
3913     TF_ASSIGN_OR_RETURN(
3914         auto ir_emitter,
3915         IrEmitterUnnested::Create(hlo_module_config_, branch_computation,
3916                                   ir_emitter_context_));
3917     TF_CHECK_OK(branch_computation->Accept(ir_emitter.get()));
3918     branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
3919 
3920     absl::optional<size_t> profile_index;
3921     if (index_map) {
3922       profile_index = index_map->GetProfileIndexFor(*branch_computation);
3923     }
3924     branch_profile_indices.push_back(profile_index);
3925   }
3926 
3927   ConditionalThunkConfig config = GetConditionalThunkConfig(
3928       hlo, std::move(branch_thunks), std::move(branch_profile_indices));
3929   return std::unique_ptr<Thunk>(new ConditionalThunk(
3930       GetThunkInfo(hlo), std::move(config),
3931       GetAllocationSlice(*hlo->operand(0)), branch_operands));
3932 }
3933 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & body_emitter)3934 Status IrEmitterUnnested::EmitTargetElementLoop(
3935     const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
3936   return InternalError("This should be unreachable");
3937 }
3938 
3939 // Gets the output offset as calculated from thread_id.x (to be applied to the
3940 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const KernelMappingScheme & mapping_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)3941 static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
3942                                     llvm::Value* thread_id_x,
3943                                     llvm::Type* index_ty,
3944                                     llvm::IRBuilder<>* b) {
3945   auto constant = [&](int64 val) {
3946     return llvm::ConstantInt::get(index_ty, val);
3947   };
3948   if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
3949     return thread_id_x;
3950   } else if (mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
3951     return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
3952   }
3953   CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
3954   int64 x_num_steps =
3955       mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
3956   return b->CreateMul(thread_id_x, constant(x_num_steps));
3957 }
3958 
3959 // Calls `emit_elem_function()` `x_num_steps` times.  If
3960 // `vector_size`==1, then each element index passed to
3961 // `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
3962 // then it must be a multiple of `x_num_steps`.  In that case, it
3963 // triggers a different indexing order that is vectorizable by
3964 // LLVM. It generates many groups of calls to `emit_elem_function`. Each
3965 // group is separated by `step_x` elements.  Inside a group, elements
3966 // are consecutive. If `check_x_tile_bounds` is true, then it will check
3967 // if the element index is in bound compared to `tile_width` before
3968 // calling `emit_elem_function`.
UnrollInnerTileLoop(bool check_x_tile_bounds,int64 x_num_steps,int64 step_x,int64 vector_size,const string & loop_name,KernelSupportLibrary * ksl,llvm::Value * start_offset_x,llvm::Value * y_loc,llvm::Value * tile_width,const IrArray::Index & source_idx,llvm::IRBuilder<> * b,const IrEmitterUnnested::EmitElementFunction * emit_elem_function)3969 static void UnrollInnerTileLoop(
3970     bool check_x_tile_bounds, int64 x_num_steps, int64 step_x,
3971     int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl,
3972     llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
3973     const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
3974     const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
3975   llvm::Type* index_ty = tile_width->getType();
3976   auto constant = [&](int64 val) {
3977     return llvm::ConstantInt::get(index_ty, val);
3978   };
3979   IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
3980   for (int64 j = 0; j < x_num_steps / vector_size; j++) {
3981     for (int64 i = 0; i < vector_size; i++) {
3982       int64 linear_index = j * vector_size + i;
3983       llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
3984                                         start_offset_x, "x_loc");
3985       IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
3986           constant(j * step_x * vector_size + i), kDimX, b);
3987       auto emit_element = [&] {
3988         return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
3989       };
3990       if (check_x_tile_bounds) {
3991         ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
3992                 emit_element);
3993       } else {
3994         emit_element();
3995       }
3996     }
3997   }
3998 }
3999 
EmitTile(const KernelMappingScheme & mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,const ThreadIdInfo & thread_id_info,llvm::Value * tile_height,llvm::Value * tile_width,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)4000 void IrEmitterUnnested::EmitTile(
4001     const KernelMappingScheme& mapping_scheme,
4002     const IrArray::Index& tile_origin_index, const string& loop_name,
4003     KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
4004     llvm::Value* tile_height, llvm::Value* tile_width,
4005     const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
4006   llvm::Type* index_ty = tile_width->getType();
4007   auto constant = [&](int64 val) {
4008     return llvm::ConstantInt::get(index_ty, val);
4009   };
4010   int64 num_threads_x = mapping_scheme.GetNumThreadsX();
4011   llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
4012   int64 tile_size_x = mapping_scheme.GetTileSizeX();
4013 
4014   int64 x_num_steps = tile_size_x / num_threads_x;
4015   llvm::Value* start_offset_x = GetStartOffsetX(
4016       mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
4017 
4018   // Using dilated mapping scheme, each thread steps with a stride of number
4019   // of threads.
4020   // Otherwise, the stride is one, but we multiply each offset by the limit of
4021   // number of steps which can be made.
4022   int64 step_x =
4023       mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
4024   int64 vector_size = mapping_scheme.GetVectorSize();
4025 
4026   IrArray::Index source_idx =
4027       tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
4028 
4029   auto ceil_of_ratio = [&](llvm::Value* a, llvm::Value* b) {
4030     return b_.CreateUDiv(b_.CreateAdd(b_.CreateAdd(a, b), constant(-1)), b);
4031   };
4032 
4033   // True iff all threads always execute all instructions in the tiling
4034   // dimension X.
4035   bool x_tile_fits =
4036       mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 &&
4037       mapping_scheme.GetRowContiguous();
4038 
4039   // The outer loop below is simply doing:
4040   //
4041   // for (int y_loc=thread_id_y; y_loc<tile_height; y_loc+=num_threads_y)
4042   //
4043   //
4044   // However, in order to avoid an LLVM optimization triggering the ptxas bug,
4045   // we write this loop in a convoluted way:
4046   //
4047   // y_bound = ceil_of_ratio(tile_height - thread_id_y, num_threads_y)
4048   // for (int y_indvar=0; y_indvar<y_bound; y_indvar+=1)
4049   //    y_loc = thread_id_y + y_indvar * num_threads_y
4050   //
4051   // TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
4052   // workaround.
4053   ksl->For(
4054       loop_name + "_y_in_tile",
4055       /*start=*/constant(0),
4056       /*end=*/
4057       ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
4058                     num_threads_y),
4059       /*step=*/constant(1), [&](llvm::Value* y_indvar) {
4060         llvm::Value* y_loc = b_.CreateAdd(
4061             thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
4062         auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) {
4063           return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps, step_x,
4064                                      vector_size, loop_name, ksl,
4065                                      start_offset_x, y_loc, tile_width,
4066                                      source_idx, &b_, &emit_elem_function);
4067         };
4068 
4069         // Only take this path when we unroll in a way vectorizable by
4070         // LLVM. Special case when the tile doesn't fit completely for even
4071         // row size. For odd row size every other row isn't aligned to the
4072         // vectorized size, so it can't be vectorized by LLVM.
4073         if (!x_tile_fits &&
4074             mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
4075           ksl->If(
4076               loop_name + "_is_full_tile",
4077               // For the last block, tile_width will be the number of
4078               // elements left.
4079               b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
4080                               tile_width),
4081               [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); },
4082               [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); });
4083         } else {
4084           unroll_inner_tile_loop(/*check_x_tile_bounds=*/!x_tile_fits);
4085         }
4086       });
4087 }
4088 
4089 // Emits code to process a tensor element in a tile for the given kCopy HLO that
4090 // performs a 0-2-1 transpose.
4091 //
4092 // index: The index for the first output element in the normalized tensor. The
4093 //   normalized tensor is the resulting tensor after collapsing contiguous
4094 //   dimensions that play the same role in the transpose.
4095 // mapping_scheme: Kernel mapping scheme specifying the tiling
EmitTileElementForCopy(const Shape & output_shape,const llvm_ir::IrArray & output_array,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)4096 void IrEmitterUnnested::EmitTileElementForCopy(
4097     const Shape& output_shape, const llvm_ir::IrArray& output_array,
4098     const llvm_ir::IrArray::Index& index,
4099     const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
4100     llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
4101   // TODO(jlebar): Add AA metadata to this load.
4102   llvm::Instruction* load_from_shmem_buffer =
4103       Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
4104            "output_element");
4105   Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4106       output_shape.element_type(), mapping_scheme.GetDimsInElems());
4107   // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
4108   // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
4109   output_array.CastToShape(output_reduced_shape, &b_)
4110       .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_);
4111 }
4112 
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const KernelMappingScheme & kernel_mapping_scheme)4113 static IrArray::Index GetUnnormalizedIndex(
4114     const IrArray::Index& normalized_shape_index,
4115     const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
4116     const KernelMappingScheme& kernel_mapping_scheme) {
4117   DCHECK_EQ(normalized_shape_index.size(), 3);
4118   // If the normalization only add a new dimensions of size 1,
4119   // generate simpler indexing. LLVM doesn't always simplify the more
4120   // complicated indexing and this prevents it from vectorizing some
4121   // cases. We do this only for major_to_minor memory layout.
4122   if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
4123       unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
4124       unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
4125       unnormalized_shape.layout().minor_to_major(1) == 0) {
4126     CHECK_EQ(normalized_shape_index.dims()[0], 1);
4127     auto multidim = normalized_shape_index.multidim();
4128     return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
4129                           normalized_shape_index.GetType());
4130   }
4131   llvm::Value* linear = normalized_shape_index.Linearize(
4132       kernel_mapping_scheme.GetDimsInElems(), b_);
4133   return IrArray::Index(linear, unnormalized_shape, b_);
4134 }
4135 
4136 // Emits code to process a tensor element in a tile for the given kLoop fusion
4137 // HLO containing parameters that are 0-2-1 transpose of its outputs.
4138 //
4139 // index: The index for the first output element in the normalized tensor, that
4140 //   is the resulting tensor after collapsing contiguous dimensions that play
4141 //   the same role in the transpose.
4142 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForFusion(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)4143 void IrEmitterUnnested::EmitTileElementForFusion(
4144     mlir::lmhlo::FusionOp fusion,
4145     absl::Span<const llvm_ir::IrArray> operand_arrays,
4146     absl::Span<const llvm_ir::IrArray> output_arrays,
4147     const llvm_ir::IrArray::Index& index,
4148     const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
4149     llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
4150   const HloComputation* fused_computation =
4151       *GetOrCreateSubComputationFromRegion(&fusion.region(),
4152                                            /*is_fusion=*/true);
4153   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
4154                                      GetNestedComputer());
4155   FusedIrEmitter fused_emitter(&elem_emitter);
4156   for (int i = 0; i < operand_arrays.size(); i++) {
4157     llvm_ir::ElementGenerator gen;
4158     if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
4159       gen = [this, param_tile_buffer, x_loc,
4160              y_loc](llvm_ir::IrArray::Index index) {
4161         // TODO(jlebar): Add AA metadata to this load.  Tile buffers are
4162         // global variables, so LLVM's points-to analysis doesn't help us
4163         // much.  And we want the AA info to be present before address
4164         // spaces are inferred (which is pretty late in the pipeline), so
4165         // even if we had address-space-based AA in LLVM, it wouldn't help
4166         // us much here.
4167         return b_.CreateLoad(
4168             b_.CreateGEP(param_tile_buffer,
4169                          {index.GetConstantWithIndexType(0), x_loc, y_loc}),
4170             "tiled_buffer");
4171       };
4172     } else {
4173       auto array = operand_arrays[i];
4174       auto name = fused_computation->parameter_instruction(i)->name();
4175       gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
4176         return array.EmitReadArrayElement(index, &b_, name);
4177       };
4178     }
4179     fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
4180                                 std::move(gen));
4181   }
4182   IrArray::Index untiled_index = GetUnnormalizedIndex(
4183       index, output_arrays[0].GetShape(), &b_, mapping_scheme);
4184   llvm_ir::ElementGenerator output_generator =
4185       *fused_emitter.GetGenerator(fused_computation->root_instruction());
4186   llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
4187   if (output_arrays.size() > 1) {
4188     DCHECK(output_value->getType()->isStructTy());
4189     DCHECK_EQ(output_value->getType()->getStructNumElements(),
4190               output_arrays.size() - 1);
4191     for (int64 i = 0; i < output_arrays.size() - 1; ++i) {
4192       output_arrays[i].EmitWriteArrayElement(
4193           untiled_index, ExtractValue(output_value, i), &b_);
4194     }
4195   } else {
4196     output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
4197   }
4198 }
4199 
GetReduceFromUnnestedMlir(mlir::Operation * unnested_hlo,int index)4200 static mlir::Operation* GetReduceFromUnnestedMlir(mlir::Operation* unnested_hlo,
4201                                                   int index) {
4202   if (mlir::isa<mlir::lmhlo::ReduceOp>(unnested_hlo)) {
4203     CHECK_EQ(0, index);
4204     return unnested_hlo;
4205   }
4206   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4207     auto results = fusion.getFusionResults();
4208     CHECK(index < results.size())
4209         << MlirToString(unnested_hlo) << " vs " << index;
4210     return results[index].getDefiningOp();
4211   }
4212   return nullptr;
4213 }
4214 
EmitPrologueForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> operand_ir_arrays,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenInfo * reduction_info)4215 void IrEmitterUnnested::EmitPrologueForReduction(
4216     mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
4217     HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
4218     absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
4219     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4220     ReductionCodegenInfo* reduction_info) {
4221   VLOG(10) << "Emit prologue for reduction: " << MlirToString(unnested_hlo);
4222   mlir::Operation* first_reduce = nullptr;
4223   for (int index : instr_index_group) {
4224     mlir::Operation* reduce_inst =
4225         GetReduceFromUnnestedMlir(unnested_hlo, index);
4226 
4227     if (!IsReductionFromOrToContiguousDimensions(reduce_inst)) {
4228       continue;
4229     }
4230 
4231     auto results = GetHloOutputs(reduce_inst);
4232     CHECK_EQ(1, results.size());
4233     Shape reduce_inst_shape = TypeToShape(results[0].getType());
4234 
4235     VLOG(10) << "Emit prologue for reduction: " << MlirToString(reduce_inst);
4236     if (first_reduce == nullptr) {
4237       first_reduce = reduce_inst;
4238     } else {
4239       CHECK(absl::c_equal(
4240           first_reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions"),
4241           reduce_inst->getAttrOfType<mlir::DenseIntElementsAttr>(
4242               "dimensions")));
4243     }
4244 
4245     AddressVector* reduction_input_addresses =
4246         reduction_info->GetMutableReductionInputAddresses();
4247     llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
4248         reduce_inst_shape.element_type(), ir_emitter_context_->llvm_module());
4249     llvm::AllocaInst* reduction_input_address =
4250         llvm_ir::EmitAllocaAtFunctionEntry(element_type,
4251                                            "reduction_input_address", &b_);
4252     reduction_input_addresses->push_back(reduction_input_address);
4253 
4254     int num_partial_results = reduction_info->GetNumPartialResults();
4255     AddressVector* partial_result_addresses =
4256         reduction_info->GetMutablePartialResultAddresses();
4257     llvm::AllocaInst* partial_result_address =
4258         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
4259             element_type, /*ArraySize=*/b_.getInt32(num_partial_results),
4260             ("partial_reduction_result." + llvm::Twine(index)).str(), &b_);
4261     partial_result_addresses->push_back(partial_result_address);
4262 
4263     // Initialize the partial result with the initial value of the reduction.
4264     llvm::Value* init_ir_value;
4265     if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4266       const HloInstruction* reduce_hlo = fused_computation->root_instruction();
4267       if (reduce_hlo->opcode() == HloOpcode::kTuple) {
4268         reduce_hlo = reduce_hlo->operand(index);
4269       }
4270       const HloInstruction* init_value = reduce_hlo->operand(1);
4271 
4272       init_ir_value = (*fused_emitter->GetGenerator(
4273           init_value))(IrArray::Index(b_.getInt32Ty()))
4274                           .ValueOrDie();
4275     } else {
4276       init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(
4277           IrArray::Index(b_.getInt32Ty()), &b_);
4278     }
4279 
4280     for (int i = 0; i < num_partial_results; ++i) {
4281       Store(init_ir_value,
4282             InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
4283     }
4284     reduction_info->GetMutableInitialValues()->push_back(init_ir_value);
4285 
4286     auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
4287     int64 num_threads_x = mapping_scheme.GetNumThreadsX();
4288     llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
4289         reduce_inst_shape.element_type(), module_);
4290     llvm::Type* buffer_type = [&] {
4291       if (reduction_info->IsRowReduction()) {
4292         // Allocate __shared__ cache[num_partial_results][kWarpSize].
4293         return llvm::ArrayType::get(
4294             llvm::ArrayType::get(primitive_type, kWarpSize),
4295             num_partial_results);
4296       } else {
4297         // Allocate __shared__
4298         // cache[num_partial_results][num_threads][num_threads + 1], where
4299         // num_threads == num_threads_x == num_threads_y.  The "+1" is used to
4300         // avoid bank conflicts.
4301         CHECK_EQ(num_threads_x, mapping_scheme.GetNumThreadsY());
4302         return llvm::ArrayType::get(
4303             llvm::ArrayType::get(
4304                 llvm::ArrayType::get(primitive_type, num_threads_x + 1),
4305                 num_threads_x),
4306             num_partial_results);
4307       }
4308     }();
4309     llvm::GlobalVariable* shared_cache_per_reduce =
4310         llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4311                                           buffer_type,
4312                                           absl::StrCat("shared_cache_", index));
4313     reduction_info->GetMutableSharedCache()->push_back(shared_cache_per_reduce);
4314   }
4315   CHECK(first_reduce);
4316 }
4317 
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses,int threads_per_block)4318 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
4319     absl::Span<HloComputation* const> reducers,
4320     absl::Span<llvm::AllocaInst* const> partial_result_addresses,
4321     int threads_per_block) {
4322   CHECK_EQ(reducers.size(), partial_result_addresses.size());
4323   for (int i = 0; i != reducers.size(); i++) {
4324     EmitFullWarpShuffleDownLoopForReduce(
4325         reducers[i], partial_result_addresses[i]->getType()->getElementType(),
4326         partial_result_addresses[i], threads_per_block);
4327   }
4328 }
4329 
EmitFullWarpShuffleDownLoopForReduce(HloComputation * reducer,llvm::Type * element_type,llvm::Value * partial_result_address,int threads_per_block)4330 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
4331     HloComputation* reducer, llvm::Type* element_type,
4332     llvm::Value* partial_result_address, int threads_per_block) {
4333   // This only works when the block size is a multiple of 32 threads.
4334   CHECK_EQ(threads_per_block % 32, 0);
4335   for (int distance = 16; distance >= 1; distance /= 2) {
4336     int bit_width = llvm_ir::GetSizeInBits(element_type);
4337     llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
4338         element_type, "result_from_other_lane", &b_);
4339     // Bitcast cannot be applied to aggregate types (even packed ones), so
4340     // we bitcast addresses of load/store to intN* of the same bit-width.
4341     llvm::Type* shuffled_value_type =
4342         element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
4343     auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
4344       return b_.CreatePointerBitCastOrAddrSpaceCast(
4345           ptr, shuffled_value_type->getPointerTo());
4346     };
4347     llvm::Value* partial_result =
4348         Load(convert_pointer_for_shuffle(partial_result_address),
4349              "partial_reduction_result");
4350     Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
4351           convert_pointer_for_shuffle(result_from_other_lane));
4352     TF_CHECK_OK(EmitCallToNestedComputation(
4353         *reducer, {partial_result_address, result_from_other_lane},
4354         partial_result_address));
4355   }
4356 }
4357 
4358 // Given the IrArray index of a reduction input, returns the linear address of
4359 // the reduction output as if the reduction were going to keep the input shape
4360 // with the dimensions being reduced moved.
GetUntransposedOutputLinearAddress(llvm::IRBuilder<> * b,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info)4361 static llvm::Value* GetUntransposedOutputLinearAddress(
4362     llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index,
4363     const ReductionCodegenInfo& reduction_info) {
4364   const KernelMappingScheme& kernel_mapping_scheme =
4365       reduction_info.GetKernelMappingScheme();
4366   if (reduction_info.IsRowReduction()) {
4367     // For row-reduction, y-coordinate determines which row we write into.
4368     return index[kDimY];
4369   }
4370   // For column reduction, we get the transposed address.
4371   absl::Span<const int64> dims_in_elem = kernel_mapping_scheme.GetDimsInElems();
4372   llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]);
4373   llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size);
4374   return b->CreateAdd(x_block_offset, index[kDimX]);
4375 }
4376 
EmitEpilogueForReduction(llvm::Type * index_ty,mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const ReductionCodegenInfo & reduction_info,const TilingKernelInfo & tiling_kernel_info)4377 void IrEmitterUnnested::EmitEpilogueForReduction(
4378     llvm::Type* index_ty, mlir::Operation* unnested_hlo,
4379     absl::Span<const int> instr_index_group,
4380     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4381     absl::Span<HloComputation* const> reducers,
4382     const ReductionCodegenInfo& reduction_info,
4383     const TilingKernelInfo& tiling_kernel_info) {
4384   const KernelMappingScheme& mapping_scheme =
4385       reduction_info.GetKernelMappingScheme();
4386   auto constant = [&](uint64 c) -> llvm::Constant* {
4387     return llvm::ConstantInt::get(index_ty, c);
4388   };
4389 
4390   IrEmitterUnnested::ThreadIdInfo thread_id_info =
4391       EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
4392                        mapping_scheme.GetNumThreadsX());
4393 
4394   IrArray::Index start_offset = [&] {
4395     llvm::Value* x_loc = thread_id_info.thread_id_x;
4396     llvm::Value* y_loc = thread_id_info.thread_id_y;
4397     if (!reduction_info.IsRowReduction()) {
4398       std::swap(x_loc, y_loc);
4399     }
4400     llvm::Value* start_offset_x =
4401         GetStartOffsetX(mapping_scheme, x_loc, index_ty, &b_);
4402     return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_)
4403         .AddOffsetToDim(start_offset_x, kDimX, &b_);
4404   }();
4405 
4406   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
4407       reduction_info.GetPartialResultAddresses();
4408 
4409   int num_partial_results = reduction_info.GetNumPartialResults();
4410 
4411   // Emit an atomic operation that accumulates the partial reduction to the
4412   // output element. For row reduction, this is only for lane 0 due to the
4413   // if-statement emitted above.
4414   //
4415   // `i` is the compacted index for contiguous-dimension reductions. It's used
4416   // for accessing `reduction_info` and `reducers`, which are also compacted.
4417   int i = -1;
4418   for (int index : instr_index_group) {
4419     mlir::Operation* reduce_hlo =
4420         GetReduceFromUnnestedMlir(unnested_hlo, index);
4421     if (!IsReductionFromOrToContiguousDimensions(reduce_hlo)) {
4422       continue;
4423     }
4424     i++;
4425     auto operand_shape = TypeToShape(reduce_hlo->getOperand(0).getType());
4426     Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
4427         [&](int64 dim) {
4428           return !absl::c_linear_search(
4429               reduce_hlo->getAttrOfType<mlir::DenseIntElementsAttr>(
4430                   "dimensions"),
4431               dim);
4432         },
4433         operand_shape);
4434     for (int j = 0; j < num_partial_results; ++j) {
4435       llvm::Value* untransposed_output_linear_address =
4436           GetUntransposedOutputLinearAddress(
4437               &b_, start_offset.AddOffsetToDim(constant(j), kDimX, &b_),
4438               reduction_info);
4439 
4440       // A reduction is allowed to transpose its output.  For example, suppose
4441       // we are reducing the second dimension of f32[10,20,30]{3,2,1}.  We are
4442       // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
4443       // f32[10,30]{0,1} (transposing the two output dims).
4444       //
4445       // At this point in the function we have a "partial sum" of input elements
4446       // (stored in partial_result_addresses), and we need to accumulate it into
4447       // the correct output element.
4448       auto output_array = result_ir_arrays[index];
4449       IrArray::Index element_index(
4450           /*linear=*/untransposed_output_linear_address,
4451           reduction_kept_element_shape, &b_);
4452       IrArray::Index output_index(element_index.multidim(),
4453                                   output_array.GetShape(),
4454                                   element_index.GetType());
4455       llvm::Value* output_address = output_array.EmitArrayElementAddress(
4456           output_index, &b_, "output_element_address");
4457       llvm::Value* current_output = b_.CreateInBoundsGEP(
4458           partial_result_addresses[i], {constant(j)}, "current_output");
4459 
4460       llvm::GlobalVariable* shared_cache = reduction_info.GetSharedCache()[i];
4461 
4462       // __shared__ memory uses a different address space, so we cast it to
4463       // global address space before writing or reading.
4464       auto shared_to_global = [&](llvm::Value* input, llvm::Twine name = "") {
4465         return b_.CreateAddrSpaceCast(
4466             input,
4467             llvm::PointerType::get(input->getType()->getPointerElementType(),
4468                                    /*AddressSpace=*/0),
4469             name);
4470       };
4471 
4472       auto is_zero = [&](llvm::Value* value) {
4473         return b_.CreateICmpEQ(value, constant(0));
4474       };
4475 
4476       KernelSupportLibrary ksl(&b_);
4477       llvm::Type* element_type =
4478           partial_result_addresses[i]->getType()->getElementType();
4479       if (reduction_info.IsRowReduction()) {
4480         EmitFullWarpShuffleDownLoopForReduce(
4481             reducers[i], element_type, current_output,
4482             mapping_scheme.GetThreadsPerBlock());
4483         llvm::Value* warp_id =
4484             b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
4485         ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
4486           llvm::Value* shmem_output_addr =
4487               shared_to_global(b_.CreateInBoundsGEP(
4488                   shared_cache, {b_.getInt32(0), constant(j), warp_id}));
4489           b_.CreateStore(b_.CreateLoad(current_output), shmem_output_addr);
4490         });
4491 
4492         EmitSyncThreads();
4493         ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
4494           llvm::Value* block_accum_addr = shared_to_global(b_.CreateInBoundsGEP(
4495               shared_cache,
4496               {b_.getInt32(0), constant(j), thread_id_info.lane_id}));
4497           llvm::Value* initial_value = reduction_info.GetInitialValues()[i];
4498           llvm::Value* initial_value_addr =
4499               shared_to_global(llvm_ir::EmitAllocaAtFunctionEntry(
4500                   element_type, "initial_value_addr", &b_));
4501           b_.CreateStore(initial_value, initial_value_addr);
4502 
4503           llvm::Value* warp_exists = b_.CreateICmpULT(
4504               thread_id_info.thread_id_x,
4505               constant(mapping_scheme.GetNumThreadsX() / kWarpSize));
4506 
4507           llvm::Value* selected_value = b_.CreateSelect(
4508               warp_exists, block_accum_addr, initial_value_addr);
4509 
4510           EmitFullWarpShuffleDownLoopForReduce(
4511               reducers[i], element_type,
4512               /*block_accum_addr*/ selected_value,
4513               mapping_scheme.GetThreadsPerBlock());
4514           ksl.If("reduction_atomic_update", is_zero(thread_id_info.thread_id_x),
4515                  [&] {
4516                    TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4517                        *reducers[i], output_address, block_accum_addr));
4518                  });
4519         });
4520 
4521       } else {
4522         llvm::Value* shmem_output_addr = shared_to_global(
4523             b_.CreateInBoundsGEP(shared_cache, {b_.getInt32(0), constant(j),
4524                                                 thread_id_info.thread_id_x,
4525                                                 thread_id_info.thread_id_y}),
4526             "shmem_output_address");
4527         llvm::Value* current_output_value = b_.CreateLoad(current_output);
4528         b_.CreateStore(current_output_value, shmem_output_addr);
4529 
4530         EmitSyncThreads();
4531 
4532         // Get transposed element from shared memory.
4533         llvm::Value* shmem_transposed_addr =
4534             shared_to_global(b_.CreateInBoundsGEP(
4535                 shared_cache,
4536                 {b_.getInt32(0), constant(j), thread_id_info.thread_id_y,
4537                  thread_id_info.thread_id_x},
4538                 "shmem_transposed_addr"));
4539 
4540         EmitFullWarpShuffleDownLoopForReduce(
4541             reducers[i], element_type, shmem_transposed_addr,
4542             mapping_scheme.GetThreadsPerBlock());
4543 
4544         // Some threads in the block are completely outside of the bound of the
4545         // tensor, so they should not write any output at all.
4546         llvm::Value* has_output = b_.CreateAnd(
4547             b_.CreateICmpULT(
4548                 GetStartOffsetX(mapping_scheme, thread_id_info.thread_id_y,
4549                                 index_ty, &b_),
4550                 tiling_kernel_info.output_tile_bounds[kDimX]),
4551             b_.CreateICmpULT(thread_id_info.thread_id_x,
4552                              tiling_kernel_info.output_tile_bounds[kDimY]));
4553 
4554         ksl.If("reduction_atomic_update",
4555                b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
4556                  TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4557                      *reducers[i], output_address, shmem_transposed_addr));
4558                });
4559       }
4560     }
4561   }
4562 }
4563 
EmitBlockId()4564 llvm::Value* IrEmitterUnnested::EmitBlockId() {
4565   return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {},
4566                                         {}, &b_);
4567 }
4568 
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,absl::optional<int64> thread_id_filter,absl::optional<int64> block_id_filter)4569 void IrEmitterUnnested::EmitPrintfWithThreadId(
4570     absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
4571     absl::optional<int64> thread_id_filter,
4572     absl::optional<int64> block_id_filter) {
4573   llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty());
4574   llvm::Value* block_id = EmitBlockId();
4575   std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
4576   updated_arguments.insert(updated_arguments.end(), arguments.begin(),
4577                            arguments.end());
4578   llvm::Value* constraint = b_.getTrue();
4579   if (thread_id_filter) {
4580     constraint = b_.CreateAnd(
4581         constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
4582   }
4583   if (block_id_filter) {
4584     constraint = b_.CreateAnd(
4585         constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
4586   }
4587   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4588   ksl.If(constraint, [&] {
4589     xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
4590                          updated_arguments, &b_);
4591   });
4592 }
4593 
EmitTileElementForReduction(mlir::Operation * unnested_hlo,const Shape & reduction_operand_shape,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> operand_ir_arrays,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info,int64 x_iter_num)4594 void IrEmitterUnnested::EmitTileElementForReduction(
4595     mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape,
4596     absl::Span<const int> instr_index_group, HloComputation* fused_computation,
4597     FusedIrEmitter* fused_emitter,
4598     absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
4599     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4600     absl::Span<HloComputation* const> reducers,
4601     const llvm_ir::IrArray::Index& index,
4602     const ReductionCodegenInfo& reduction_info, int64 x_iter_num) {
4603   VLOG(10) << "Emit tile element for reduce " << MlirToString(unnested_hlo);
4604   int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
4605 
4606   InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
4607   std::vector<std::pair<llvm_ir::ElementGenerator, int>> extra_output_gens;
4608 
4609   // Construct the ElementGenerator for each reduction and extra output in the
4610   // the group of output instructions.
4611   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4612     for (int index : instr_index_group) {
4613       mlir::Operation* inst = GetReduceFromUnnestedMlir(unnested_hlo, index);
4614 
4615       const HloInstruction* hlo = fused_computation->root_instruction();
4616       if (hlo->opcode() == HloOpcode::kTuple) {
4617         hlo = hlo->operand(index);
4618       }
4619       if (IsReductionFromOrToContiguousDimensions(inst)) {
4620         input_gens.push_back(*fused_emitter->GetGenerator(hlo->operand(0)));
4621       } else {
4622         extra_output_gens.emplace_back(*fused_emitter->GetGenerator(hlo),
4623                                        index);
4624       }
4625     }
4626   } else {
4627     input_gens.push_back([&](const IrArray::Index& index) {
4628       return operand_ir_arrays[0].EmitReadArrayElement(index, &b_);
4629     });
4630   }
4631 
4632   IrArray::Index input_index =
4633       GetUnnormalizedIndex(index, reduction_operand_shape, &b_,
4634                            reduction_info.GetKernelMappingScheme());
4635   // Clear the linear index field of the IrArray::Index to enable the use of
4636   // GetElementPointer with array types. This enables the vectorization of
4637   // the computation for different partial results. Use this index if
4638   // 'num_partial_results > 1'.
4639   int num_partial_results = reduction_info.GetNumPartialResults();
4640   auto index_without_linear = IrArray::Index(
4641       input_index.multidim(), reduction_operand_shape, input_index.GetType());
4642 
4643   // Emit code to generate the input and perform the reduction computation for
4644   // each reduction instruction.
4645   for (int i = 0; i < reducers.size(); i++) {
4646     llvm::AllocaInst* input_address =
4647         reduction_info.GetReductionInputAddresses()[i];
4648     llvm::AllocaInst* partial_reduction_result_address =
4649         reduction_info.GetPartialResultAddresses()[i];
4650     llvm::Value* const input_ir_value =
4651         input_gens[i](num_partial_results > 1 ? index_without_linear
4652                                               : input_index)
4653             .ValueOrDie();
4654     Store(input_ir_value, input_address);
4655     llvm::Value* partial_result_address = InBoundsGEP(
4656         partial_reduction_result_address, {b_.getInt32(partial_result_index)});
4657     TF_CHECK_OK(EmitCallToNestedComputation(
4658         *reducers[i], {partial_result_address, input_address},
4659         partial_result_address));
4660   }
4661 
4662   // Emit code to generate the output for the non-reduction instructions in the
4663   // fusion, if any.
4664   TF_CHECK_OK(EmitExtraOutputsForReduce(
4665       result_ir_arrays, input_index,
4666       /*use_linear_index=*/num_partial_results == 1, extra_output_gens));
4667 }
4668 
EmitThreadId(int64 threads_per_block,llvm::Type * index_ty)4669 llvm::Value* IrEmitterUnnested::EmitThreadId(int64 threads_per_block,
4670                                              llvm::Type* index_ty) {
4671   // Calculate (y, x) coordinates respectively in the 2D view of thread block,
4672   // defined by (num_thread_y, num_thread_x) from thread_id.
4673   llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
4674       gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
4675   llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
4676   return b_.CreateIntCast(thread_id_raw, index_ty,
4677                           /*isSigned=*/true, "thread.id.x");
4678 }
4679 
EmitThreadIdInfo(int64 threads_per_block,llvm::Type * index_ty,int64 num_threads_x)4680 IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo(
4681     int64 threads_per_block, llvm::Type* index_ty, int64 num_threads_x) {
4682   auto constant = [&](uint64 c) -> llvm::Constant* {
4683     return llvm::ConstantInt::get(index_ty, c);
4684   };
4685   llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty);
4686   llvm::Value* num_threads_x_v = constant(num_threads_x);
4687   return {
4688       /*thread_id=*/thread_id,
4689       /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"),
4690       /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"),
4691       /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")};
4692 }
4693 
EmitTilingKernel(const KernelMappingScheme & mapping_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)4694 IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
4695     const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
4696     const TileElementGenerator& tile_element_generator) {
4697   absl::Span<const int64> dims_in_elems = mapping_scheme.GetDimsInElems();
4698   std::vector<int64> dims_in_blocks = {
4699       CeilOfRatio(dims_in_elems[0], mapping_scheme.GetTileSizeZ()),
4700       CeilOfRatio(dims_in_elems[1], mapping_scheme.GetTileSizeY()),
4701       CeilOfRatio(dims_in_elems[2], mapping_scheme.GetTileSizeX())};
4702   auto constant = [&](uint64 c) -> llvm::Constant* {
4703     return llvm::ConstantInt::get(index_ty, c);
4704   };
4705 
4706   IrEmitterUnnested::ThreadIdInfo thread_id_info =
4707       EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
4708                        mapping_scheme.GetNumThreadsX());
4709 
4710   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4711 
4712   const IrArray::Index block_coords = [&] {
4713     llvm::Value* block_id = EmitBlockId();
4714     llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(),
4715                               llvm::cast<llvm::Instruction>(block_id));
4716     llvm::Value* linear_block_id =
4717         b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
4718     IrArray::Index starting_block(linear_block_id,
4719                                   ShapeUtil::MakeShapeWithDescendingLayout(
4720                                       PRED /*arbitrary*/, dims_in_blocks),
4721                                   &b_);
4722 
4723     std::vector<llvm::Value*> multidim = {
4724         b_.CreateMul(starting_block[0], constant(mapping_scheme.GetTileSizeZ()),
4725                      "block_origin.z"),
4726         starting_block[1], starting_block[2]};
4727     return IrArray::Index(multidim, dims_in_blocks, index_ty);
4728   }();
4729 
4730   std::array<llvm::Value*, 3> output_tile_bounds;
4731   for (int i = kDimY; i < kDimTot; ++i) {
4732     int64 tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
4733     // Only last row or column may not have full size.
4734     llvm::Value* is_last =
4735         b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
4736     int64 partial_row =
4737         dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
4738     output_tile_bounds[i] =
4739         b_.CreateSelect(is_last, constant(partial_row),
4740                         constant(tile_size_for_dim), "tile_bound");
4741   }
4742 
4743   IrArray::Index tile_origin = [&] {
4744     std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
4745     llvm::Type* index_ty = block_coords.GetType();
4746     for (int i = kDimY; i < kDimTot; ++i) {
4747       elem_multi_index[i] = b_.CreateMul(
4748           block_coords[i],
4749           llvm::ConstantInt::get(index_ty, mapping_scheme.GetTileSizeFor(i)),
4750           "tile_origin." + std::to_string(i));
4751     }
4752     return IrArray::Index(elem_multi_index, mapping_scheme.GetDimsInElems(),
4753                           index_ty);
4754   }();
4755 
4756   auto emit_tile = [&](const IrArray::Index& tile) {
4757     tile_element_generator(thread_id_info, tile, "output",
4758                            output_tile_bounds[1], output_tile_bounds[2], &ksl);
4759   };
4760 
4761   if (mapping_scheme.GetTileSizeZ() == 1) {
4762     emit_tile(tile_origin);
4763   } else {
4764     llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
4765     llvm::Value* block_size_for_dim = constant(mapping_scheme.GetTileSizeZ());
4766     llvm::Value* block_id_for_dim =
4767         b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
4768     llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
4769     llvm::Value* last_block_size_for_dim =
4770         constant(dims_in_elems[kDimZ] -
4771                  (dims_in_blocks[kDimZ] - 1) * mapping_scheme.GetTileSizeZ());
4772 
4773     llvm::Value* num_tiles_in_block =
4774         b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
4775                         last_block_size_for_dim, block_size_for_dim);
4776     ksl.For("loop_z",
4777             /*start=*/constant(0),
4778             /*end=*/num_tiles_in_block,
4779             /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
4780               IrArray::Index tile_index = tile_origin.AddOffsetToDim(
4781                   block_dim_induction_var, kDimZ, &b_);
4782               emit_tile(tile_index);
4783             });
4784   }
4785   return {output_tile_bounds, tile_origin};
4786 }
4787 
EmitSyncThreads()4788 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
4789   return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
4790 }
4791 
4792 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
4793 // algorithm to improve the memory access patterns for the input parameters
4794 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
4795 // is responsible for making sure that it is safe to apply the shared memory
4796 // transpose on the input parameters.
4797 //
4798 //
4799 // For the purpose of tiling, the output tensors have a logical shape of three
4800 // components 0-2-1 while the relevant input parameters have a logical shape
4801 // of three components 0-1-2 in the order major to minor. The x- and y-
4802 // dimensions of the tensors are tiled in square tiles with an edge length
4803 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
4804 // transposes one tile: each thread copies kTileSize/kNumRows elements from
4805 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
4806 // reads from the shared memory instead of the original input.
4807 //
4808 // This is similar to the following CUDA algorithm in TensorFlow:
4809 // https://goo.gl/MStRV6.
4810 //
4811 // `kTileSize` should usually be same as warp size. We currently choose 32 for
4812 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
4813 //
4814 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
4815 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(mlir::Operation * op,Thunk * kernel_thunk,const MlirEmitterContext & context,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids)4816 void IrEmitterUnnested::EmitHlo021Tile(
4817     mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context,
4818     absl::Span<const llvm_ir::IrArray> operand_arrays,
4819     absl::Span<const llvm_ir::IrArray> output_arrays,
4820     absl::Span<const int64> reduced_output_dims,
4821     absl::Span<const int64> tiled_param_ids) {
4822   constexpr int kNumRows = 4;
4823 
4824   std::string name = mlir::GetNameFromLoc(op->getLoc());
4825 
4826   KernelMappingScheme mapping_scheme(reduced_output_dims,
4827                                      /*tile_sizes=*/{1, kWarpSize, kWarpSize},
4828                                      /*num_threads_y=*/kNumRows,
4829                                      /*num_threads_x=*/kWarpSize,
4830                                      /*indexing_order=*/kLinearIndexingX,
4831                                      /*vector_size=*/1,
4832                                      /*is_row_contiguous=*/false);
4833   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
4834                                      mapping_scheme.GetThreadsPerBlock());
4835 
4836   llvm::Type* index_type =
4837       GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_);
4838   std::vector<IrArray> param_arrays;
4839 
4840   // For each tiled parameter, cast its input IrArray to the corresponding
4841   // reduced shape and keep the reduced shape live during IR emission.
4842   std::vector<IrArray> param_in_reduced_shape_arrays;
4843   std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(),
4844                                                 nullptr);
4845 
4846   auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
4847                                       absl::string_view buffer_name) {
4848     // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
4849     // is organized into 32-way. We usually use the warp size or a multiplier or
4850     // a the warp size as the size for tiling. This may cause all elements in
4851     // the same column of a tile use the same memory bank and therefore shared
4852     // memory bank conflicts. Adding 1 to the minor dimension of the shared
4853     // memory buffer can reduce such shared memory bank conflicts.
4854     llvm::Type* buffer_type = llvm::ArrayType::get(
4855         llvm::ArrayType::get(elem_ty, mapping_scheme.GetTileSizeX() + 1),
4856         mapping_scheme.GetTileSizeY());
4857     return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4858                                              buffer_type, buffer_name);
4859   };
4860 
4861   for (int64 id = 0; id < context.operand_shapes.size(); id++) {
4862     const Shape& param_shape = context.operand_shapes[id];
4863     param_arrays.push_back(operand_arrays[id]);
4864 
4865     if (absl::c_linear_search(tiled_param_ids, id)) {
4866       param_shmem_buffers[id] = get_shared_memory_buffer(
4867           llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
4868           IrName(name, StrCat("tile", id)));
4869       VLOG(3) << "Added shmem buffer for parameter " << id << ": "
4870               << llvm_ir::DumpToString(*param_shmem_buffers[id]);
4871       Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4872           param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1}));
4873       param_in_reduced_shape_arrays.push_back(
4874           param_arrays[id].CastToShape(reduced_shape, &b_));
4875     } else {
4876       param_in_reduced_shape_arrays.push_back(IrArray());
4877     }
4878   }
4879 
4880   EmitElementFunction element_generator =
4881       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4882           llvm::Value* x_loc, int64 x_iter_num) {
4883         if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) {
4884           CHECK_EQ(1, context.output_shapes.size());
4885           EmitTileElementForCopy(context.output_shapes[0], output_arrays[0],
4886                                  index, mapping_scheme, y_loc, x_loc,
4887                                  param_shmem_buffers);
4888         } else if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
4889           EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
4890                                    mapping_scheme, y_loc, x_loc,
4891                                    param_shmem_buffers);
4892         } else {
4893           LOG(FATAL) << "Unexpected op: " << MlirToString(op);
4894         }
4895       };
4896 
4897   TileElementGenerator tile_generator =
4898       [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4899           const string& loop_name, llvm::Value* tile_height,
4900           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
4901         // If shared memory transpose is needed, wait for all threads to reach
4902         // this point, lest we copy a value from tile to output before the other
4903         // thread copies it from input to tile. This is `__syncthreads` in CUDA.
4904         if (!tiled_param_ids.empty()) {
4905           // Calculate the input tile origin from the output tile origin.
4906           const IrArray::Index input_tile_origin(
4907               Permute(index.multidim(), {0, 2, 1}),
4908               Permute(index.dims(), {0, 2, 1}), index.GetType());
4909 
4910           // Copy input parameter values to shared memory buffers:
4911           // tile[thread_id_y, thread_id_x] = input[index]
4912           // Note that tile_width and tile_height are flipped here because we
4913           // are reading a transposed tile.
4914           EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
4915                    thread_id_info, tile_width, tile_height,
4916                    [&](const IrArray::Index& index, llvm::Value* y_loc,
4917                        llvm::Value* x_loc, int64 /*x_iter_num*/) {
4918                      for (int64 id : tiled_param_ids) {
4919                        IrArray& input_in_logical_shape =
4920                            param_in_reduced_shape_arrays.at(id);
4921 
4922                        llvm::Value* shmem_buffer = param_shmem_buffers.at(id);
4923                        llvm::Value* zero =
4924                            llvm::ConstantInt::get(index_type, 0);
4925                        // TODO(jlebar): Add AA metadata to this store.  Tile
4926                        // buffers are global variables, so LLVM can't infer much
4927                        // about it.
4928                        auto value = input_in_logical_shape.EmitReadArrayElement(
4929                            index, &b_, "input_element");
4930                        auto addr = GEP(shmem_buffer, {zero, y_loc, x_loc});
4931                        Store(value, addr);
4932                      }
4933                    });
4934 
4935           // Wait for all threads to reach this point using `__syncthreads` in
4936           // CUDA.
4937           EmitSyncThreads();
4938         }
4939 
4940         EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
4941                  tile_height, tile_width, element_generator);
4942         bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
4943 
4944         // If a tile block contains multiple tiles and shared memory buffers are
4945         // used, we need to wait for all threads to finish using the shared
4946         // memory buffer for the current tile before we move on to process the
4947         // next tile and overwrite the shared memory buffers.
4948         if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
4949           EmitSyncThreads();
4950         }
4951       };
4952 
4953   // For multioutput fusion, one thread needs to output a tuple
4954   // with pointers to all the individual outputs.  We could do this
4955   // at any point in the kernel, but we do it at the beginning in
4956   // the hopes of reducing register pressure, since we touch
4957   // threadIdx.x and blockIdx.x at the beginning of the kernel
4958   // *anyway*.
4959   if (output_arrays.size() > 1) {
4960     KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
4961       llvm_ir::EmitTuple(output_arrays.back(),
4962                          output_arrays.subspan(0, output_arrays.size() - 1),
4963                          &b_);
4964     });
4965   }
4966 
4967   EmitTilingKernel(mapping_scheme, index_type, tile_generator);
4968   UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
4969                          ir_emitter_context_->llvm_module());
4970 }
4971 
4972 namespace {
4973 
4974 // A recursive function to inspect the users of a parameter to determine
4975 // whether it's safe for a parameter to participate in a shared-memory
4976 // transpose.
4977 //
4978 // Consider a fusion parameter P for which we might want to use a shmem
4979 // transpose.  If we do, we use a GPU thread block to preload a tile of P with
4980 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
4981 // cooperatively, where z, y, x are the indices for the normalized input/output
4982 // tensor (see the document for FindTranspose021 for the definition of
4983 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
4984 // requires that the computation of the output tile only read elements within
4985 // the preload tile. If this is not true, we can't use a shmem transpose for P.
4986 //
4987 // If the computation of output element [z, y, x] only requires the element of
4988 // P with the same indices, the shmem transpose implementation can be applied
4989 // to P safely. This is a sufficient but not necessary condition. We check all
4990 // the transitive users of P to see if we can find a user that may cause an
4991 // exception to the situation. If such a user is not found, we conclude that P
4992 // is safe for shmem transpose.
4993 //
4994 // This is trivially true for elementwise operations and some "data-movement"
4995 // ops like kTuple. However, it's not true for operations that can change the
4996 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
4997 // For example:
4998 //
4999 // fused_computation {
5000 //   param_0 = f32[64,64]{1,0} parameter(0)
5001 //   ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
5002 // }
5003 // The output element at logical address [0, 63] depends on the input element
5004 // at logical address [63, 0], which would not be within the shared-memory
5005 // block.
5006 //
5007 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
5008 // with transpose, we only need to end the use-chain checking with the input of
5009 // a reduce operations. In this case, the above description on "output" apply
5010 // to the result of such a use-chain, which provides the input to the reduce
5011 // operation.
IsInstructionSafeForShmemTranspose(mlir::Operation * op)5012 bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
5013   if (mlir::isa<mlir::TensorStoreOp>(op)) {
5014     return true;
5015   }
5016 
5017   HloOpcode opcode;
5018   if (mlir::isa<mlir::TensorLoadOp>(op)) {
5019     opcode = HloOpcode::kParameter;
5020   } else {
5021     opcode = *MhloToHloOpcode(op);
5022   }
5023   if (HloInstruction::IsOpElementwise(opcode)) {
5024     for (mlir::Value v : op->getResults()) {
5025       for (mlir::OpOperand use : v.getUsers()) {
5026         if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
5027           return false;
5028         }
5029       }
5030     }
5031     return true;
5032   }
5033 
5034   switch (opcode) {
5035     // Non-elementwise instructions that don't cause the shmem transpose
5036     // to be unsafe, including the instructions that don't currently fuse.
5037     case HloOpcode::kGetDimensionSize:
5038       // The result of the operation doesn't rely on the content of the
5039       // tensor. As such, there is no need to further inspect its users.
5040       return true;
5041     case HloOpcode::kGetTupleElement:
5042     case HloOpcode::kMap:
5043     case HloOpcode::kParameter:
5044     case HloOpcode::kTuple:
5045     case HloOpcode::kTupleSelect:
5046       for (mlir::Value v : op->getResults()) {
5047         for (mlir::OpOperand use : v.getUsers()) {
5048           if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
5049             return false;
5050           }
5051         }
5052       }
5053       return true;
5054 
5055     default:
5056       return false;
5057   }
5058 }
5059 
5060 // Given a group of input parameters that are 0-2-1 transpose of the outputs of
5061 // a fusion kernel, returns the input parameters that are safe for the shared
5062 // memory transpose implementation.
5063 //
5064 // When a tile based shared memory transpose is used to implement an input with
5065 // 0-2-1 transpose, we preload a tile of the input elements
5066 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
5067 // indices. Preloading the input tile this way is only safe when the computation
5068 // of the output tile elements do not need any input element outside the
5069 // preloaded tile. We inspect all the transitive users of the input parameter
5070 // up to the fusion root instruction to see if we can find any instruction
5071 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,std::vector<int64> input_ids)5072 std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
5073                                                  std::vector<int64> input_ids) {
5074   std::vector<mlir::Value> params = ToStdVector(fusion.getFusionParameters());
5075 
5076   std::vector<int64> filtered_input_ids;
5077   for (int64 input_id : input_ids) {
5078     mlir::Value input = params.at(input_id);
5079     if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) {
5080       filtered_input_ids.push_back(input_id);
5081     }
5082   }
5083   return filtered_input_ids;
5084 }
5085 
5086 }  // namespace
5087 
CheckAndEmitHloWithTile021(MlirEmitterInput input)5088 StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
5089     MlirEmitterInput input) {
5090   CHECK((mlir::isa<mlir::lmhlo::FusionOp, mlir::lmhlo::CopyOp>(input.op)));
5091 
5092   MlirEmitterContext context;
5093   context.SetOperation(input.op);
5094 
5095   // If the output_shape is reduced to 021 shape, find all the parameters of
5096   // the HLO that are in the corresponding 012 shape.
5097   std::vector<int64> params_012;
5098   optional<std::vector<int64>> reduced_dims_021;
5099   for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size();
5100        ++operand_idx) {
5101     const Shape& operand_shape = context.operand_shapes[operand_idx];
5102     auto find_transpose_result =
5103         ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]);
5104     if (!find_transpose_result.has_value()) {
5105       continue;
5106     }
5107     const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
5108     if (!reduced_dims_021.has_value()) {
5109       reduced_dims_021 = curr_reduced_dims_021;
5110     }
5111     if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
5112       // There is more than one possible transpose. Instead of picking one
5113       // transpose, we simply give up here.
5114       return false;
5115     }
5116     params_012.push_back(operand_idx);
5117   }
5118 
5119   if (!reduced_dims_021.has_value()) {
5120     return false;
5121   }
5122 
5123   if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
5124       (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
5125     return false;
5126   }
5127 
5128   if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(input.op)) {
5129     params_012 = FilterInputsForShmemTranspose(fusion_op, params_012);
5130     if (params_012.empty()) {
5131       return false;
5132     }
5133   }
5134 
5135   // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
5136   // elements are of size 4 bytes), and CUDA has an architectural limit of
5137   // 48kb shared memory per SM.  (This is increased to 96kb in Volta, but we
5138   // don't use this, in part because it eats into our L1 cache space.)
5139   //
5140   // For correctness we need to ensure that we don't make more than 48kb worth
5141   // of shmem tiles per block.  And for performance, we'd probably like to use
5142   // significantly less, so that we can fit more than one block at a time on a
5143   // gpu core.
5144   //
5145   // We say without benchmarks that we want at least 3 threads/block,
5146   // corresponding to 3 shmem tiles if the elements are 32 bits wide.  We
5147   // choose which params get the shmem transpose treatment arbitrarily; it's
5148   // not clear if there's a Right Choice.
5149   //
5150   // This is only sound if tiled transposes are the only place where we use
5151   // shared memory in fusions.  If in the future other fusible ops use shared
5152   // memory, we'll have to adjust this heuristic.
5153   constexpr int kMinBlocksPerCore = 3;
5154   constexpr int64 kShmemPerCore = 48 * 1024;
5155   int64 shmem_used = 0;
5156   for (int64 i = 0; i < params_012.size(); ++i) {
5157     const Shape& operand_shape = context.operand_shapes[params_012[i]];
5158     shmem_used +=
5159         32 * 33 *
5160         ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
5161 
5162     if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
5163       // Erase this element and everything after it from params_012.
5164       params_012.resize(i);
5165       break;
5166     }
5167   }
5168 
5169   if (params_012.empty()) {
5170     return false;
5171   }
5172 
5173   std::vector<llvm_ir::IrArray> ir_arrays;
5174   TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
5175                       BuildKernelThunkForMlir(input.op, input.thunk_info,
5176                                               input.extra_slice, &ir_arrays));
5177   EmitHlo021Tile(
5178       input.op, kernel_thunk.get(), context,
5179       absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
5180       absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
5181       *reduced_dims_021, params_012);
5182   AddThunkToThunkSequence(std::move(kernel_thunk));
5183   return true;
5184 }
5185 
5186 namespace {
5187 
5188 // Returns true if all the transitive users of hlo before hitting users in
5189 // use_chain_endings are elementwise operations.
AreUsersElementwise(mlir::Value value,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)5190 bool AreUsersElementwise(
5191     mlir::Value value,
5192     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
5193   return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) {
5194     mlir::Operation* user = use.getOwner();
5195     CHECK_EQ(1, user->getNumResults());
5196     return use_chain_endings.count(user) ||
5197            (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) &&
5198             AreUsersElementwise(user->getResult(0), use_chain_endings));
5199   });
5200 }
5201 
5202 // Returns the number of fusion inputs that have the same dimension as the
5203 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(mlir::lmhlo::FusionOp fusion,const Shape & op_shape,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)5204 int64 NumInputsInvolveInOnlyElementwiseOps(
5205     mlir::lmhlo::FusionOp fusion, const Shape& op_shape,
5206     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
5207   return absl::c_count_if(
5208       fusion.getFusionParameters(), [&](mlir::Value parameter) {
5209         Shape parameter_shape = TypeToShape(parameter.getType());
5210         return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
5211                AreUsersElementwise(parameter, use_chain_endings);
5212       });
5213 }
5214 
5215 // Returns the number of fusion inputs that have more elements than the given
5216 // shape.
NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,const Shape & shape)5217 int64 NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
5218                                     const Shape& shape) {
5219   int64 num_elements = ShapeUtil::ElementsIn(shape);
5220   return absl::c_count_if(
5221       fusion.getFusionParameters(), [&](mlir::Value parameter) {
5222         Shape parameter_shape = TypeToShape(parameter.getType());
5223         return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
5224       });
5225 }
5226 
5227 // The benefit of unrolling a kInput fusion that is a column reduction comes
5228 // from the vectorization of non-reduction fusion outputs and fusion inputs.
5229 // On the other hand, unrolling can also introduce factors that can cause
5230 // the kernel to run slower. This routine uses a simple heuristic to estimate
5231 // the benefit as well as the overhead of unrolling in order to decide whether
5232 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(mlir::Operation * unnested_hlo,const Shape & input_shape,int64 num_kept_minor)5233 bool IsUnrollingColumnReductionBeneficial(mlir::Operation* unnested_hlo,
5234                                           const Shape& input_shape,
5235                                           int64 num_kept_minor) {
5236   // TODO(b/122468062): Need further investigate to see whether we can
5237   // remove the constraint on IsPowerOfTwo.
5238   if (!IsPowerOfTwo(static_cast<uint64>(num_kept_minor))) {
5239     return false;
5240   }
5241 
5242   if (IsReductionFromOrToContiguousDimensions(unnested_hlo)) {
5243     return true;
5244   }
5245 
5246   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
5247   int64 can_be_vectorized = 0;
5248   int64 cannot_be_vectorized = 0;
5249   auto fusion_results = ToStdVector(fusion.getFusionResults());
5250   absl::flat_hash_set<mlir::Operation*> use_chain_endings;
5251   if (fusion_results.size() == 1) {
5252     if (IsReductionFromOrToContiguousDimensions(
5253             fusion_results[0].getDefiningOp())) {
5254       use_chain_endings.insert(fusion_results[0].getDefiningOp());
5255       // Atomic.add of the reduction result can't be vectorized.
5256       cannot_be_vectorized++;
5257     }
5258   } else {
5259     for (mlir::Value result : fusion_results) {
5260       if (IsReductionFromOrToContiguousDimensions(result.getDefiningOp())) {
5261         // Atomic.add of the reduction result can't be vectorized.
5262         cannot_be_vectorized++;
5263       } else {
5264         // Write of the non-reduction result can be vectorized.
5265         can_be_vectorized++;
5266       }
5267       use_chain_endings.insert(result.getDefiningOp());
5268     }
5269   }
5270   // Fusion inputs that have the same dimension as the reduce input and
5271   // only involve in elementwise operations can be vectorized.
5272   can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape,
5273                                                             use_chain_endings);
5274   // Fusion inputs with more elements than the reduce op input must participate
5275   // in non-elementwise operations and we assume that they are not vectorizable
5276   // for the purpose of estimating the benefit of unrolling. If the kernel is
5277   // unrolled even with such an assumption,  and the accesses to those inputs
5278   // turn out to be vectorizable, the compiler will still vectorize them.
5279   cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
5280   return can_be_vectorized >= cannot_be_vectorized;
5281 }
5282 
NearestPowerOfTwo(int64 v)5283 int64 NearestPowerOfTwo(int64 v) {
5284   if (v < 0) {
5285     return 0;
5286   }
5287   int64 upper = tensorflow::NextPowerOfTwo64(v);
5288   int64 lower = upper >> 1;
5289   return upper - v < v - lower ? upper : lower;
5290 }
5291 
5292 }  // namespace
5293 
ComputeReductionCodegenInfo(mlir::Operation * unnested_hlo,mlir::Operation * first_reduce)5294 ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
5295     mlir::Operation* unnested_hlo, mlir::Operation* first_reduce) {
5296   Shape input_shape = TypeToShape(first_reduce->getOperand(0).getType());
5297   ReductionDimensions reduction_dimensions =
5298       GetReductionKindAndContiguousComponents(first_reduce);
5299   VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
5300            << " " << reduction_dimensions.dimensions[0] << " "
5301            << reduction_dimensions.dimensions[1] << " "
5302            << reduction_dimensions.dimensions[2];
5303   auto get_dtype_bits = [](mlir::Value i) {
5304     // TODO(timshen): may not be efficient.
5305     return primitive_util::BitWidth(TypeToShape(i.getType()).element_type());
5306   };
5307 
5308   // For fusion with multiple inputs, use the smallest input dtype to
5309   // select the reduction_tiling.
5310   int smallest_input_dtype_bits = get_dtype_bits(first_reduce->getOperand(0));
5311 
5312   for (mlir::Value operand : GetHloOperands(unnested_hlo)) {
5313     smallest_input_dtype_bits =
5314         std::min(get_dtype_bits(operand), smallest_input_dtype_bits);
5315   }
5316   std::array<int64, 3> reduction_tiling =
5317       GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
5318                          ir_emitter_context_->cuda_compute_capability());
5319 
5320   int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
5321   int64 num_threads_x = [&] {
5322     if (reduction_dimensions.is_row_reduction) {
5323       // Use 512 as default block size (threads per block) for row reductions.
5324       // For multi-output fusions, reduce the block size further to decrease
5325       // register pressure when multiple outputs are computed by each thread.
5326       int64 fan_out = 1;
5327       if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
5328         fan_out = fusion.getFusionResults().size();
5329       }
5330 
5331       // 64 is the general advice as the smallest block sizes.
5332       // Moreover, XLA:GPU emitters need at least 32 threads at some places.
5333       int64 max_block_size = std::max(64LL, 512LL / NearestPowerOfTwo(fan_out));
5334       return std::min(
5335           max_block_size,
5336           RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
5337                                        reduction_tiling[2]),
5338                            kWarpSize));
5339     }
5340     return kWarpSize;
5341   }();
5342 
5343   bool tile_fit = reduction_dimensions.dimensions[kDimX] %
5344                       (reduction_tiling[2] * num_threads_x) ==
5345                   0;
5346 
5347   int cc_major = 0;
5348   if (ir_emitter_context_->cuda_compute_capability()) {
5349     cc_major = ir_emitter_context_->cuda_compute_capability()->cc_major;
5350   }
5351 
5352   int num_partial_results = 1;
5353   KernelMappingScheme::IndexingOrder indexing_order = [&]() {
5354     if (reduction_dimensions.is_row_reduction &&
5355         // P100, only try to vectorize+coales memory access when the
5356         // tile size fits exactly and dtypes <= 32 bits
5357         ((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
5358          // On V100, only try to vectorize+coales memory access for
5359          // rows of even size.  For odd row sizes, every other row
5360          // isn't aligned, so it can't be vectorized.
5361          (cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
5362       return kStridedLinearIndexingX;
5363     } else if (!reduction_dimensions.is_row_reduction &&
5364                IsUnrollingColumnReductionBeneficial(
5365                    unnested_hlo, input_shape,
5366                    reduction_dimensions.dimensions[2])) {
5367       num_partial_results = 2;
5368       reduction_tiling[2] *= num_partial_results;
5369       return kLinearIndexingX;
5370     } else {
5371       return kStridedIndexingX;
5372     }
5373   }();
5374 
5375   int vector_size = 1;
5376   if (indexing_order == kStridedLinearIndexingX) {
5377     // Assuming XLA will perform the unrolling and LLVM will vectorize,
5378     // disable the unroll for the cases that LLVM doesn't vectorize.
5379     if (reduction_dimensions.dimensions[2] % 2 == 0 &&
5380         !MayPreventVectorization(unnested_hlo)) {
5381       vector_size = 2;
5382     } else {
5383       indexing_order = kStridedIndexingX;
5384     }
5385   }
5386   KernelMappingScheme mapping_scheme(
5387       reduction_dimensions.dimensions,
5388       {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
5389        reduction_tiling[2] * num_threads_x},
5390       num_threads_y, num_threads_x, indexing_order, vector_size);
5391   return ReductionCodegenInfo(mapping_scheme, num_partial_results,
5392                               reduction_dimensions.is_row_reduction);
5393 }
5394 
EmitIRForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> operand_ir_arrays,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenInfo * reduction_info,const Shape & input_shape)5395 void IrEmitterUnnested::EmitIRForReduction(
5396     mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
5397     HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
5398     absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
5399     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
5400     ReductionCodegenInfo* reduction_info, const Shape& input_shape) {
5401   std::vector<HloComputation*> reducers;
5402   for (auto index : instr_index_group) {
5403     auto reduce = GetReduceFromUnnestedMlir(unnested_hlo, index);
5404     if (!IsReductionFromOrToContiguousDimensions(reduce)) {
5405       continue;
5406     }
5407     if (auto unnested_reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(reduce)) {
5408       reducers.push_back(
5409           *GetOrCreateSubComputationFromRegion(&unnested_reduce.body(),
5410                                                /*is_fusion=*/false));
5411     } else if (auto nested_reduce =
5412                    mlir::dyn_cast<mlir::mhlo::ReduceOp>(reduce)) {
5413       HloInstruction* root = fused_computation->root_instruction();
5414       if (root->opcode() == HloOpcode::kTuple) {
5415         root = root->mutable_operand(index);
5416       } else {
5417         CHECK_EQ(0, index);
5418       }
5419       reducers.push_back(root->to_apply());
5420     } else {
5421       LOG(FATAL) << "Unexpected reduce op: " << MlirToString(reduce);
5422     }
5423   }
5424   CHECK(!reducers.empty()) << " expect at least one reduce instructions.";
5425 
5426   const KernelMappingScheme& mapping_scheme =
5427       reduction_info->GetKernelMappingScheme();
5428   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
5429                                      mapping_scheme.GetThreadsPerBlock());
5430   llvm::Type* index_ty = GetIndexTypeForKernelFromMlir(
5431       unnested_hlo, launch_dimensions.launch_bound(), &b_);
5432   EmitPrologueForReduction(unnested_hlo, instr_index_group, fused_computation,
5433                            fused_emitter, operand_ir_arrays, result_ir_arrays,
5434                            reduction_info);
5435 
5436   EmitElementFunction emit_reduction_tile =
5437       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
5438           llvm::Value* x_loc, int64 x_iter_num) {
5439         EmitTileElementForReduction(
5440             unnested_hlo, input_shape, instr_index_group, fused_computation,
5441             fused_emitter, operand_ir_arrays, result_ir_arrays, reducers, index,
5442             *reduction_info, x_iter_num);
5443       };
5444 
5445   TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
5446       mapping_scheme, index_ty,
5447       [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
5448           const string& loop_name, llvm::Value* tile_height,
5449           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
5450         EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name,
5451                  ksl, thread_id_info, tile_height, tile_width,
5452                  emit_reduction_tile);
5453       });
5454   EmitEpilogueForReduction(index_ty, unnested_hlo, instr_index_group,
5455                            result_ir_arrays, reducers, *reduction_info,
5456                            tiling_kernel_info);
5457 }
5458 
5459 namespace {
5460 
5461 // Returns whether the `instr` is either a constant, a scalar, or a
5462 // broadcasted constant/scalar.
IsBroadcastedConstantOrScalar(const HloInstruction & instr)5463 bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
5464   return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
5465          (HloOpcode::kBroadcast == instr.opcode() &&
5466           (instr.operand(0)->IsConstant() ||
5467            ShapeUtil::IsScalar(instr.operand(0)->shape())));
5468 }
5469 
5470 // Divides `num_reduces` reduces into groups. Different groups will be executed
5471 // in parallel. Generally speaking, we'd like to run the reduce instructions
5472 // in parallel without incurring too much recomputation overhead. The current
5473 // heuristic is to place reduce instructions who share nothing or only
5474 // (broadcasted) scalars/constants into different groups; otherwise, they are
5475 // placed in the same group. Non-reduce instructions always go with the reduce
5476 // instructions into the same group so long as they share any predecessors.
DivideOutputInstructionsIntoGroups(HloComputation * fused_computation,int num_reduces)5477 std::vector<std::vector<int>> DivideOutputInstructionsIntoGroups(
5478     HloComputation* fused_computation, int num_reduces) {
5479   CHECK_NE(0, num_reduces);
5480   if (num_reduces == 1) {
5481     return {{0}};
5482   }
5483 
5484   std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
5485       num_reduces);
5486   for (size_t i = 0; i < num_reduces; ++i) {
5487     disjoint_sets[i].Get() =
5488         fused_computation->root_instruction()->mutable_operand(i);
5489   }
5490 
5491   std::unique_ptr<HloReachabilityMap> reachability_map =
5492       HloReachabilityMap::Build(fused_computation);
5493   for (auto* instr : fused_computation->instructions()) {
5494     std::vector<int64> reached_output_ids;
5495     for (size_t oid = 0; oid < num_reduces; ++oid) {
5496       auto reduce = fused_computation->root_instruction()->mutable_operand(oid);
5497       if (HloOpcode::kReduce == reduce->opcode() &&
5498           (IsBroadcastedConstantOrScalar(*instr))) {
5499         // Do not group output reduce instructions through broadcasted
5500         // constants or scalars, as the recomputation should be acceptable.
5501         VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString();
5502         continue;
5503       }
5504       // Now group output instructions if they have common predecessors.
5505       if (reachability_map->IsReachable(instr, reduce)) {
5506         VLOG(3) << "Reaching " << reduce->ToString() << " from "
5507                 << instr->ToString();
5508         reached_output_ids.push_back(oid);
5509       }
5510     }
5511     for (size_t j = 1; j < reached_output_ids.size(); ++j) {
5512       disjoint_sets[reached_output_ids[0]].Merge(
5513           &disjoint_sets[reached_output_ids[j]]);
5514     }
5515   }
5516   // Place output instructions in the same set into the same group.
5517   absl::flat_hash_map<HloInstruction*, std::vector<int>> groups;
5518   for (size_t oid = 0; oid < num_reduces; ++oid) {
5519     groups[disjoint_sets[oid].Get()].push_back(oid);
5520   }
5521 
5522   std::vector<std::vector<int>> ret;
5523   absl::c_for_each(
5524       groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
5525   return ret;
5526 }
5527 
5528 }  // namespace
5529 
EmitReductionFromOrToContiguousDimensions(MlirEmitterInput mlir_input)5530 Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
5531     MlirEmitterInput mlir_input) {
5532   mlir::Operation* unnested_hlo = mlir_input.op;
5533   auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
5534 
5535   int num_reduces = 1;
5536   if (fusion) {
5537     num_reduces = fusion.getFusionResults().size();
5538   }
5539 
5540   bool returns_tuple = num_reduces > 1;
5541   VLOG(10) << "Emitting reduction to vector " << MlirToString(unnested_hlo);
5542 
5543   // Build an initializer thunk to initialize each reduction output.
5544   std::vector<std::unique_ptr<Thunk>> thunks;
5545   for (int i = 0; i < num_reduces; ++i) {
5546     mlir::Operation* output_instruction =
5547         GetReduceFromUnnestedMlir(unnested_hlo, i);
5548     if (!IsReductionFromOrToContiguousDimensions(output_instruction)) {
5549       continue;
5550     }
5551 
5552     if (fusion) {
5553       TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
5554                           BuildFusedInitializerThunkForMlir(fusion, i));
5555       thunks.push_back(std::move(initializer_thunk));
5556     } else {
5557       auto reduce = mlir::cast<mlir::lmhlo::ReduceOp>(output_instruction);
5558 
5559       TF_RET_CHECK(!returns_tuple);
5560       TF_ASSIGN_OR_RETURN(
5561           std::unique_ptr<Thunk> initializer_thunk,
5562           BuildInitializerThunkForMlir(reduce, reduce.init_values()[0],
5563                                        reduce.out()[0]));
5564       thunks.push_back(std::move(initializer_thunk));
5565     }
5566   }
5567 
5568   // Build a kernel thunk to compute all the outputs.
5569   mlir::Operation* first_reduce = nullptr;
5570   for (int i = 0; i < num_reduces; ++i) {
5571     if (IsReductionFromOrToContiguousDimensions(
5572             GetReduceFromUnnestedMlir(unnested_hlo, i))) {
5573       first_reduce = GetReduceFromUnnestedMlir(unnested_hlo, i);
5574       break;
5575     }
5576   }
5577   CHECK(first_reduce) << MlirToString(unnested_hlo);
5578   if (num_reduces > 1) {
5579     for (int i = 0; i < num_reduces; i++) {
5580       auto candidate = mlir::dyn_cast<mlir::mhlo::ReduceOp>(
5581           GetReduceFromUnnestedMlir(unnested_hlo, i));
5582       if (candidate &&
5583           !IsFusedReductionOutputConsistent(
5584               candidate, mlir::cast<mlir::mhlo::ReduceOp>(first_reduce))) {
5585         return InternalError("Inconsistent reduction fusion outputs");
5586       }
5587     }
5588   }
5589   Shape input_shape = TypeToShape(first_reduce->getOperand(0).getType());
5590   // The layout of a reduction input is either set by LayoutAssignment for
5591   // unnested kReduce or by InstructionFusion for fused kReduce.
5592   CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
5593                                      "doesn't set the input layout of "
5594                                   << MlirToString(first_reduce);
5595 
5596   std::vector<llvm_ir::IrArray> ir_arrays;
5597   TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
5598                       BuildKernelThunkForMlir(unnested_hlo, Thunk::ThunkInfo(),
5599                                               {}, &ir_arrays));
5600 
5601   HloComputation* fused_computation = nullptr;
5602   if (fusion) {
5603     TF_ASSIGN_OR_RETURN(fused_computation, GetOrCreateSubComputationFromRegion(
5604                                                &fusion.region(),
5605                                                /*is_fusion=*/true));
5606   }
5607 
5608   // Group output instructions. Each group will be executed in parallel.
5609   std::vector<std::vector<int>> instr_index_groups =
5610       DivideOutputInstructionsIntoGroups(fused_computation, num_reduces);
5611 
5612   VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
5613                     MlirToString(unnested_hlo));
5614 
5615   absl::optional<GpuElementalIrEmitter> elemental_emitter;
5616   absl::optional<FusedIrEmitter> optional_fused_emitter;
5617   FusedIrEmitter* fused_emitter = nullptr;
5618 
5619   absl::Span<const llvm_ir::IrArray> operand_ir_arrays;
5620   absl::Span<const llvm_ir::IrArray> result_ir_arrays;
5621   if (fusion) {
5622     elemental_emitter.emplace(hlo_module_config_,
5623                               ir_emitter_context_->llvm_module(), &b_,
5624                               GetNestedComputer());
5625     optional_fused_emitter.emplace(&*elemental_emitter);
5626     fused_emitter = &*optional_fused_emitter;
5627 
5628     CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
5629     for (int i = 0; i < fused_computation->num_parameters(); i++) {
5630       auto ir_array = ir_arrays[i];
5631       auto fused_operand = fused_computation->parameter_instruction(i);
5632       fused_emitter->BindGenerator(
5633           fused_operand, [this, ir_array,
5634                           fused_operand](const llvm_ir::IrArray::Index& index) {
5635             return ir_array.EmitReadArrayElement(index, &b_,
5636                                                  fused_operand->name());
5637           });
5638     }
5639     result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(
5640         fused_computation->num_parameters(), num_reduces);
5641   } else {
5642     CHECK_EQ(3, ir_arrays.size());
5643     operand_ir_arrays = absl::MakeSpan(ir_arrays).subspan(0, 2);
5644     result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(2);
5645   }
5646 
5647   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5648   for (size_t i = 0; i < instr_index_groups.size(); ++i) {
5649     // Create a new ReductionCodegenInfo instance as it contains states for
5650     // code generation per reduction group. For now, let's always use the very
5651     // first reduce as representative to construct ReductionCodegenInfo, since
5652     // all the reductions are required to have the same shape and layout as
5653     // verified by `IsFusedReductionOutputConsistent()`. We can loosen the
5654     // constraint later when the needs arise.
5655     ReductionCodegenInfo reduction_info =
5656         ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
5657     auto emit_reduction_func = [&] {
5658       EmitIRForReduction(unnested_hlo, instr_index_groups[i], fused_computation,
5659                          fused_emitter, operand_ir_arrays, result_ir_arrays,
5660                          &reduction_info, input_shape);
5661     };
5662     // Use raw block_id_y to select the i-th parallel reduction to run. Using
5663     // block_id_y instead of block_id_x simplifies the index calculation
5664     // for reduction code generation as the block_id_y is orthogonal to
5665     // the indices used within the reductions.
5666     llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
5667         gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
5668     llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
5669                               llvm::cast<llvm::Instruction>(raw_block_id_y));
5670     llvm::Value* guarding_cond =
5671         b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i));
5672     ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func);
5673   }
5674   ReductionCodegenInfo reduction_info =
5675       ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
5676   const KernelMappingScheme& mapping_scheme =
5677       reduction_info.GetKernelMappingScheme();
5678   // block_y_count is set to instr_index_groups.size(), so that each reduction
5679   // group can be run in parallel by a different BlockIdy.
5680   LaunchDimensions launch_dimensions(
5681       {/*x=*/mapping_scheme.GetNumberOfBlocks(),
5682        /*y=*/static_cast<int64>(instr_index_groups.size()),
5683        /*z=*/1},
5684       {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
5685   VLOG(3) << "Launch dimensions of "
5686           << mlir::GetNameFromLoc(unnested_hlo->getLoc())
5687           << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
5688           << " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
5689   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
5690                          ir_emitter_context_->llvm_module());
5691 
5692   thunks.push_back(std::move(kernel_thunk));
5693   std::unique_ptr<SequentialThunk> sequential_thunk =
5694       absl::make_unique<SequentialThunk>(mlir_input.thunk_info,
5695                                          std::move(thunks));
5696   AddThunkToThunkSequence(std::move(sequential_thunk));
5697 
5698   return Status::OK();
5699 }
5700 
5701 // Emits code for slices based on the below structure. An if statement with
5702 // a guarding condition is generated for each ROOT slice.
5703 //
5704 // Pseudo code:
5705 //
5706 // Compute values of slice input operands
5707 //
5708 // Compute guarding_cond0
5709 // if (guarding_cond0) {
5710 //   Write to output of slice0
5711 // }
5712 //
5713 // Compute guarding_cond1
5714 // if (guarding_cond1) {
5715 //   Write to output of slice1
5716 // }
5717 //
EmitElementForInputFusibleSlices(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> ir_arrays,const llvm_ir::IrArray::Index & index)5718 Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
5719     mlir::lmhlo::FusionOp fusion, absl::Span<const llvm_ir::IrArray> ir_arrays,
5720     const llvm_ir::IrArray::Index& index) {
5721   VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion);
5722 
5723   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
5724                       GetOrCreateSubComputationFromRegion(&fusion.region(),
5725                                                           /*is_fusion=*/true));
5726 
5727   HloInstruction* slice_or_tuple = fused_computation->root_instruction();
5728   auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
5729     if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
5730       return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
5731     }
5732     CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
5733     return slice_or_tuple->operands();
5734   }();
5735 
5736   // Emit input operand values of slices.
5737   std::vector<llvm::Value*> input_ir_values;
5738   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
5739                                      GetNestedComputer());
5740   FusedIrEmitter fused_emitter(&elem_emitter);
5741   for (int i = 0; i < fused_computation->num_parameters(); i++) {
5742     fused_emitter.BindGenerator(
5743         fused_computation->parameter_instruction(i),
5744         [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
5745           return ir_arrays[i].EmitReadArrayElement(index, &b_);
5746         });
5747   }
5748   for (const HloInstruction* slice : slice_instructions) {
5749     auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
5750     input_ir_values.push_back(input_generator(index).ValueOrDie());
5751   }
5752 
5753   // Emit for slice_instructions.
5754   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5755   for (int64 i = 0; i < slice_instructions.size(); ++i) {
5756     HloInstruction* slice = slice_instructions[i];
5757 
5758     // guarding_cond := index >= start && index < limit, for each dim.
5759     std::vector<llvm::Value*> index_within_ranges;
5760     for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
5761       CHECK_EQ(slice->slice_strides(dim), 1);
5762       auto larger_or_equal_than_start = b_.CreateICmpSGE(
5763           index.multidim()[dim],
5764           index.GetConstantWithIndexType(slice->slice_starts(dim)));
5765       llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
5766           index.multidim()[dim],
5767           index.GetConstantWithIndexType(slice->slice_limits(dim)));
5768       llvm::Value* within_range =
5769           b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
5770       index_within_ranges.push_back(within_range);
5771     }
5772     llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
5773 
5774     auto emit_slice_elem_func = [&] {
5775       const std::vector<llvm::Value*>& src_multidim = index.multidim();
5776       std::vector<llvm::Value*> dst_multidim(src_multidim.size());
5777       for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
5778         dst_multidim[dim] =
5779             Sub(src_multidim[dim],
5780                 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5781       }
5782       llvm_ir::IrArray src_ir_array =
5783           ir_arrays[fused_computation->num_parameters() + i];
5784       IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
5785                                      index.GetType());
5786       src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
5787                                          &b_);
5788     };
5789 
5790     ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
5791   }
5792   return Status::OK();
5793 }
5794 
EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input)5795 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
5796     MlirEmitterInput mlir_input) {
5797   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
5798 
5799   constexpr int unroll_factor = 1;
5800 
5801   std::vector<llvm_ir::IrArray> ir_arrays;
5802   TF_ASSIGN_OR_RETURN(
5803       auto kernel_thunk,
5804       BuildKernelThunkForMlir(fusion, mlir_input.thunk_info,
5805                               mlir_input.extra_slice, &ir_arrays));
5806 
5807   TF_ASSIGN_OR_RETURN(Shape element_shape,
5808                       GetConsistentInputShapeForRootSlices(fusion));
5809   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
5810       element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
5811   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
5812                          ir_emitter_context_->llvm_module());
5813 
5814   Status emit_status =
5815       ParallelLoopEmitter(
5816           [&](const llvm_ir::IrArray::Index index) -> Status {
5817             return EmitElementForInputFusibleSlices(fusion, ir_arrays, index);
5818           },
5819           element_shape, launch_dimensions, &b_)
5820           .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
5821                     GetIndexTypeForKernelFromMlir(
5822                         fusion, launch_dimensions.launch_bound(), &b_));
5823 
5824   thunk_sequence_.emplace_back(std::move(kernel_thunk));
5825 
5826   return emit_status;
5827 }
5828 
GetThunkInfo(const HloInstruction * hlo) const5829 Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
5830     const HloInstruction* hlo) const {
5831   auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
5832   if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
5833     info.profile_index.emplace(
5834         static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
5835   }
5836   return info;
5837 }
5838 
EmitOp(MlirEmitterInput mlir_input)5839 Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) {
5840   if (mlir::isa<mlir::lmhlo::SortOp>(mlir_input.op)) {
5841     return EmitSortFromMlir(mlir_input);
5842   }
5843   if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(mlir_input.op)) {
5844     return EmitCollectivePermuteFromMlir(mlir_input);
5845   }
5846   LOG(FATAL)
5847       << "This function is for test only, and the op is not implemented: "
5848       << MlirToString(mlir_input.op);
5849 }
5850 
SetOperation(mlir::Operation * op)5851 void MlirEmitterContext::SetOperation(mlir::Operation* op) {
5852   this->name = mlir::GetNameFromLoc(op->getLoc());
5853 
5854   auto operands = GetHloOperands(op);
5855   auto outputs = GetHloOutputs(op);
5856   for (auto operand : operands) {
5857     operand_shapes.push_back(TypeToShape(operand.getType()));
5858   }
5859   for (auto output : outputs) {
5860     output_shapes.push_back(TypeToShape(output.getType()));
5861   }
5862 }
5863 
HandleBitcast(HloInstruction * bitcast)5864 Status IrEmitterUnnested::HandleBitcast(HloInstruction* bitcast) {
5865   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(bitcast));
5866   DCHECK_EQ(nullptr, input.op);
5867   return Status::OK();
5868 }
5869 
5870 }  // namespace gpu
5871 }  // namespace xla
5872