1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR.h"
19 #include "mlir/Target/LLVMIR/TypeTranslation.h"
20 #include "mlir/Translation.h"
21 
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/InlineAsm.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/IRReader/IRReader.h"
29 #include "llvm/Support/Error.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace mlir::LLVM;
34 
35 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
36 
37 // Utility to print an LLVM value as a string for passing to emitError().
38 // FIXME: Diagnostic should be able to natively handle types that have
39 // operator << (raw_ostream&) defined.
diag(llvm::Value & v)40 static std::string diag(llvm::Value &v) {
41   std::string s;
42   llvm::raw_string_ostream os(s);
43   os << v;
44   return os.str();
45 }
46 
47 // Handles importing globals and functions from an LLVM module.
48 namespace {
49 class Importer {
50 public:
Importer(MLIRContext * context,ModuleOp module)51   Importer(MLIRContext *context, ModuleOp module)
52       : b(context), context(context), module(module),
53         unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)),
54         typeTranslator(*context) {
55     b.setInsertionPointToStart(module.getBody());
56   }
57 
58   /// Imports `f` into the current module.
59   LogicalResult processFunction(llvm::Function *f);
60 
61   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
62   GlobalOp processGlobal(llvm::GlobalVariable *GV);
63 
64 private:
65   /// Returns personality of `f` as a FlatSymbolRefAttr.
66   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
67   /// Imports `bb` into `block`, which must be initially empty.
68   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
69   /// Imports `inst` and populates instMap[inst] with the imported Value.
70   LogicalResult processInstruction(llvm::Instruction *inst);
71   /// Creates an LLVMType for `type`.
72   LLVMType processType(llvm::Type *type);
73   /// `value` is an SSA-use. Return the remapped version of `value` or a
74   /// placeholder that will be remapped later if this is an instruction that
75   /// has not yet been visited.
76   Value processValue(llvm::Value *value);
77   /// Create the most accurate Location possible using a llvm::DebugLoc and
78   /// possibly an llvm::Instruction to narrow the Location if debug information
79   /// is unavailable.
80   Location processDebugLoc(const llvm::DebugLoc &loc,
81                            llvm::Instruction *inst = nullptr);
82   /// `br` branches to `target`. Append the block arguments to attach to the
83   /// generated branch op to `blockArguments`. These should be in the same order
84   /// as the PHIs in `target`.
85   LogicalResult processBranchArgs(llvm::Instruction *br,
86                                   llvm::BasicBlock *target,
87                                   SmallVectorImpl<Value> &blockArguments);
88   /// Returns the builtin type equivalent to be used in attributes for the given
89   /// LLVM IR dialect type.
90   Type getStdTypeForAttr(LLVMType type);
91   /// Return `value` as an attribute to attach to a GlobalOp.
92   Attribute getConstantAsAttr(llvm::Constant *value);
93   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
94   /// an expanded sequence of ops in the current function's entry block (for
95   /// ConstantExprs or ConstantGEPs).
96   Value processConstant(llvm::Constant *c);
97 
98   /// The current builder, pointing at where the next Instruction should be
99   /// generated.
100   OpBuilder b;
101   /// The current context.
102   MLIRContext *context;
103   /// The current module being created.
104   ModuleOp module;
105   /// The entry block of the current function being processed.
106   Block *currentEntryBlock;
107 
108   /// Globals are inserted before the first function, if any.
getGlobalInsertPt()109   Block::iterator getGlobalInsertPt() {
110     auto i = module.getBody()->begin();
111     while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i))
112       ++i;
113     return i;
114   }
115 
116   /// Functions are always inserted before the module terminator.
getFuncInsertPt()117   Block::iterator getFuncInsertPt() {
118     return std::prev(module.getBody()->end());
119   }
120 
121   /// Remapped blocks, for the current function.
122   DenseMap<llvm::BasicBlock *, Block *> blocks;
123   /// Remapped values. These are function-local.
124   DenseMap<llvm::Value *, Value> instMap;
125   /// Instructions that had not been defined when first encountered as a use.
126   /// Maps to the dummy Operation that was created in processValue().
127   DenseMap<llvm::Value *, Operation *> unknownInstMap;
128   /// Uniquing map of GlobalVariables.
129   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
130   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
131   Location unknownLoc;
132   /// The stateful type translator (contains named structs).
133   LLVM::TypeFromLLVMIRTranslator typeTranslator;
134 };
135 } // namespace
136 
processDebugLoc(const llvm::DebugLoc & loc,llvm::Instruction * inst)137 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
138                                    llvm::Instruction *inst) {
139   if (!loc && inst) {
140     std::string s;
141     llvm::raw_string_ostream os(s);
142     os << "llvm-imported-inst-%";
143     inst->printAsOperand(os, /*PrintType=*/false);
144     return FileLineColLoc::get(os.str(), 0, 0, context);
145   } else if (!loc) {
146     return unknownLoc;
147   }
148   // FIXME: Obtain the filename from DILocationInfo.
149   return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(),
150                              context);
151 }
152 
processType(llvm::Type * type)153 LLVMType Importer::processType(llvm::Type *type) {
154   if (LLVMType result = typeTranslator.translateType(type))
155     return result;
156 
157   // FIXME: Diagnostic should be able to natively handle types that have
158   // operator<<(raw_ostream&) defined.
159   std::string s;
160   llvm::raw_string_ostream os(s);
161   os << *type;
162   emitError(unknownLoc) << "unhandled type: " << os.str();
163   return nullptr;
164 }
165 
166 // We only need integers, floats, doubles, and vectors and tensors thereof for
167 // attributes. Scalar and vector types are converted to the standard
168 // equivalents. Array types are converted to ranked tensors; nested array types
169 // are converted to multi-dimensional tensors or vectors, depending on the
170 // innermost type being a scalar or a vector.
getStdTypeForAttr(LLVMType type)171 Type Importer::getStdTypeForAttr(LLVMType type) {
172   if (!type)
173     return nullptr;
174 
175   if (type.isIntegerTy())
176     return b.getIntegerType(type.getIntegerBitWidth());
177 
178   if (type.isFloatTy())
179     return b.getF32Type();
180 
181   if (type.isDoubleTy())
182     return b.getF64Type();
183 
184   // LLVM vectors can only contain scalars.
185   if (type.isVectorTy()) {
186     auto numElements = type.getVectorElementCount();
187     if (numElements.isScalable()) {
188       emitError(unknownLoc) << "scalable vectors not supported";
189       return nullptr;
190     }
191     Type elementType = getStdTypeForAttr(type.getVectorElementType());
192     if (!elementType)
193       return nullptr;
194     return VectorType::get(numElements.getKnownMinValue(), elementType);
195   }
196 
197   // LLVM arrays can contain other arrays or vectors.
198   if (type.isArrayTy()) {
199     // Recover the nested array shape.
200     SmallVector<int64_t, 4> shape;
201     shape.push_back(type.getArrayNumElements());
202     while (type.getArrayElementType().isArrayTy()) {
203       type = type.getArrayElementType();
204       shape.push_back(type.getArrayNumElements());
205     }
206 
207     // If the innermost type is a vector, use the multi-dimensional vector as
208     // attribute type.
209     if (type.getArrayElementType().isVectorTy()) {
210       LLVMType vectorType = type.getArrayElementType();
211       auto numElements = vectorType.getVectorElementCount();
212       if (numElements.isScalable()) {
213         emitError(unknownLoc) << "scalable vectors not supported";
214         return nullptr;
215       }
216       shape.push_back(numElements.getKnownMinValue());
217 
218       Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
219       if (!elementType)
220         return nullptr;
221       return VectorType::get(shape, elementType);
222     }
223 
224     // Otherwise use a tensor.
225     Type elementType = getStdTypeForAttr(type.getArrayElementType());
226     if (!elementType)
227       return nullptr;
228     return RankedTensorType::get(shape, elementType);
229   }
230 
231   return nullptr;
232 }
233 
234 // Get the given constant as an attribute. Not all constants can be represented
235 // as attributes.
getConstantAsAttr(llvm::Constant * value)236 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
237   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
238     return b.getIntegerAttr(
239         IntegerType::get(ci->getType()->getBitWidth(), context),
240         ci->getValue());
241   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
242     if (c->isString())
243       return b.getStringAttr(c->getAsString());
244   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
245     if (c->getType()->isDoubleTy())
246       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
247     else if (c->getType()->isFloatingPointTy())
248       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
249   }
250   if (auto *f = dyn_cast<llvm::Function>(value))
251     return b.getSymbolRefAttr(f->getName());
252 
253   // Convert constant data to a dense elements attribute.
254   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
255     LLVMType type = processType(cd->getElementType());
256     if (!type)
257       return nullptr;
258 
259     auto attrType = getStdTypeForAttr(processType(cd->getType()))
260                         .dyn_cast_or_null<ShapedType>();
261     if (!attrType)
262       return nullptr;
263 
264     if (type.isIntegerTy()) {
265       SmallVector<APInt, 8> values;
266       values.reserve(cd->getNumElements());
267       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
268         values.push_back(cd->getElementAsAPInt(i));
269       return DenseElementsAttr::get(attrType, values);
270     }
271 
272     if (type.isFloatTy() || type.isDoubleTy()) {
273       SmallVector<APFloat, 8> values;
274       values.reserve(cd->getNumElements());
275       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
276         values.push_back(cd->getElementAsAPFloat(i));
277       return DenseElementsAttr::get(attrType, values);
278     }
279 
280     return nullptr;
281   }
282 
283   // Unpack constant aggregates to create dense elements attribute whenever
284   // possible. Return nullptr (failure) otherwise.
285   if (isa<llvm::ConstantAggregate>(value)) {
286     auto outerType = getStdTypeForAttr(processType(value->getType()))
287                          .dyn_cast_or_null<ShapedType>();
288     if (!outerType)
289       return nullptr;
290 
291     SmallVector<Attribute, 8> values;
292     SmallVector<int64_t, 8> shape;
293 
294     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
295       auto nested = getConstantAsAttr(value->getAggregateElement(i))
296                         .dyn_cast_or_null<DenseElementsAttr>();
297       if (!nested)
298         return nullptr;
299 
300       values.append(nested.attr_value_begin(), nested.attr_value_end());
301     }
302 
303     return DenseElementsAttr::get(outerType, values);
304   }
305 
306   return nullptr;
307 }
308 
processGlobal(llvm::GlobalVariable * GV)309 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
310   auto it = globals.find(GV);
311   if (it != globals.end())
312     return it->second;
313 
314   OpBuilder b(module.getBody(), getGlobalInsertPt());
315   Attribute valueAttr;
316   if (GV->hasInitializer())
317     valueAttr = getConstantAsAttr(GV->getInitializer());
318   LLVMType type = processType(GV->getValueType());
319   if (!type)
320     return nullptr;
321   GlobalOp op = b.create<GlobalOp>(
322       UnknownLoc::get(context), type, GV->isConstant(),
323       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
324   if (GV->hasInitializer() && !valueAttr) {
325     Region &r = op.getInitializerRegion();
326     currentEntryBlock = b.createBlock(&r);
327     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
328     Value v = processConstant(GV->getInitializer());
329     if (!v)
330       return nullptr;
331     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
332   }
333   return globals[GV] = op;
334 }
335 
processConstant(llvm::Constant * c)336 Value Importer::processConstant(llvm::Constant *c) {
337   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
338   if (Attribute attr = getConstantAsAttr(c)) {
339     // These constants can be represented as attributes.
340     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
341     LLVMType type = processType(c->getType());
342     if (!type)
343       return nullptr;
344     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
345       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
346                                                      symbolRef.getValue());
347     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
348   }
349   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
350     LLVMType type = processType(cn->getType());
351     if (!type)
352       return nullptr;
353     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
354   }
355   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
356     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
357                                       processGlobal(GV));
358 
359   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
360     llvm::Instruction *i = ce->getAsInstruction();
361     OpBuilder::InsertionGuard guard(b);
362     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
363     if (failed(processInstruction(i)))
364       return nullptr;
365     assert(instMap.count(i));
366 
367     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
368     // op.
369     i->deleteValue();
370     return instMap[c] = instMap[i];
371   }
372   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
373     LLVMType type = processType(ue->getType());
374     if (!type)
375       return nullptr;
376     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
377   }
378   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
379   return nullptr;
380 }
381 
processValue(llvm::Value * value)382 Value Importer::processValue(llvm::Value *value) {
383   auto it = instMap.find(value);
384   if (it != instMap.end())
385     return it->second;
386 
387   // We don't expect to see instructions in dominator order. If we haven't seen
388   // this instruction yet, create an unknown op and remap it later.
389   if (isa<llvm::Instruction>(value)) {
390     OperationState state(UnknownLoc::get(context), "llvm.unknown");
391     LLVMType type = processType(value->getType());
392     if (!type)
393       return nullptr;
394     state.addTypes(type);
395     unknownInstMap[value] = b.createOperation(state);
396     return unknownInstMap[value]->getResult(0);
397   }
398 
399   if (auto *c = dyn_cast<llvm::Constant>(value))
400     return processConstant(c);
401 
402   emitError(unknownLoc) << "unhandled value: " << diag(*value);
403   return nullptr;
404 }
405 
406 /// Return the MLIR OperationName for the given LLVM opcode.
lookupOperationNameFromOpcode(unsigned opcode)407 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
408 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
409 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
410 // instructions.
411 #define INST(llvm_n, mlir_n)                                                   \
412   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
413   static const DenseMap<unsigned, StringRef> opcMap = {
414       // Ret is handled specially.
415       // Br is handled specially.
416       // FIXME: switch
417       // FIXME: indirectbr
418       // FIXME: invoke
419       INST(Resume, Resume),
420       // FIXME: unreachable
421       // FIXME: cleanupret
422       // FIXME: catchret
423       // FIXME: catchswitch
424       // FIXME: callbr
425       // FIXME: fneg
426       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
427       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
428       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
429       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
430       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
431       INST(Store, Store),
432       // Getelementptr is handled specially.
433       INST(Ret, Return), INST(Fence, Fence),
434       // FIXME: atomiccmpxchg
435       // FIXME: atomicrmw
436       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
437       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
438       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
439       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
440       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
441       // FIXME: cleanuppad
442       // FIXME: catchpad
443       // ICmp is handled specially.
444       // FIXME: fcmp
445       // PHI is handled specially.
446       INST(Freeze, Freeze), INST(Call, Call),
447       // FIXME: select
448       // FIXME: vaarg
449       // FIXME: extractelement
450       // FIXME: insertelement
451       // FIXME: shufflevector
452       // FIXME: extractvalue
453       // FIXME: insertvalue
454       // FIXME: landingpad
455   };
456 #undef INST
457 
458   return opcMap.lookup(opcode);
459 }
460 
getICmpPredicate(llvm::CmpInst::Predicate p)461 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
462   switch (p) {
463   default:
464     llvm_unreachable("incorrect comparison predicate");
465   case llvm::CmpInst::Predicate::ICMP_EQ:
466     return LLVM::ICmpPredicate::eq;
467   case llvm::CmpInst::Predicate::ICMP_NE:
468     return LLVM::ICmpPredicate::ne;
469   case llvm::CmpInst::Predicate::ICMP_SLT:
470     return LLVM::ICmpPredicate::slt;
471   case llvm::CmpInst::Predicate::ICMP_SLE:
472     return LLVM::ICmpPredicate::sle;
473   case llvm::CmpInst::Predicate::ICMP_SGT:
474     return LLVM::ICmpPredicate::sgt;
475   case llvm::CmpInst::Predicate::ICMP_SGE:
476     return LLVM::ICmpPredicate::sge;
477   case llvm::CmpInst::Predicate::ICMP_ULT:
478     return LLVM::ICmpPredicate::ult;
479   case llvm::CmpInst::Predicate::ICMP_ULE:
480     return LLVM::ICmpPredicate::ule;
481   case llvm::CmpInst::Predicate::ICMP_UGT:
482     return LLVM::ICmpPredicate::ugt;
483   case llvm::CmpInst::Predicate::ICMP_UGE:
484     return LLVM::ICmpPredicate::uge;
485   }
486   llvm_unreachable("incorrect comparison predicate");
487 }
488 
getLLVMAtomicOrdering(llvm::AtomicOrdering ordering)489 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
490   switch (ordering) {
491   case llvm::AtomicOrdering::NotAtomic:
492     return LLVM::AtomicOrdering::not_atomic;
493   case llvm::AtomicOrdering::Unordered:
494     return LLVM::AtomicOrdering::unordered;
495   case llvm::AtomicOrdering::Monotonic:
496     return LLVM::AtomicOrdering::monotonic;
497   case llvm::AtomicOrdering::Acquire:
498     return LLVM::AtomicOrdering::acquire;
499   case llvm::AtomicOrdering::Release:
500     return LLVM::AtomicOrdering::release;
501   case llvm::AtomicOrdering::AcquireRelease:
502     return LLVM::AtomicOrdering::acq_rel;
503   case llvm::AtomicOrdering::SequentiallyConsistent:
504     return LLVM::AtomicOrdering::seq_cst;
505   }
506   llvm_unreachable("incorrect atomic ordering");
507 }
508 
509 // `br` branches to `target`. Return the branch arguments to `br`, in the
510 // same order of the PHIs in `target`.
511 LogicalResult
processBranchArgs(llvm::Instruction * br,llvm::BasicBlock * target,SmallVectorImpl<Value> & blockArguments)512 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
513                             SmallVectorImpl<Value> &blockArguments) {
514   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
515     auto *PN = cast<llvm::PHINode>(&*inst);
516     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
517     if (!value)
518       return failure();
519     blockArguments.push_back(value);
520   }
521   return success();
522 }
523 
processInstruction(llvm::Instruction * inst)524 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
525   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
526   // flags and call / operand attributes are not supported.
527   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
528   Value &v = instMap[inst];
529   assert(!v && "processInstruction must be called only once per instruction!");
530   switch (inst->getOpcode()) {
531   default:
532     return emitError(loc) << "unknown instruction: " << diag(*inst);
533   case llvm::Instruction::Add:
534   case llvm::Instruction::FAdd:
535   case llvm::Instruction::Sub:
536   case llvm::Instruction::FSub:
537   case llvm::Instruction::Mul:
538   case llvm::Instruction::FMul:
539   case llvm::Instruction::UDiv:
540   case llvm::Instruction::SDiv:
541   case llvm::Instruction::FDiv:
542   case llvm::Instruction::URem:
543   case llvm::Instruction::SRem:
544   case llvm::Instruction::FRem:
545   case llvm::Instruction::Shl:
546   case llvm::Instruction::LShr:
547   case llvm::Instruction::AShr:
548   case llvm::Instruction::And:
549   case llvm::Instruction::Or:
550   case llvm::Instruction::Xor:
551   case llvm::Instruction::Alloca:
552   case llvm::Instruction::Load:
553   case llvm::Instruction::Store:
554   case llvm::Instruction::Ret:
555   case llvm::Instruction::Resume:
556   case llvm::Instruction::Trunc:
557   case llvm::Instruction::ZExt:
558   case llvm::Instruction::SExt:
559   case llvm::Instruction::FPToUI:
560   case llvm::Instruction::FPToSI:
561   case llvm::Instruction::UIToFP:
562   case llvm::Instruction::SIToFP:
563   case llvm::Instruction::FPTrunc:
564   case llvm::Instruction::FPExt:
565   case llvm::Instruction::PtrToInt:
566   case llvm::Instruction::IntToPtr:
567   case llvm::Instruction::AddrSpaceCast:
568   case llvm::Instruction::Freeze:
569   case llvm::Instruction::BitCast: {
570     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
571     SmallVector<Value, 4> ops;
572     ops.reserve(inst->getNumOperands());
573     for (auto *op : inst->operand_values()) {
574       Value value = processValue(op);
575       if (!value)
576         return failure();
577       ops.push_back(value);
578     }
579     state.addOperands(ops);
580     if (!inst->getType()->isVoidTy()) {
581       LLVMType type = processType(inst->getType());
582       if (!type)
583         return failure();
584       state.addTypes(type);
585     }
586     Operation *op = b.createOperation(state);
587     if (!inst->getType()->isVoidTy())
588       v = op->getResult(0);
589     return success();
590   }
591   case llvm::Instruction::ICmp: {
592     Value lhs = processValue(inst->getOperand(0));
593     Value rhs = processValue(inst->getOperand(1));
594     if (!lhs || !rhs)
595       return failure();
596     v = b.create<ICmpOp>(
597         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
598         rhs);
599     return success();
600   }
601   case llvm::Instruction::Br: {
602     auto *brInst = cast<llvm::BranchInst>(inst);
603     OperationState state(loc,
604                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
605     if (brInst->isConditional()) {
606       Value condition = processValue(brInst->getCondition());
607       if (!condition)
608         return failure();
609       state.addOperands(condition);
610     }
611 
612     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
613     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
614       auto *succ = brInst->getSuccessor(i);
615       SmallVector<Value, 4> blockArguments;
616       if (failed(processBranchArgs(brInst, succ, blockArguments)))
617         return failure();
618       state.addSuccessors(blocks[succ]);
619       state.addOperands(blockArguments);
620       operandSegmentSizes[i + 1] = blockArguments.size();
621     }
622 
623     if (brInst->isConditional()) {
624       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
625                          b.getI32VectorAttr(operandSegmentSizes));
626     }
627 
628     b.createOperation(state);
629     return success();
630   }
631   case llvm::Instruction::PHI: {
632     LLVMType type = processType(inst->getType());
633     if (!type)
634       return failure();
635     v = b.getInsertionBlock()->addArgument(type);
636     return success();
637   }
638   case llvm::Instruction::Call: {
639     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
640     SmallVector<Value, 4> ops;
641     ops.reserve(inst->getNumOperands());
642     for (auto &op : ci->arg_operands()) {
643       Value arg = processValue(op.get());
644       if (!arg)
645         return failure();
646       ops.push_back(arg);
647     }
648 
649     SmallVector<Type, 2> tys;
650     if (!ci->getType()->isVoidTy()) {
651       LLVMType type = processType(inst->getType());
652       if (!type)
653         return failure();
654       tys.push_back(type);
655     }
656     Operation *op;
657     if (llvm::Function *callee = ci->getCalledFunction()) {
658       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
659                             ops);
660     } else {
661       Value calledValue = processValue(ci->getCalledOperand());
662       if (!calledValue)
663         return failure();
664       ops.insert(ops.begin(), calledValue);
665       op = b.create<CallOp>(loc, tys, ops);
666     }
667     if (!ci->getType()->isVoidTy())
668       v = op->getResult(0);
669     return success();
670   }
671   case llvm::Instruction::LandingPad: {
672     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
673     SmallVector<Value, 4> ops;
674 
675     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
676       ops.push_back(processConstant(lpi->getClause(i)));
677 
678     Type ty = processType(lpi->getType());
679     if (!ty)
680       return failure();
681 
682     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
683     return success();
684   }
685   case llvm::Instruction::Invoke: {
686     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
687 
688     SmallVector<Type, 2> tys;
689     if (!ii->getType()->isVoidTy())
690       tys.push_back(processType(inst->getType()));
691 
692     SmallVector<Value, 4> ops;
693     ops.reserve(inst->getNumOperands() + 1);
694     for (auto &op : ii->arg_operands())
695       ops.push_back(processValue(op.get()));
696 
697     SmallVector<Value, 4> normalArgs, unwindArgs;
698     processBranchArgs(ii, ii->getNormalDest(), normalArgs);
699     processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
700 
701     Operation *op;
702     if (llvm::Function *callee = ii->getCalledFunction()) {
703       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
704                               ops, blocks[ii->getNormalDest()], normalArgs,
705                               blocks[ii->getUnwindDest()], unwindArgs);
706     } else {
707       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
708       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
709                               normalArgs, blocks[ii->getUnwindDest()],
710                               unwindArgs);
711     }
712 
713     if (!ii->getType()->isVoidTy())
714       v = op->getResult(0);
715     return success();
716   }
717   case llvm::Instruction::Fence: {
718     StringRef syncscope;
719     SmallVector<StringRef, 4> ssNs;
720     llvm::LLVMContext &llvmContext = inst->getContext();
721     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
722     llvmContext.getSyncScopeNames(ssNs);
723     int fenceSyncScopeID = fence->getSyncScopeID();
724     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
725       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
726         syncscope = ssNs[i];
727         break;
728       }
729     }
730     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
731                       syncscope);
732     return success();
733   }
734   case llvm::Instruction::GetElementPtr: {
735     // FIXME: Support inbounds GEPs.
736     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
737     SmallVector<Value, 4> ops;
738     for (auto *op : gep->operand_values()) {
739       Value value = processValue(op);
740       if (!value)
741         return failure();
742       ops.push_back(value);
743     }
744     Type type = processType(inst->getType());
745     if (!type)
746       return failure();
747     v = b.create<GEPOp>(loc, type, ops);
748     return success();
749   }
750   }
751 }
752 
getPersonalityAsAttr(llvm::Function * f)753 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
754   if (!f->hasPersonalityFn())
755     return nullptr;
756 
757   llvm::Constant *pf = f->getPersonalityFn();
758 
759   // If it directly has a name, we can use it.
760   if (pf->hasName())
761     return b.getSymbolRefAttr(pf->getName());
762 
763   // If it doesn't have a name, currently, only function pointers that are
764   // bitcast to i8* are parsed.
765   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
766     if (ce->getOpcode() == llvm::Instruction::BitCast &&
767         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
768       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
769         return b.getSymbolRefAttr(func->getName());
770     }
771   }
772   return FlatSymbolRefAttr();
773 }
774 
processFunction(llvm::Function * f)775 LogicalResult Importer::processFunction(llvm::Function *f) {
776   blocks.clear();
777   instMap.clear();
778   unknownInstMap.clear();
779 
780   LLVMType functionType = processType(f->getFunctionType());
781   if (!functionType)
782     return failure();
783 
784   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
785   LLVMFuncOp fop =
786       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
787                            convertLinkageFromLLVM(f->getLinkage()));
788 
789   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
790     fop.setAttr(b.getIdentifier("personality"), personality);
791   else if (f->hasPersonalityFn())
792     emitWarning(UnknownLoc::get(context),
793                 "could not deduce personality, skipping it");
794 
795   if (f->isDeclaration())
796     return success();
797 
798   // Eagerly create all blocks.
799   SmallVector<Block *, 4> blockList;
800   for (llvm::BasicBlock &bb : *f) {
801     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
802     blocks[&bb] = blockList.back();
803   }
804   currentEntryBlock = blockList[0];
805 
806   // Add function arguments to the entry block.
807   for (auto kv : llvm::enumerate(f->args()))
808     instMap[&kv.value()] = blockList[0]->addArgument(
809         functionType.getFunctionParamType(kv.index()));
810 
811   for (auto bbs : llvm::zip(*f, blockList)) {
812     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
813       return failure();
814   }
815 
816   // Now that all instructions are guaranteed to have been visited, ensure
817   // any unknown uses we encountered are remapped.
818   for (auto &llvmAndUnknown : unknownInstMap) {
819     assert(instMap.count(llvmAndUnknown.first));
820     Value newValue = instMap[llvmAndUnknown.first];
821     Value oldValue = llvmAndUnknown.second->getResult(0);
822     oldValue.replaceAllUsesWith(newValue);
823     llvmAndUnknown.second->erase();
824   }
825   return success();
826 }
827 
processBasicBlock(llvm::BasicBlock * bb,Block * block)828 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
829   b.setInsertionPointToStart(block);
830   for (llvm::Instruction &inst : *bb) {
831     if (failed(processInstruction(&inst)))
832       return failure();
833   }
834   return success();
835 }
836 
837 OwningModuleRef
translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,MLIRContext * context)838 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
839                               MLIRContext *context) {
840   context->loadDialect<LLVMDialect>();
841   OwningModuleRef module(ModuleOp::create(
842       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
843 
844   Importer deserializer(context, module.get());
845   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
846     if (!deserializer.processGlobal(&gv))
847       return {};
848   }
849   for (llvm::Function &f : llvmModule->functions()) {
850     if (failed(deserializer.processFunction(&f)))
851       return {};
852   }
853 
854   return module;
855 }
856 
857 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
858 // LLVM dialect.
translateLLVMIRToModule(llvm::SourceMgr & sourceMgr,MLIRContext * context)859 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
860                                         MLIRContext *context) {
861   llvm::SMDiagnostic err;
862   llvm::LLVMContext llvmContext;
863   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
864       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
865   if (!llvmModule) {
866     std::string errStr;
867     llvm::raw_string_ostream errStream(errStr);
868     err.print(/*ProgName=*/"", errStream);
869     emitError(UnknownLoc::get(context)) << errStream.str();
870     return {};
871   }
872   return translateLLVMIRToModule(std::move(llvmModule), context);
873 }
874 
875 namespace mlir {
registerFromLLVMIRTranslation()876 void registerFromLLVMIRTranslation() {
877   TranslateToMLIRRegistration fromLLVM(
878       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
879         return ::translateLLVMIRToModule(sourceMgr, context);
880       });
881 }
882 } // namespace mlir
883