1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
17 
18 #include <climits>
19 #include <memory>
20 #include <tuple>
21 
22 #include "absl/algorithm/container.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
26 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
27 #include "mlir/IR/AffineMap.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Dialect.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
37 #include "mlir/IR/Operation.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
40 #include "mlir/Pass/Pass.h"  // from @llvm-project
41 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
42 #include "mlir/Translation.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
44 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
47 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
48 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
49 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
50 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
51 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
52 #include "tensorflow/compiler/xla/debug_options_flags.h"
53 #include "tensorflow/compiler/xla/service/backend.h"
54 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
55 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
56 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
57 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
58 #include "tensorflow/compiler/xla/service/hlo_computation.h"
59 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
60 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
61 #include "tensorflow/compiler/xla/service/hlo_module.h"
62 #include "tensorflow/compiler/xla/service/hlo_parser.h"
63 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
64 #include "tensorflow/compiler/xla/shape_util.h"
65 #include "tensorflow/compiler/xla/statusor.h"
66 #include "tensorflow/compiler/xla/util.h"
67 #include "tensorflow/compiler/xla/window_util.h"
68 #include "tensorflow/compiler/xla/xla_data.pb.h"
69 
70 using xla::BufferAllocation;
71 using xla::BufferAssignment;
72 using xla::HloComputation;
73 using xla::HloCustomCallInstruction;
74 using xla::HloInfeedInstruction;
75 using xla::HloInstruction;
76 using xla::HloModule;
77 using xla::HloModuleProto;
78 using xla::HloOutfeedInstruction;
79 using xla::HloProto;
80 using xla::Shape;
81 using xla::StatusOr;
82 
83 namespace mlir {
84 namespace {
85 
StringRefToView(llvm::StringRef ref)86 absl::string_view StringRefToView(llvm::StringRef ref) {
87   return {ref.data(), ref.size()};
88 }
89 
HloModuleFromProto(const HloProto & hlo_proto)90 StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
91     const HloProto& hlo_proto) {
92   const HloModuleProto& module_proto = hlo_proto.hlo_module();
93   TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
94                       HloModule::CreateModuleConfigFromProto(
95                           module_proto, xla::GetDebugOptionsFromFlags()));
96   return HloModule::CreateFromProto(module_proto, module_config);
97 }
98 
99 // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
100 // given platform.
ConvertModule(std::unique_ptr<HloModule> hlo_module,ModuleOp module,StringRef platform_name)101 Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
102                      StringRef platform_name) {
103   auto platform = xla::se::MultiPlatformManager::PlatformWithName(
104       StringRefToView(platform_name));
105   if (!platform.ok()) {
106     std::string error_msg;
107     llvm::raw_string_ostream os(error_msg);
108     os << "failed to get platform: " << platform.status().ToString()
109        << " (available Platform: ";
110     std::vector<std::string> available_platforms;
111     (void)xla::se::MultiPlatformManager::PlatformsWithFilter(
112         [&](const stream_executor::Platform* p) {
113           available_platforms.push_back(p->Name());
114           return false;
115         });
116     llvm::interleaveComma(available_platforms, os);
117     os << ")";
118     return xla::InvalidArgument("%s", os.str().c_str());
119   }
120 
121   xla::BackendOptions backend_options;
122   backend_options.set_platform(platform.ValueOrDie());
123   auto backend_or_err = xla::Backend::CreateBackend(backend_options);
124   TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
125                                   "failed to create XLA Backend ");
126   auto backend = std::move(backend_or_err.ValueOrDie());
127 
128   // Run all HLO passes to produce an optimized module.
129   auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
130       std::move(hlo_module), backend->default_stream_executor(),
131       optimize_xla_hlo, {backend->memory_allocator()});
132   TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
133                                   "running XLA pass pipeline");
134   std::unique_ptr<HloModule> optimized_hlo_module =
135       std::move(std::get<0>(result_or.ValueOrDie()));
136   std::unique_ptr<BufferAssignment> assignment =
137       std::move(std::get<1>(result_or.ValueOrDie()));
138 
139   // Clear the module before populating it back with the result of the
140   // conversion.
141   module.getBody()->clear();
142   OpBuilder builder(module);
143   module.ensureTerminator(module.getBodyRegion(), builder, module.getLoc());
144 
145   TF_RETURN_WITH_CONTEXT_IF_ERROR(
146       HloToLhloModule(*assignment, *optimized_hlo_module, module),
147       "converting HLO to LHLO");
148 
149   return Status::OK();
150 }
151 
152 // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
153 // optimization pipeline for the required platform, and then converts it back to
154 // MLIR LHLO.
155 class XlaHloToLhloPass
156     : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const157   void getDependentDialects(DialectRegistry& registry) const override {
158     registry
159         .insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
160                 mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
161   }
162 
163  public:
164   XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass &)165   XlaHloToLhloPass(const XlaHloToLhloPass&) {}
166 
167  private:
runOnOperation()168   void runOnOperation() final {
169     ModuleOp module = getOperation();
170 
171     auto status = [&module, this]() -> Status {
172       SymbolTable symbol_table(module);
173       if (!symbol_table.lookup("main")) {
174         return xla::InvalidArgument(
175             "conversion to HLO module failed: missing main()");
176       }
177       HloProto hlo_proto;
178       TF_RETURN_WITH_CONTEXT_IF_ERROR(
179           ConvertMlirHloToHlo(module, &hlo_proto,
180                               /*use_tuple_args=*/false,
181                               /*return_tuple=*/false,
182                               /*shape_representation_fn=*/nullptr),
183           "conversion to XLA HLO proto failed");
184 
185       auto statusOrHloModule = HloModuleFromProto(hlo_proto);
186       TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
187                                       "parsing HLO proto to HLO module failed");
188       std::unique_ptr<HloModule> hlo_module =
189           std::move(statusOrHloModule.ValueOrDie());
190 
191       return ConvertModule(std::move(hlo_module), module, platform_);
192     }();
193     if (!status.ok()) {
194       module.emitError() << status.ToString();
195       return signalPassFailure();
196     }
197   }
198 
199   Option<std::string> platform_{
200       *this, "platform",
201       llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
202       llvm::cl::init("Host")};
203 };
204 
205 }  // namespace
206 
207 // Creates MLIR operands corresponding to operands and results of the XLA HLO
208 // instruction. If `num_operands` is valid, then only the first `num_operands`
209 // operands of the HLO instruction will be considered.
CreateOperands(const HloInstruction * instr,absl::optional<xla::int64> num_operands,llvm::SmallVectorImpl<Value> & operands,size_t & num_arguments,size_t & num_results)210 Status LhloDialectEmitter::CreateOperands(
211     const HloInstruction* instr, absl::optional<xla::int64> num_operands,
212     llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
213     size_t& num_results) {
214   if (num_operands.value_or(0) > instr->operand_count())
215     return xla::InvalidArgument("num_operands must be <= operand count");
216   for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
217        ++i) {
218     TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
219   }
220   num_arguments = operands.size();
221   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands));
222   num_results = operands.size() - num_arguments;
223   return Status::OK();
224 }
225 
226 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,ValueRange operands)227 OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
228                                                 ValueRange operands) {
229   Location loc = getLocation(instr);
230   return builder_.create<OpType>(loc, llvm::None, operands,
231                                  llvm::ArrayRef<NamedAttribute>{});
232 }
233 
234 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,size_t & num_arguments,size_t & num_results,absl::optional<xla::int64> num_operands)235 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
236     const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
237     absl::optional<xla::int64> num_operands) {
238   llvm::SmallVector<Value, 4> operands;
239   TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
240                                     num_arguments, num_results));
241   return CreateOpWithoutAttrs<OpType>(instr, operands);
242 }
243 
EmitOp(const HloInstruction * instr)244 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
245     const HloInstruction* instr) {
246   using xla::HloOpcode;
247   switch (instr->opcode()) {
248     case HloOpcode::kAbs:
249       return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
250     case HloOpcode::kAdd:
251       return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
252     case HloOpcode::kAllToAll:
253       return EmitAllToAllOp(instr);
254     case HloOpcode::kAllGather:
255       return EmitAllGatherOp(instr);
256     case HloOpcode::kAllReduce:
257       return EmitAllReduceOp(instr);
258     case HloOpcode::kAnd:
259       return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
260     case HloOpcode::kAtan2:
261       return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
262     case HloOpcode::kBitcast:
263       return nullptr;
264     case HloOpcode::kBitcastConvert:
265       return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
266     case HloOpcode::kBroadcast:
267       return EmitBroadcastOp(instr);
268     case HloOpcode::kCeil:
269       return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr);
270     case HloOpcode::kCbrt:
271       return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
272     case HloOpcode::kClamp:
273       return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
274     case HloOpcode::kCollectivePermute:
275       return EmitCollectivePermuteOp(instr);
276     case HloOpcode::kClz:
277       return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
278     case HloOpcode::kCompare:
279       return EmitCompareOp(instr);
280     case HloOpcode::kComplex:
281       return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr);
282     case HloOpcode::kConcatenate:
283       return EmitConcatenateOp(instr);
284     case HloOpcode::kConvert:
285       return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr);
286     case HloOpcode::kCopy:
287       return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr);
288     case HloOpcode::kCos:
289       return CreateOpWithoutAttrs<lmhlo::CosOp>(instr);
290     case HloOpcode::kDivide:
291       return CreateOpWithoutAttrs<lmhlo::DivOp>(instr);
292     case HloOpcode::kDot:
293       return EmitDotOp(instr);
294     case HloOpcode::kDynamicSlice:
295       return EmitDynamicSliceOp(instr);
296     case HloOpcode::kDynamicUpdateSlice:
297       return CreateOpWithoutAttrs<lmhlo::DynamicUpdateSliceOp>(instr);
298     case HloOpcode::kFft:
299       return EmitFftOp(instr);
300     case HloOpcode::kExp:
301       return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr);
302     case HloOpcode::kExpm1:
303       return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr);
304     case HloOpcode::kFloor:
305       return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
306     case HloOpcode::kGather:
307       return EmitGatherOp(instr);
308     case HloOpcode::kImag:
309       return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
310     case HloOpcode::kInfeed:
311       return EmitInfeedOp(instr);
312     case HloOpcode::kIota:
313       return EmitIotaOp(instr);
314     case HloOpcode::kIsFinite:
315       return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
316     case HloOpcode::kLog:
317       return CreateOpWithoutAttrs<lmhlo::LogOp>(instr);
318     case HloOpcode::kLog1p:
319       return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr);
320     case HloOpcode::kMap:
321       return EmitMapOp(instr);
322     case HloOpcode::kMaximum:
323       return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr);
324     case HloOpcode::kMinimum:
325       return CreateOpWithoutAttrs<lmhlo::MinOp>(instr);
326     case HloOpcode::kMultiply:
327       return CreateOpWithoutAttrs<lmhlo::MulOp>(instr);
328     case HloOpcode::kNegate:
329       return CreateOpWithoutAttrs<lmhlo::NegOp>(instr);
330     case HloOpcode::kNot:
331       return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
332     case HloOpcode::kOr:
333       return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
334     case HloOpcode::kOutfeed:
335       return EmitOutfeedOp(instr);
336     case HloOpcode::kPartitionId:
337       return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
338     case HloOpcode::kPad:
339       return EmitPadOp(instr);
340     case HloOpcode::kPopulationCount:
341       return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
342     case HloOpcode::kPower:
343       return CreateOpWithoutAttrs<lmhlo::PowOp>(instr);
344     case HloOpcode::kReal:
345       return CreateOpWithoutAttrs<lmhlo::RealOp>(instr);
346     case HloOpcode::kReshape:
347       return CreateOpWithoutAttrs<lmhlo::ReshapeOp>(instr);
348     case HloOpcode::kReducePrecision:
349       return EmitReducePrecisionOp(instr);
350     case HloOpcode::kReduceWindow:
351       return EmitReduceWindowOp(instr);
352     case HloOpcode::kRemainder:
353       return CreateOpWithoutAttrs<lmhlo::RemOp>(instr);
354     case HloOpcode::kReplicaId:
355       return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
356     case HloOpcode::kReverse:
357       return EmitReverseOp(instr);
358     case HloOpcode::kRoundNearestAfz:
359       return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr);
360     case HloOpcode::kRsqrt:
361       return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr);
362     case HloOpcode::kSelect:
363       return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr);
364     case HloOpcode::kShiftLeft:
365       return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr);
366     case HloOpcode::kShiftRightLogical:
367       return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr);
368     case HloOpcode::kShiftRightArithmetic:
369       return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr);
370     case HloOpcode::kSign:
371       return CreateOpWithoutAttrs<lmhlo::SignOp>(instr);
372     case HloOpcode::kSin:
373       return CreateOpWithoutAttrs<lmhlo::SinOp>(instr);
374     case HloOpcode::kSlice:
375       return EmitSliceOp(instr);
376     case HloOpcode::kSqrt:
377       return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr);
378     case HloOpcode::kSubtract:
379       return CreateOpWithoutAttrs<lmhlo::SubOp>(instr);
380     case HloOpcode::kTanh:
381       return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
382     case HloOpcode::kTranspose:
383       return EmitTransposeOp(instr);
384     case HloOpcode::kTriangularSolve:
385       return EmitTriangularSolveOp(instr);
386     case HloOpcode::kXor:
387       return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
388     case HloOpcode::kSort:
389       return EmitSortOp(instr);
390     case HloOpcode::kFusion:
391       return EmitFusionOp(instr);
392     case HloOpcode::kScatter:
393       return EmitScatterOp(instr);
394     case HloOpcode::kSelectAndScatter:
395       return EmitSelectAndScatterOp(instr);
396     case HloOpcode::kCustomCall:
397       return EmitCustomCallOp(instr);
398     case HloOpcode::kConstant:
399       return EmitConstant(instr);
400     case HloOpcode::kReduce:
401       return EmitReduceOp(instr);
402     case HloOpcode::kRngGetAndUpdateState:
403       return EmitRngGetAndUpdateStateOp(instr);
404     default:
405       llvm::errs() << instr->ToString();
406       return tensorflow::errors::Internal(
407           absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
408                        " is not supported."));
409   }
410 }
411 
DefaultAction(const HloInstruction * instr)412 Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
413   return EmitOp(instr).status();
414 }
415 
EmitSortOp(const HloInstruction * instr)416 StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
417     const HloInstruction* instr) {
418   TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
419   auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
420   sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
421   sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
422   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
423       *sort_instr->called_computations()[0], &sort.comparator(), &builder_));
424   return sort;
425 }
426 
427 // Walks MHLO::TupleOp recursively.
WalkTuplePostOrder(Value v,const std::function<Status (Value)> & visitor)428 Status WalkTuplePostOrder(Value v,
429                           const std::function<Status(Value)>& visitor) {
430   if (auto* op = v.getDefiningOp()) {
431     if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
432       for (Value sub_v : tuple.val()) {
433         TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
434       }
435       return Status::OK();
436     }
437   }
438   return visitor(v);
439 }
440 
441 // This function removes all uses of a fused region argument, and rewire those
442 // uses to a `tensor_load %memref`, where %memref is caller argument.
443 //
444 // It also flattens all input/output tuples into more region arguments /
445 // results.
RewriteFusionOperand(const HloInstruction * root,const Shape & shape,xla::ShapeIndex * shape_index,OpBuilder * b,Location loc)446 StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
447     const HloInstruction* root, const Shape& shape,
448     xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
449   if (shape.IsTuple()) {
450     llvm::SmallVector<Value, 4> values;
451     for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
452       shape_index->push_back(i);
453       TF_ASSIGN_OR_RETURN(
454           auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
455                                        b, loc));
456       values.push_back(v);
457       shape_index->pop_back();
458     }
459     return Value(b->create<mhlo::TupleOp>(loc, values));
460   }
461   TF_ASSIGN_OR_RETURN(Value memref,
462                       GetOrCreateArrayView(root, shape, *shape_index));
463   auto load = b->create<TensorLoadOp>(loc, memref);
464   if (shape.layout() !=
465       xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
466     llvm::SmallVector<int64_t, 4> minor_to_major(
467         shape.layout().minor_to_major().begin(),
468         shape.layout().minor_to_major().end());
469     load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b));
470   }
471   return load.getResult();
472 }
473 
EmitFusionOp(const HloInstruction * instr)474 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
475     const HloInstruction* instr) {
476   Location loc = getLocation(instr);
477 
478   auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
479 
480   auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
481   auto after_fusion = builder_.saveInsertionPoint();
482   builder_ = mlir::OpBuilder(fusion);
483 
484   auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
485 
486   llvm::SmallVector<Value, 8> arguments;
487   for (int i = 0; i < instr->operands().size(); ++i) {
488     const HloInstruction* operand = instr->operand(i);
489     xla::ShapeIndex shape_index;
490     TF_ASSIGN_OR_RETURN(
491         auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
492                                        &region_builder, loc));
493     arguments.push_back(arg);
494   }
495 
496   TF_ASSIGN_OR_RETURN(Value result,
497                       xla::HloFunctionImporter::ImportInstructions(
498                           *fusion_instr->fused_instructions_computation(),
499                           arguments, &region_builder));
500 
501   {
502     int i = 0;
503     llvm::SmallVector<Value, 4> output;
504     TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
505     TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
506       region_builder.create<TensorStoreOp>(loc, v, output[i++]);
507       return Status::OK();
508     }));
509     if (i != output.size()) {
510       return xla::InternalError("output sizes don't match");
511     }
512   }
513 
514   // Fold GTE/Tuple pairs.
515   //
516   // Since the fused region refers to values in its parent region, we can't
517   // call applyPatternAndFoldGreedily. We optimize it manually.
518   //
519   // Only walk once, because post-ordering is exactly what we need for GTE
520   // optimizations.
521   fusion.region().walk([](mhlo::GetTupleElementOp gte) {
522     SmallVector<Value, 4> folded_values;
523     if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
524       gte.replaceAllUsesWith(folded_values[0]);
525     }
526   });
527 
528   // Effectively a DCE on the region.
529   {
530     llvm::SmallVector<mlir::Operation*, 4> ops;
531     fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); });
532     // Visit the user first.
533     std::reverse(ops.begin(), ops.end());
534     for (auto op : ops) {
535       if (isOpTriviallyDead(op)) op->erase();
536     }
537   }
538 
539   builder_.restoreInsertionPoint(after_fusion);
540   return fusion;
541 }
542 
543 StatusOr<mhlo::ScatterDimensionNumbers>
GetScatterDimensionNumbers(const HloInstruction * instr)544 LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr) {
545   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
546 
547   const xla::ScatterDimensionNumbers& xla_scatter_dim =
548       scatter_instr->scatter_dimension_numbers();
549   auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
550       GetI64DenseElementsAttr(xla_scatter_dim.update_window_dims()),
551       GetI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()),
552       GetI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()),
553       builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()),
554       module_.getContext());
555   return scatter_dimension_numbers;
556 }
557 
EmitScatterOp(const HloInstruction * instr)558 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
559     const HloInstruction* instr) {
560   TF_ASSIGN_OR_RETURN(auto scatter,
561                       CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
562 
563   // copy attributes
564   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
565 
566   TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
567                       GetScatterDimensionNumbers(instr));
568   scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers);
569   scatter.indices_are_sortedAttr(
570       builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
571   scatter.unique_indicesAttr(
572       builder_.getBoolAttr(scatter_instr->unique_indices()));
573 
574   // import update computation as region
575   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
576       *scatter_instr->called_computations()[0], &scatter.update_computation(),
577       &builder_));
578 
579   return scatter;
580 }
581 
EmitSelectAndScatterOp(const HloInstruction * instr)582 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
583     const HloInstruction* instr) {
584   TF_ASSIGN_OR_RETURN(auto select_and_scatter,
585                       CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
586 
587   // copy attributes
588   auto* select_and_scatter_instr =
589       xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
590   const xla::Window& window = select_and_scatter_instr->window();
591 
592   select_and_scatter.window_dimensionsAttr(
593       GetWindowElements(window, [](const xla::WindowDimension& dim) {
594         return static_cast<int64_t>(dim.size());
595       }));
596   select_and_scatter.window_stridesAttr(
597       GetWindowElements(window, [](const xla::WindowDimension& dim) {
598         return static_cast<int64_t>(dim.stride());
599       }));
600   select_and_scatter.paddingAttr(
601       GetWindowElements(window, [](const xla::WindowDimension& dim) {
602         return static_cast<int64_t>(dim.padding_low());
603       }));
604 
605   // import select and scatter computation as region
606   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
607       *select_and_scatter_instr->select(), &select_and_scatter.select(),
608       &builder_));
609   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
610       *select_and_scatter_instr->scatter(), &select_and_scatter.scatter(),
611       &builder_));
612   return select_and_scatter;
613 }
614 
EmitCustomCallOp(const HloInstruction * instr)615 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
616     const HloInstruction* instr) {
617   auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
618 
619   if (xla::gpu::IsCustomCallToCusolver(*instr)) {
620     return EmitCholesky(custom_call_instr);
621   }
622 
623   if (xla::gpu::IsCublasGemm(*instr)) {
624     return EmitGemm(custom_call_instr);
625   }
626 
627   if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
628     return EmitDnnConvolution(custom_call_instr);
629   }
630 
631   if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) {
632     return EmitDnnBatchNorm(custom_call_instr);
633   }
634 
635   size_t num_arguments, num_results;
636   TF_ASSIGN_OR_RETURN(auto custom_call,
637                       CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
638                           instr, num_arguments, num_results));
639   custom_call.call_target_nameAttr(
640       builder_.getStringAttr(custom_call_instr->custom_call_target()));
641   custom_call.backend_configAttr(
642       builder_.getStringAttr(custom_call_instr->opaque()));
643   const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
644                                static_cast<int32_t>(num_results)};
645   custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
646                        builder_.getI32VectorAttr(segments));
647   return custom_call.getOperation();
648 }
649 
EmitCholesky(const HloCustomCallInstruction * custom_call)650 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
651     const HloCustomCallInstruction* custom_call) {
652   TF_ASSIGN_OR_RETURN(auto cholesky_op,
653                       CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
654   TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
655                       custom_call->backend_config<xla::CholeskyOptions>());
656   cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
657   return cholesky_op;
658 }
659 
EmitGemm(const HloCustomCallInstruction * custom_call)660 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
661     const HloCustomCallInstruction* custom_call) {
662   TF_ASSIGN_OR_RETURN(
663       auto const config,
664       custom_call->backend_config<xla::gpu::GemmBackendConfig>());
665 
666   auto set_common_attributes = [&](auto op) -> Operation* {
667     auto hlo_dims = config.dot_dimension_numbers();
668     auto mlir_dims = mhlo::DotDimensionNumbers::get(
669         GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()),
670         GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()),
671         GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()),
672         GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()),
673         builder_.getContext());
674     op.dot_dimension_numbersAttr(mlir_dims);
675     op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real()));
676     op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
677     op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size()));
678     if (config.algorithm_case() ==
679         xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
680       op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm()));
681     }
682     return op.getOperation();
683   };
684 
685   if (custom_call->operand_count() == 2) {
686     TF_ASSIGN_OR_RETURN(auto gemm,
687                         CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
688     return set_common_attributes(gemm);
689   }
690 
691   if (custom_call->operand_count() == 3) {
692     TF_ASSIGN_OR_RETURN(
693         auto gemm_bias,
694         CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
695     gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta()));
696     return set_common_attributes(gemm_bias);
697   }
698 
699   return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
700 }
701 
GetLHLOActivation(stream_executor::dnn::ActivationMode activation)702 static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
703     stream_executor::dnn::ActivationMode activation) {
704   switch (activation) {
705     case stream_executor::dnn::kNone:
706       return mlir::lmhlo_gpu::Activation::None;
707     case stream_executor::dnn::kSigmoid:
708       return mlir::lmhlo_gpu::Activation::Sigmoid;
709     case stream_executor::dnn::kRelu:
710       return mlir::lmhlo_gpu::Activation::Relu;
711     case stream_executor::dnn::kRelu6:
712       return mlir::lmhlo_gpu::Activation::Relu6;
713     case stream_executor::dnn::kReluX:
714       return mlir::lmhlo_gpu::Activation::ReluX;
715     case stream_executor::dnn::kTanh:
716       return mlir::lmhlo_gpu::Activation::Tanh;
717     case stream_executor::dnn::kBandPass:
718       return mlir::lmhlo_gpu::Activation::BandPass;
719     default:
720       return xla::InternalError("Unknown activation");
721   }
722 }
723 
EmitDnnConvolution(const HloCustomCallInstruction * custom_call)724 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
725     const HloCustomCallInstruction* custom_call) {
726   TF_ASSIGN_OR_RETURN(
727       auto const backend_config,
728       custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
729 
730   TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
731                       xla::gpu::GetCudnnConvKind(custom_call));
732 
733   auto get_layout_attribute = [&](const xla::Layout& layout) {
734     std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
735     absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
736                       [](xla::int64 x) { return static_cast<int64_t>(x); });
737     return builder_.getI64ArrayAttr(minor_to_major);
738   };
739 
740   auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
741     const xla::Window& window = custom_call->window();
742     // Window size for Cudnn Conv is same as the kernel size.
743     op.window_stridesAttr(
744         GetWindowElements(window, [](const xla::WindowDimension& dim) {
745           return static_cast<int64_t>(dim.stride());
746         }));
747     // Cudnn Conv requires low and high padding to be equal.
748     op.paddingAttr(
749         GetWindowElements(window, [](const xla::WindowDimension& dim) {
750           return static_cast<int64_t>(dim.padding_low());
751         }));
752     // LHS dilation is encoded in base_dilation of the backend config.
753     // RHS dilation is encoded in window_dilation of the backend config.
754     op.lhs_dilationAttr(
755         GetWindowElements(window, [](const xla::WindowDimension& dim) {
756           return static_cast<int64_t>(dim.base_dilation());
757         }));
758     op.rhs_dilationAttr(
759         GetWindowElements(window, [](const xla::WindowDimension& dim) {
760           return static_cast<int64_t>(dim.window_dilation());
761         }));
762     // Setup window reversal.
763     auto window_reversal = llvm::to_vector<4>(llvm::map_range(
764         window.dimensions(),
765         [](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
766     auto type = RankedTensorType::get(op.window_strides()->getType().getShape(),
767                                       builder_.getIntegerType(/*width=*/1));
768     op.window_reversalAttr(DenseElementsAttr::get(type, window_reversal));
769 
770     op.dimension_numbersAttr(xla::ConvertConvDimensionNumbers(
771         custom_call->convolution_dimension_numbers(), &builder_));
772     op.feature_group_countAttr(
773         builder_.getI64IntegerAttr(custom_call->feature_group_count()));
774     op.batch_group_countAttr(
775         builder_.getI64IntegerAttr(custom_call->batch_group_count()));
776     op.precision_configAttr(xla::ConvertPrecisionConfig(
777         &custom_call->precision_config(), &builder_));
778     op.result_scaleAttr(
779         builder_.getF64FloatAttr(backend_config.conv_result_scale()));
780     auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
781         builder_.getI64IntegerAttr(backend_config.algorithm()),
782         builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
783         get_layout_attribute(custom_call->operand(0)->shape().layout()),
784         get_layout_attribute(custom_call->operand(1)->shape().layout()),
785         get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
786         builder_.getContext());
787     op.backend_configAttr(config);
788 
789     return op.getOperation();
790   };
791 
792   auto set_activation = [&, this](auto op) -> Status {
793     auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
794         backend_config.activation_mode());
795     TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
796                         GetLHLOActivation(se_activation));
797     StringAttr activation_attr = builder_.getStringAttr(
798         mlir::lmhlo_gpu::stringifyActivation(activation));
799     op.activation_modeAttr(activation_attr);
800     return Status::OK();
801   };
802 
803   switch (kind) {
804     case xla::gpu::CudnnConvKind::kForward: {
805       TF_ASSIGN_OR_RETURN(
806           auto cnn_forward,
807           CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
808       return set_common_conv_attributes(cnn_forward);
809     }
810     case xla::gpu::CudnnConvKind::kBackwardInput: {
811       TF_ASSIGN_OR_RETURN(
812           auto cnn_backward,
813           CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
814       return set_common_conv_attributes(cnn_backward);
815     }
816     case xla::gpu::CudnnConvKind::kBackwardFilter: {
817       TF_ASSIGN_OR_RETURN(
818           auto cnn_backward,
819           CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
820       return set_common_conv_attributes(cnn_backward);
821     }
822     case xla::gpu::CudnnConvKind::kForwardActivation: {
823       // Fused conv can be either with side input or without.
824       if (custom_call->operand_count() == 3) {
825         TF_ASSIGN_OR_RETURN(
826             auto cnn_fused,
827             CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
828         TF_RETURN_IF_ERROR(set_activation(cnn_fused));
829         return set_common_conv_attributes(cnn_fused);
830       }
831 
832       TF_RET_CHECK(custom_call->operand_count() == 4);
833       TF_ASSIGN_OR_RETURN(
834           auto cnn_fused_side_input,
835           CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
836               custom_call));
837       cnn_fused_side_input.side_input_scaleAttr(
838           builder_.getF64FloatAttr(backend_config.side_input_scale()));
839       TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
840       return set_common_conv_attributes(cnn_fused_side_input);
841     }
842   }
843 }
844 
EmitDnnBatchNorm(const HloCustomCallInstruction * custom_call)845 StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
846     const HloCustomCallInstruction* custom_call) {
847   const xla::int64 num_operands = custom_call->operand_count();
848   auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
849     // The last 2 operands of a custom call for batch norm are the epsilon and
850     // feature_index.
851     const HloInstruction* epsilon = custom_call->operand(num_operands - 2);
852     TF_RET_CHECK(epsilon->IsConstant());
853     float epsilon_value = epsilon->literal().Get<float>({});
854 
855     const HloInstruction* feature_index =
856         custom_call->operand(num_operands - 1);
857     TF_RET_CHECK(feature_index->IsConstant());
858     xla::int64 feature_index_value =
859         feature_index->literal().Get<xla::int64>({});
860 
861     op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value));
862     op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value));
863     return op.getOperation();
864   };
865 
866   const std::string& target = custom_call->custom_call_target();
867   if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) {
868     TF_ASSIGN_OR_RETURN(auto fwd_training,
869                         CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>(
870                             custom_call, num_operands - 2));
871     return set_batchnorm_attributes(fwd_training);
872   }
873 
874   if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) {
875     TF_ASSIGN_OR_RETURN(auto backward,
876                         CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>(
877                             custom_call, num_operands - 2));
878     return set_batchnorm_attributes(backward);
879   }
880 
881   if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) {
882     TF_ASSIGN_OR_RETURN(auto fwd_inference,
883                         CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>(
884                             custom_call, num_operands - 2));
885     return set_batchnorm_attributes(fwd_inference);
886   }
887 
888   return xla::Unimplemented("Unsupported batch norm operation");
889 }
890 
891 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
EmitConstant(const HloInstruction * instr)892 StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
893     const HloInstruction* instr) {
894   // Insert a global_memref in the module.
895   Location loc = getLocation(instr);
896 
897   auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
898   TF_RET_CHECK(const_instr->shape().IsArray() &&
899                const_instr->shape().is_static());
900   TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
901                                      const_instr->shape(), builder_));
902   auto memref_type = type.dyn_cast<MemRefType>();
903   TF_RET_CHECK(memref_type != nullptr);
904 
905   TF_ASSIGN_OR_RETURN(
906       DenseElementsAttr initial_value,
907       CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
908 
909   std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
910       xla::llvm_ir::SanitizeConstantName(instr->name()));
911 
912   // Insert the global memref at the top level.
913   {
914     OpBuilder::InsertionGuard guard(builder_);
915     builder_.clearInsertionPoint();
916     auto global_var = builder_.create<GlobalMemrefOp>(
917         loc, constant_name, builder_.getStringAttr("private"),
918         TypeAttr::get(memref_type), initial_value, true);
919     SymbolTable(module_).insert(global_var);
920     global_var.getOperation()->moveBefore(&module_.front());
921 
922     // For operations that do not fold this constant value in their codegen, we
923     // still need to materialize it into a buffer. Since buffer allocation is
924     // already done, annotate the global_memref with the information to get to
925     // the allocated buffer slice for this constant if need be.
926     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
927                         assignment_.GetUniqueTopLevelSlice(instr));
928     global_var->setAttr("lmhlo.alloc", builder_.getIndexAttr(slice.index()));
929     TF_RET_CHECK(slice.offset() == 0)
930         << "Each constant should have its own allocation from BufferAssignment";
931     TF_RET_CHECK(slice.allocation()->size() == slice.size())
932         << "Each constant should have its own allocation from BufferAssignment";
933   }
934 
935   auto get_global_memref =
936       builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name);
937 
938   // Update the cache to remember this value.
939   auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
940   TF_RET_CHECK(cached_value == nullptr);
941   cached_value = get_global_memref;
942   return get_global_memref;
943 }
944 
EmitReduceOp(const HloInstruction * instr)945 StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
946     const HloInstruction* instr) {
947   TF_ASSIGN_OR_RETURN(auto reduce_op,
948                       CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
949   auto* reduce = xla::Cast<xla::HloReduceInstruction>(instr);
950   std::vector<int64_t> dimensions(reduce->dimensions().begin(),
951                                   reduce->dimensions().end());
952   reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
953   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
954       *instr->called_computations()[0], &reduce_op.body(), &builder_));
955   return reduce_op;
956 }
957 
EmitMapOp(const HloInstruction * instr)958 StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(
959     const HloInstruction* instr) {
960   TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
961   auto* map = xla::Cast<xla::HloMapInstruction>(instr);
962   std::vector<int64_t> dimensions(map->dimensions().begin(),
963                                   map->dimensions().end());
964   map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
965   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
966       *instr->called_computations()[0], &map_op.computation(), &builder_));
967   return map_op;
968 }
969 
EmitCompareOp(const HloInstruction * instr)970 StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
971     const HloInstruction* instr) {
972   TF_ASSIGN_OR_RETURN(auto compare_op,
973                       CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
974 
975   auto* compare = xla::Cast<xla::HloCompareInstruction>(instr);
976   auto direction = [&]() {
977     switch (compare->direction()) {
978       case xla::ComparisonDirection::kEq:
979         return mhlo::ComparisonDirection::EQ;
980       case xla::ComparisonDirection::kNe:
981         return mhlo::ComparisonDirection::NE;
982       case xla::ComparisonDirection::kGe:
983         return mhlo::ComparisonDirection::GE;
984       case xla::ComparisonDirection::kGt:
985         return mhlo::ComparisonDirection::GT;
986       case xla::ComparisonDirection::kLe:
987         return mhlo::ComparisonDirection::LE;
988       case xla::ComparisonDirection::kLt:
989         return mhlo::ComparisonDirection::LT;
990     }
991   }();
992   compare_op.comparison_directionAttr(
993       builder_.getStringAttr(stringifyComparisonDirection(direction)));
994   auto compare_type = [&]() {
995     switch (compare->type()) {
996       case xla::Comparison::Type::kFloat:
997         return mhlo::ComparisonType::FLOAT;
998       case xla::Comparison::Type::kFloatTotalOrder:
999         return mhlo::ComparisonType::TOTALORDER;
1000       case xla::Comparison::Type::kSigned:
1001         return mhlo::ComparisonType::SIGNED;
1002       case xla::Comparison::Type::kUnsigned:
1003         return mhlo::ComparisonType::UNSIGNED;
1004     }
1005   }();
1006   compare_op.compare_typeAttr(
1007       builder_.getStringAttr(stringifyComparisonType(compare_type)));
1008   return compare_op;
1009 }
1010 
EmitReducePrecisionOp(const HloInstruction * instr)1011 StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
1012     const HloInstruction* instr) {
1013   TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
1014                       CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
1015   auto* reduce_precision = xla::Cast<xla::HloReducePrecisionInstruction>(instr);
1016   reduce_precision_op.exponent_bitsAttr(
1017       builder_.getI32IntegerAttr(reduce_precision->exponent_bits()));
1018   reduce_precision_op.mantissa_bitsAttr(
1019       builder_.getI32IntegerAttr(reduce_precision->mantissa_bits()));
1020   return reduce_precision_op;
1021 }
1022 
1023 namespace {
1024 template <typename OpT>
SetupChannelIdAttribute(OpT op,const xla::HloChannelInstruction * instr,mlir::Builder builder)1025 void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
1026                              mlir::Builder builder) {
1027   if (instr->channel_id().has_value()) {
1028     op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
1029         builder.getI64IntegerAttr(*instr->channel_id()),
1030         builder.getI64IntegerAttr(0), builder.getContext()));
1031   }
1032 }
1033 
1034 template <typename OpT>
SetupCommonCollectiveOpAttributes(OpT op,const HloInstruction * instr,mlir::OpBuilder & builder)1035 Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
1036                                          mlir::OpBuilder& builder) {
1037   auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
1038   auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
1039       collective->replica_groups(), &builder);
1040   op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
1041   op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
1042   SetupChannelIdAttribute(op, collective, builder);
1043   return Status::OK();
1044 }
1045 }  // namespace
1046 
EmitAllToAllOp(const HloInstruction * instr)1047 StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
1048     const HloInstruction* instr) {
1049   TF_ASSIGN_OR_RETURN(auto all_to_all_op,
1050                       CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
1051   auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
1052   TF_RETURN_IF_ERROR(
1053       SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
1054   if (all_to_all->split_dimension().has_value()) {
1055     all_to_all_op.split_dimensionAttr(
1056         builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
1057   }
1058   return all_to_all_op;
1059 }
1060 
EmitAllGatherOp(const HloInstruction * instr)1061 StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
1062     const HloInstruction* instr) {
1063   TF_ASSIGN_OR_RETURN(auto all_gather_op,
1064                       CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
1065   auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
1066   TF_RETURN_IF_ERROR(
1067       SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
1068   all_gather_op.use_global_device_idsAttr(
1069       builder_.getBoolAttr(all_gather->use_global_device_ids()));
1070   all_gather_op.all_gather_dimensionAttr(
1071       builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
1072   return all_gather_op;
1073 }
1074 
EmitAllReduceOp(const HloInstruction * instr)1075 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
1076     const HloInstruction* instr) {
1077   TF_ASSIGN_OR_RETURN(auto all_reduce_op,
1078                       CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
1079   auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1080   TF_RETURN_IF_ERROR(
1081       SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
1082   all_reduce_op.use_global_device_idsAttr(
1083       builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1084   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1085       *instr->called_computations()[0], &all_reduce_op.computation(),
1086       &builder_));
1087   return all_reduce_op;
1088 }
1089 
1090 StatusOr<lmhlo::CollectivePermuteOp>
EmitCollectivePermuteOp(const HloInstruction * instr)1091 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
1092   TF_ASSIGN_OR_RETURN(auto permute_op,
1093                       CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
1094   auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
1095   SetupChannelIdAttribute(permute_op, permute, builder_);
1096   mlir::NamedAttribute source_target_pairs_attr =
1097       xla::HloFunctionImporter::ConvertSourceTargetPairs(
1098           permute->source_target_pairs(), &builder_);
1099   permute_op->setAttr(source_target_pairs_attr.first,
1100                       source_target_pairs_attr.second);
1101   return permute_op;
1102 }
1103 
EmitInfeedOp(const HloInstruction * instr)1104 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
1105     const HloInstruction* instr) {
1106   const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
1107   // HLO Infeed instruction has a single operand of token type and a tuple
1108   // with buffers and a token as its output. LMHLO Infeed operation does not
1109   // need the token operand or result, so drop it.
1110   SmallVector<Value, 2> operands;
1111   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
1112   auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
1113   infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
1114   return infeed_op;
1115 }
1116 
EmitOutfeedOp(const HloInstruction * instr)1117 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
1118     const HloInstruction* instr) {
1119   const HloOutfeedInstruction* outfeed =
1120       xla::Cast<HloOutfeedInstruction>(instr);
1121   // HLO outfeed instruction has 2 operands, the source and a token, and a
1122   // single token output. LMHLO Outfeed does not need the token operand and
1123   // result, do drop it.
1124   SmallVector<Value, 2> operands;
1125   TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
1126   auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
1127   outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config()));
1128   return outfeed_op;
1129 }
1130 
EmitBroadcastOp(const xla::HloInstruction * instr)1131 xla::StatusOr<lmhlo::BroadcastInDimOp> LhloDialectEmitter::EmitBroadcastOp(
1132     const xla::HloInstruction* instr) {
1133   TF_ASSIGN_OR_RETURN(auto broadcast,
1134                       CreateOpWithoutAttrs<lmhlo::BroadcastInDimOp>(instr));
1135   broadcast.broadcast_dimensionsAttr(
1136       builder_.getI64TensorAttr(instr->dimensions()));
1137   return broadcast;
1138 }
1139 
EmitConcatenateOp(const xla::HloInstruction * instr)1140 xla::StatusOr<lmhlo::ConcatenateOp> LhloDialectEmitter::EmitConcatenateOp(
1141     const xla::HloInstruction* instr) {
1142   TF_ASSIGN_OR_RETURN(auto concat,
1143                       CreateOpWithoutAttrs<lmhlo::ConcatenateOp>(instr));
1144   auto hlo_concat = xla::Cast<xla::HloConcatenateInstruction>(instr);
1145   concat.dimensionAttr(
1146       builder_.getI64IntegerAttr(hlo_concat->concatenate_dimension()));
1147   return concat;
1148 }
1149 
EmitIotaOp(const xla::HloInstruction * instr)1150 xla::StatusOr<lmhlo::IotaOp> LhloDialectEmitter::EmitIotaOp(
1151     const xla::HloInstruction* instr) {
1152   TF_ASSIGN_OR_RETURN(auto iota, CreateOpWithoutAttrs<lmhlo::IotaOp>(instr));
1153   auto hlo_iota = xla::Cast<xla::HloIotaInstruction>(instr);
1154   iota.iota_dimensionAttr(
1155       builder_.getI64IntegerAttr(hlo_iota->iota_dimension()));
1156   return iota;
1157 }
1158 
EmitReverseOp(const xla::HloInstruction * instr)1159 xla::StatusOr<lmhlo::ReverseOp> LhloDialectEmitter::EmitReverseOp(
1160     const xla::HloInstruction* instr) {
1161   TF_ASSIGN_OR_RETURN(auto reverse,
1162                       CreateOpWithoutAttrs<lmhlo::ReverseOp>(instr));
1163   reverse.dimensionsAttr(builder_.getI64TensorAttr(instr->dimensions()));
1164   return reverse;
1165 }
1166 
EmitTransposeOp(const xla::HloInstruction * instr)1167 xla::StatusOr<lmhlo::TransposeOp> LhloDialectEmitter::EmitTransposeOp(
1168     const xla::HloInstruction* instr) {
1169   TF_ASSIGN_OR_RETURN(auto transpose,
1170                       CreateOpWithoutAttrs<lmhlo::TransposeOp>(instr));
1171   transpose.permutationAttr(builder_.getI64TensorAttr(instr->dimensions()));
1172   return transpose;
1173 }
1174 
EmitPadOp(const xla::HloInstruction * instr)1175 xla::StatusOr<lmhlo::PadOp> LhloDialectEmitter::EmitPadOp(
1176     const xla::HloInstruction* instr) {
1177   TF_ASSIGN_OR_RETURN(auto pad, CreateOpWithoutAttrs<lmhlo::PadOp>(instr));
1178   auto hlo_pad = xla::Cast<xla::HloPadInstruction>(instr);
1179   std::vector<xla::int64> low, high, interior;
1180   for (const auto& dim : hlo_pad->padding_config().dimensions()) {
1181     low.push_back(dim.edge_padding_low());
1182     high.push_back(dim.edge_padding_high());
1183     interior.push_back(dim.interior_padding());
1184   }
1185   pad.edge_padding_lowAttr(builder_.getI64TensorAttr(low));
1186   pad.edge_padding_highAttr(builder_.getI64TensorAttr(high));
1187   pad.interior_paddingAttr(builder_.getI64TensorAttr(interior));
1188   return pad;
1189 }
1190 
EmitReduceWindowOp(const xla::HloInstruction * instr)1191 xla::StatusOr<lmhlo::ReduceWindowOp> LhloDialectEmitter::EmitReduceWindowOp(
1192     const xla::HloInstruction* instr) {
1193   TF_ASSIGN_OR_RETURN(auto reduce_window,
1194                       CreateOpWithoutAttrs<lmhlo::ReduceWindowOp>(instr));
1195   auto hlo_reduce_window = xla::Cast<xla::HloReduceWindowInstruction>(instr);
1196   std::vector<xla::int64> dims, strides, base_dilations, window_dilations,
1197       paddings;
1198   for (const auto& dim : hlo_reduce_window->window().dimensions()) {
1199     dims.push_back(dim.size());
1200     strides.push_back(dim.stride());
1201     base_dilations.push_back(dim.base_dilation());
1202     window_dilations.push_back(dim.window_dilation());
1203     paddings.push_back(dim.padding_low());
1204     paddings.push_back(dim.padding_high());
1205   }
1206   reduce_window.window_dimensionsAttr(builder_.getI64TensorAttr(dims));
1207   if (xla::window_util::HasStride(hlo_reduce_window->window())) {
1208     reduce_window.window_stridesAttr(builder_.getI64TensorAttr(strides));
1209   }
1210   if (xla::window_util::HasBaseDilation(hlo_reduce_window->window())) {
1211     reduce_window.base_dilationsAttr(builder_.getI64TensorAttr(base_dilations));
1212   }
1213   if (xla::window_util::HasWindowDilation(hlo_reduce_window->window())) {
1214     reduce_window.window_dilationsAttr(
1215         builder_.getI64TensorAttr(window_dilations));
1216   }
1217   CHECK_EQ(0, paddings.size() % 2);
1218   if (xla::window_util::HasPadding(hlo_reduce_window->window())) {
1219     reduce_window.paddingAttr(DenseIntElementsAttr::get(
1220         RankedTensorType::get({static_cast<int64_t>(paddings.size() / 2), 2},
1221                               builder_.getIntegerType(64)),
1222         paddings));
1223   }
1224   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1225       *hlo_reduce_window->called_computations()[0], &reduce_window.body(),
1226       &builder_));
1227   return reduce_window;
1228 }
1229 
EmitSliceOp(const xla::HloInstruction * instr)1230 xla::StatusOr<lmhlo::SliceOp> LhloDialectEmitter::EmitSliceOp(
1231     const xla::HloInstruction* instr) {
1232   TF_ASSIGN_OR_RETURN(auto slice, CreateOpWithoutAttrs<lmhlo::SliceOp>(instr));
1233   auto hlo_slice = xla::Cast<xla::HloSliceInstruction>(instr);
1234   slice.start_indicesAttr(builder_.getI64TensorAttr(hlo_slice->slice_starts()));
1235   slice.limit_indicesAttr(builder_.getI64TensorAttr(hlo_slice->slice_limits()));
1236   slice.stridesAttr(builder_.getI64TensorAttr(hlo_slice->slice_strides()));
1237   return slice;
1238 }
1239 
EmitGatherOp(const xla::HloInstruction * instr)1240 xla::StatusOr<lmhlo::GatherOp> LhloDialectEmitter::EmitGatherOp(
1241     const xla::HloInstruction* instr) {
1242   TF_ASSIGN_OR_RETURN(auto gather,
1243                       CreateOpWithoutAttrs<lmhlo::GatherOp>(instr));
1244   auto hlo_gather = xla::Cast<xla::HloGatherInstruction>(instr);
1245   gather.dimension_numbersAttr(xla::ConvertGatherDimensionNumbers(
1246       hlo_gather->gather_dimension_numbers(), &builder_));
1247   gather.slice_sizesAttr(builder_.getI64TensorAttr(
1248       std::vector<int64_t>(hlo_gather->gather_slice_sizes().begin(),
1249                            hlo_gather->gather_slice_sizes().end())));
1250   return gather;
1251 }
1252 
EmitDynamicSliceOp(const xla::HloInstruction * instr)1253 xla::StatusOr<lmhlo::DynamicSliceOp> LhloDialectEmitter::EmitDynamicSliceOp(
1254     const xla::HloInstruction* instr) {
1255   TF_ASSIGN_OR_RETURN(auto dynamic_slice,
1256                       CreateOpWithoutAttrs<lmhlo::DynamicSliceOp>(instr));
1257   auto hlo_dynamic_slice = xla::Cast<xla::HloDynamicSliceInstruction>(instr);
1258   dynamic_slice.slice_sizesAttr(
1259       builder_.getI64TensorAttr(hlo_dynamic_slice->dynamic_slice_sizes()));
1260   return dynamic_slice;
1261 }
1262 
EmitDotOp(const xla::HloInstruction * instr)1263 xla::StatusOr<lmhlo::DotOp> LhloDialectEmitter::EmitDotOp(
1264     const xla::HloInstruction* instr) {
1265   TF_ASSIGN_OR_RETURN(auto dot, CreateOpWithoutAttrs<lmhlo::DotOp>(instr));
1266   auto hlo_dot = xla::Cast<xla::HloDotInstruction>(instr);
1267   dot.dot_dimension_numbersAttr(xla::ConvertDotDimensionNumbers(
1268       hlo_dot->dot_dimension_numbers(), &builder_));
1269   dot.precision_configAttr(
1270       xla::ConvertPrecisionConfig(&hlo_dot->precision_config(), &builder_));
1271   return dot;
1272 }
1273 
1274 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
EmitRngGetAndUpdateStateOp(const xla::HloInstruction * instr)1275 LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
1276     const xla::HloInstruction* instr) {
1277   TF_ASSIGN_OR_RETURN(
1278       auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
1279   auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
1280   rng.deltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
1281   return rng;
1282 }
1283 
EmitFftOp(const HloInstruction * instr)1284 xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
1285     const HloInstruction* instr) {
1286   auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
1287   TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
1288   TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
1289                       xla::ConvertFftType(hlo_fft->fft_type()));
1290   StringAttr fft_type_attr =
1291       builder_.getStringAttr(mlir::mhlo::stringifyFftType(fft_type));
1292   fft.fft_typeAttr(fft_type_attr);
1293   fft.fft_lengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
1294   return fft;
1295 }
1296 
1297 xla::StatusOr<lmhlo::TriangularSolveOp>
EmitTriangularSolveOp(const xla::HloInstruction * instr)1298 LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
1299   auto hlo_triangular_solve =
1300       xla::Cast<xla::HloTriangularSolveInstruction>(instr);
1301   TF_ASSIGN_OR_RETURN(auto triangular_solve,
1302                       CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
1303   const xla::TriangularSolveOptions& options =
1304       hlo_triangular_solve->triangular_solve_options();
1305   triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side()));
1306   triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower()));
1307   triangular_solve.unit_diagonalAttr(
1308       builder_.getBoolAttr(options.unit_diagonal()));
1309   TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
1310                       xla::ConvertTranspose(options.transpose_a()));
1311   triangular_solve.transpose_aAttr(
1312       builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose)));
1313   triangular_solve.layout_aAttr(
1314       GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
1315   triangular_solve.layout_bAttr(
1316       GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
1317   triangular_solve.layout_outputAttr(
1318       GetLayoutAttribute(instr->shape().layout(), &builder_));
1319   return triangular_solve;
1320 }
1321 
GetLayoutAttribute(const xla::Layout & layout,Builder * builder)1322 mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
1323     const xla::Layout& layout, Builder* builder) {
1324   llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
1325                                                layout.minor_to_major().end());
1326   return builder->getIndexTensorAttr(minor_to_major);
1327 }
1328 
GetOrCreateArrayView(const xla::HloInstruction * instr,const xla::Shape & current_shape,const xla::ShapeIndex & shape_index)1329 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
1330     const xla::HloInstruction* instr, const xla::Shape& current_shape,
1331     const xla::ShapeIndex& shape_index) {
1332   // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
1333   // shape_index).
1334   auto& cached_value = slices_[std::make_pair(instr, shape_index)];
1335   if (cached_value) {
1336     return cached_value;
1337   }
1338 
1339   if (instr->IsConstant() && shape_index.empty()) {
1340     TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
1341     return cached_value = constant_memref;
1342   }
1343 
1344   // If the shape happens to have dynamic dimensions, create the memref using
1345   // the underlying static shape.
1346   // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
1347   // but static bounds in MLIR.
1348   const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
1349 
1350   TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
1351                                          static_shape, builder_));
1352   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1353                       assignment_.GetUniqueSlice(instr, shape_index));
1354   Value alloc = allocations_[slice.allocation()];
1355   if (alloc.getType() == out_type && slice.offset() == 0) {
1356     return cached_value = alloc;
1357   }
1358 
1359   auto out_memref_type = out_type.dyn_cast<MemRefType>();
1360   if (!out_memref_type)
1361     return tensorflow::errors::Internal(
1362         "Expected memref type when creating a view for leaf type of a "
1363         "tuple.");
1364 
1365   Value byte_shift =
1366       builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
1367 
1368   xla::Shape physical_shape =
1369       xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
1370           static_shape);
1371   TF_ASSIGN_OR_RETURN(
1372       Type physical_out_type,
1373       xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
1374 
1375   // TODO(timshen): revisit location handling.
1376   Location loc = builder_.getUnknownLoc();
1377 
1378   // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp produce
1379   // the physical shape (where dimensions are ordered in major to minor) first,
1380   // then follow up with a MemRefReinterpretCast to cast the resulting memref to
1381   // the original layout.
1382   Value result =
1383       builder_.create<ViewOp>(loc, physical_out_type, alloc, byte_shift,
1384                               /*sizes=*/ValueRange{});
1385   if (physical_out_type != out_type) {
1386     int64_t out_offset;
1387     SmallVector<int64_t, 4> out_strides;
1388     if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
1389       return tensorflow::errors::Internal(
1390           "Failed to get strides and offset from the output type.");
1391     result = builder_.create<MemRefReinterpretCastOp>(
1392         loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
1393         out_strides);
1394   }
1395   return cached_value = result;
1396 }
1397 
GetOrCreateViewImpl(const HloInstruction * instr,const Shape & current_shape,xla::ShapeIndex * current_shape_index,SmallVectorImpl<Value> * values)1398 Status LhloDialectEmitter::GetOrCreateViewImpl(
1399     const HloInstruction* instr, const Shape& current_shape,
1400     xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
1401   if (current_shape.IsTuple()) {
1402     for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
1403       current_shape_index->push_back(i);
1404       TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
1405           instr, current_shape.tuple_shapes(i), current_shape_index, values));
1406       current_shape_index->pop_back();
1407     }
1408     return Status::OK();
1409   }
1410   if (current_shape.IsArray()) {
1411     TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
1412                                                      *current_shape_index));
1413     values->push_back(v);
1414     return Status::OK();
1415   }
1416   return xla::InternalError("Unexpected shape kind for %s and shape index %s",
1417                             instr->ToString(), current_shape_index->ToString());
1418 }
1419 
1420 // Returns a view for the result of an instruction.
1421 // We first get a view for the slice in the allocation, and then may need to
1422 // create another view to adjust the slice for the shape of the instruction.
GetOrCreateView(const HloInstruction * instr,SmallVectorImpl<Value> * values,const xla::ShapeIndex & result_subset)1423 Status LhloDialectEmitter::GetOrCreateView(
1424     const HloInstruction* instr, SmallVectorImpl<Value>* values,
1425     const xla::ShapeIndex& result_subset) {
1426   xla::ShapeIndex shape_index = result_subset;
1427   const Shape& sub_shape =
1428       xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
1429   return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values);
1430 }
1431 
Initialize()1432 Status LhloDialectEmitter::Initialize() {
1433   mlir::IntegerAttr unique_id =
1434       builder_.getI32IntegerAttr(computation_.parent()->unique_id());
1435   module_->setAttr("hlo.unique_id", unique_id);
1436   std::string function_name =
1437       computation_.name().empty() ? "__compute" : computation_.name();
1438 
1439   // Create the function as () -> (), we'll compute the arguments from the
1440   // buffer allocation and update the type then.
1441   auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name,
1442                                 builder_.getFunctionType({}, {}));
1443   Block* block = func_op.addEntryBlock();
1444 
1445   llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
1446   for (const BufferAllocation& alloc : assignment_.Allocations())
1447     ordered_allocations.push_back(&alloc);
1448 
1449   if (computation_.IsEntryComputation()) {
1450     // Sort the rather arbitrarily ordered allocations to match the input/output
1451     // parameters. Specifically we want to sort buffer allocations in the
1452     // following order:
1453     // * Parameters always order before non-parameters.
1454     // * Different parameters order by parameter number.
1455     // * Different allocations for the same parameter order by the shape index.
1456     //
1457     // TODO(timshen): there should be only one non-parameter buffer, the temp
1458     // buffer. Check on that.
1459     const auto allocation_comparator = [](const BufferAllocation* lhs,
1460                                           const BufferAllocation* rhs) {
1461       if (lhs->is_entry_computation_parameter() !=
1462           rhs->is_entry_computation_parameter()) {
1463         return lhs->is_entry_computation_parameter() >
1464                rhs->is_entry_computation_parameter();
1465       }
1466       if (lhs->is_entry_computation_parameter()) {
1467         return std::tuple<int, const xla::ShapeIndex&>(
1468                    lhs->parameter_number(), lhs->param_shape_index()) <
1469                std::tuple<int, const xla::ShapeIndex&>(
1470                    rhs->parameter_number(), rhs->param_shape_index());
1471       }
1472       return false;
1473     };
1474 
1475     std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
1476                      allocation_comparator);
1477   }
1478 
1479   // The function signature will be composed of:
1480   // - one memref for each of the parameters.
1481   // - one memref for each other buffer allocation.
1482   llvm::SmallVector<DictionaryAttr, 8> args_attrs;
1483   for (const BufferAllocation* alloc : ordered_allocations) {
1484     if (computation_.IsEntryComputation() &&
1485         alloc->is_entry_computation_parameter()) {
1486       const xla::Shape& buffer_shape = xla::ShapeUtil::GetSubshape(
1487           computation_.parameter_instruction(alloc->parameter_number())
1488               ->shape(),
1489           alloc->param_shape_index());
1490 
1491       TF_ASSIGN_OR_RETURN(auto arg_type, xla::ConvertShapeToType<MemRefType>(
1492                                              buffer_shape, builder_));
1493 
1494       // First map parameters to memrefs on the operation.
1495       block->addArgument(arg_type);
1496       allocations_[alloc] = block->getArguments().back();
1497       NamedAttrList arg_attr_list;
1498       arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
1499       arg_attr_list.set("lmhlo.params",
1500                         builder_.getIndexAttr(alloc->parameter_number()));
1501       args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1502     } else {
1503       block->addArgument(MemRefType::get({alloc->size()}, i8_type_));
1504       allocations_[alloc] = block->getArguments().back();
1505 
1506       NamedAttrList arg_attr_list;
1507       arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
1508       arg_attr_list.set("lmhlo.liveout", builder_.getBoolAttr(true));
1509       args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1510     }
1511   }
1512 
1513   FunctionType function_type =
1514       builder_.getFunctionType(block->getArgumentTypes(), {});
1515   func_op.setType(function_type);
1516   func_op.setAllArgAttrs(args_attrs);
1517 
1518   SymbolTable symbol_table(module_);
1519   symbol_table.insert(func_op);
1520   builder_.setInsertionPointToEnd(block);
1521 
1522   auto return_op = builder_.create<ReturnOp>(builder_.getUnknownLoc());
1523   builder_ = OpBuilder(return_op);
1524 
1525   return Status::OK();
1526 }
1527 
createXlaHloToLhloWithXlaPass()1528 std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
1529   return std::make_unique<XlaHloToLhloPass>();
1530 }
1531 
HloToLhloModule(const BufferAssignment & assignment,const HloModule & hlo_module,ModuleOp module)1532 Status HloToLhloModule(const BufferAssignment& assignment,
1533                        const HloModule& hlo_module, ModuleOp module) {
1534   module.getContext()
1535       ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
1536                     lmhlo_gpu::LmhloGpuDialect>();
1537   const HloComputation* computation = hlo_module.entry_computation();
1538 
1539   LhloDialectEmitter emitter(assignment, *computation, module);
1540   TF_RETURN_IF_ERROR(emitter.Initialize());
1541 
1542   const xla::HloInstructionSequence* schedule =
1543       assignment.hlo_ordering().SequentialOrder(*computation);
1544   if (!schedule)
1545     return xla::Unimplemented("Missing sequential order for the computation");
1546   const std::vector<HloInstruction*>& ordering = schedule->instructions();
1547   return computation->AcceptOrdered(&emitter, ordering);
1548 }
1549 
HloTextToLhloTranslateFunction(llvm::StringRef input,MLIRContext * context)1550 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
1551                                                MLIRContext* context) {
1552   StatusOr<std::unique_ptr<HloModule>> maybe_module =
1553       xla::ParseAndReturnUnverifiedModule(
1554           absl::string_view(input.data(), input.size()));
1555   TF_CHECK_OK(maybe_module.status());
1556 
1557   OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
1558 
1559   TF_CHECK_OK(
1560       ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host"));
1561 
1562   return module;
1563 }
1564 
1565 static PassRegistration<XlaHloToLhloPass> registration(
1566     "xla-hlo-to-lhlo-with-xla",
1567     "Emit LHLO from HLO using the existing XLA implementation");
1568 
1569 }  // namespace mlir
1570