1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/MLIRContext.h"
21 
22 #include "llvm/ADT/StringSwitch.h"
23 #include "llvm/AsmParser/Parser.h"
24 #include "llvm/Bitcode/BitcodeReader.h"
25 #include "llvm/Bitcode/BitcodeWriter.h"
26 #include "llvm/IR/Attributes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/Support/Mutex.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace mlir::LLVM;
34 
35 static constexpr const char kVolatileAttrName[] = "volatile_";
36 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
37 
38 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // Printing/parsing for LLVM::CmpOp.
42 //===----------------------------------------------------------------------===//
printICmpOp(OpAsmPrinter & p,ICmpOp & op)43 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
44   p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
45     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
46   p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
47   p << " : " << op.lhs().getType();
48 }
49 
printFCmpOp(OpAsmPrinter & p,FCmpOp & op)50 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
51   p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
52     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
53   p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
54   p << " : " << op.lhs().getType();
55 }
56 
57 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
58 //                 attribute-dict? `:` type
59 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
60 //                 attribute-dict? `:` type
61 template <typename CmpPredicateType>
parseCmpOp(OpAsmParser & parser,OperationState & result)62 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
63   Builder &builder = parser.getBuilder();
64 
65   StringAttr predicateAttr;
66   OpAsmParser::OperandType lhs, rhs;
67   Type type;
68   llvm::SMLoc predicateLoc, trailingTypeLoc;
69   if (parser.getCurrentLocation(&predicateLoc) ||
70       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
71       parser.parseOperand(lhs) || parser.parseComma() ||
72       parser.parseOperand(rhs) ||
73       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
75       parser.resolveOperand(lhs, type, result.operands) ||
76       parser.resolveOperand(rhs, type, result.operands))
77     return failure();
78 
79   // Replace the string attribute `predicate` with an integer attribute.
80   int64_t predicateValue = 0;
81   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
82     Optional<ICmpPredicate> predicate =
83         symbolizeICmpPredicate(predicateAttr.getValue());
84     if (!predicate)
85       return parser.emitError(predicateLoc)
86              << "'" << predicateAttr.getValue()
87              << "' is an incorrect value of the 'predicate' attribute";
88     predicateValue = static_cast<int64_t>(predicate.getValue());
89   } else {
90     Optional<FCmpPredicate> predicate =
91         symbolizeFCmpPredicate(predicateAttr.getValue());
92     if (!predicate)
93       return parser.emitError(predicateLoc)
94              << "'" << predicateAttr.getValue()
95              << "' is an incorrect value of the 'predicate' attribute";
96     predicateValue = static_cast<int64_t>(predicate.getValue());
97   }
98 
99   result.attributes.set("predicate",
100                         parser.getBuilder().getI64IntegerAttr(predicateValue));
101 
102   // The result type is either i1 or a vector type <? x i1> if the inputs are
103   // vectors.
104   auto resultType = LLVMType::getInt1Ty(builder.getContext());
105   auto argType = type.dyn_cast<LLVM::LLVMType>();
106   if (!argType)
107     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
108   if (argType.isVectorTy())
109     resultType =
110         LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
111 
112   result.addTypes({resultType});
113   return success();
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Printing/parsing for LLVM::AllocaOp.
118 //===----------------------------------------------------------------------===//
119 
printAllocaOp(OpAsmPrinter & p,AllocaOp & op)120 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
121   auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
122 
123   auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
124                                   op.getContext());
125 
126   p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
127   if (op.alignment().hasValue() && *op.alignment() != 0)
128     p.printOptionalAttrDict(op.getAttrs());
129   else
130     p.printOptionalAttrDict(op.getAttrs(), {"alignment"});
131   p << " : " << funcTy;
132 }
133 
134 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
135 //                 `:` type `,` type
parseAllocaOp(OpAsmParser & parser,OperationState & result)136 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
137   OpAsmParser::OperandType arraySize;
138   Type type, elemType;
139   llvm::SMLoc trailingTypeLoc;
140   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
141       parser.parseType(elemType) ||
142       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
143       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
144     return failure();
145 
146   Optional<NamedAttribute> alignmentAttr =
147       result.attributes.getNamed("alignment");
148   if (alignmentAttr.hasValue()) {
149     auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>();
150     if (!alignmentInt)
151       return parser.emitError(parser.getNameLoc(),
152                               "expected integer alignment");
153     if (alignmentInt.getValue().isNullValue())
154       result.attributes.erase("alignment");
155   }
156 
157   // Extract the result type from the trailing function type.
158   auto funcType = type.dyn_cast<FunctionType>();
159   if (!funcType || funcType.getNumInputs() != 1 ||
160       funcType.getNumResults() != 1)
161     return parser.emitError(
162         trailingTypeLoc,
163         "expected trailing function type with one argument and one result");
164 
165   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
166     return failure();
167 
168   result.addTypes({funcType.getResult(0)});
169   return success();
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // LLVM::BrOp
174 //===----------------------------------------------------------------------===//
175 
176 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)177 BrOp::getMutableSuccessorOperands(unsigned index) {
178   assert(index == 0 && "invalid successor index");
179   return destOperandsMutable();
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // LLVM::CondBrOp
184 //===----------------------------------------------------------------------===//
185 
186 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)187 CondBrOp::getMutableSuccessorOperands(unsigned index) {
188   assert(index < getNumSuccessors() && "invalid successor index");
189   return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Builder, printer and parser for for LLVM::LoadOp.
194 //===----------------------------------------------------------------------===//
195 
build(OpBuilder & builder,OperationState & result,Type t,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)196 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
197                    Value addr, unsigned alignment, bool isVolatile,
198                    bool isNonTemporal) {
199   result.addOperands(addr);
200   result.addTypes(t);
201   if (isVolatile)
202     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
203   if (isNonTemporal)
204     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
205   if (alignment != 0)
206     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
207 }
208 
printLoadOp(OpAsmPrinter & p,LoadOp & op)209 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
210   p << op.getOperationName() << ' ';
211   if (op.volatile_())
212     p << "volatile ";
213   p << op.addr();
214   p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
215   p << " : " << op.addr().getType();
216 }
217 
218 // Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
219 // the resulting type wrapped in MLIR, or nullptr on error.
getLoadStoreElementType(OpAsmParser & parser,Type type,llvm::SMLoc trailingTypeLoc)220 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
221                                     llvm::SMLoc trailingTypeLoc) {
222   auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
223   if (!llvmTy)
224     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
225            nullptr;
226   if (!llvmTy.isPointerTy())
227     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
228            nullptr;
229   return llvmTy.getPointerElementTy();
230 }
231 
232 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
parseLoadOp(OpAsmParser & parser,OperationState & result)233 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
234   OpAsmParser::OperandType addr;
235   Type type;
236   llvm::SMLoc trailingTypeLoc;
237 
238   if (succeeded(parser.parseOptionalKeyword("volatile")))
239     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
240 
241   if (parser.parseOperand(addr) ||
242       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
243       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
244       parser.resolveOperand(addr, type, result.operands))
245     return failure();
246 
247   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
248 
249   result.addTypes(elemTy);
250   return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // Builder, printer and parser for LLVM::StoreOp.
255 //===----------------------------------------------------------------------===//
256 
build(OpBuilder & builder,OperationState & result,Value value,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)257 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
258                     Value addr, unsigned alignment, bool isVolatile,
259                     bool isNonTemporal) {
260   result.addOperands({value, addr});
261   result.addTypes({});
262   if (isVolatile)
263     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
264   if (isNonTemporal)
265     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
266   if (alignment != 0)
267     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
268 }
269 
printStoreOp(OpAsmPrinter & p,StoreOp & op)270 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
271   p << op.getOperationName() << ' ';
272   if (op.volatile_())
273     p << "volatile ";
274   p << op.value() << ", " << op.addr();
275   p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
276   p << " : " << op.addr().getType();
277 }
278 
279 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
280 //                 attribute-dict? `:` type
parseStoreOp(OpAsmParser & parser,OperationState & result)281 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
282   OpAsmParser::OperandType addr, value;
283   Type type;
284   llvm::SMLoc trailingTypeLoc;
285 
286   if (succeeded(parser.parseOptionalKeyword("volatile")))
287     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
288 
289   if (parser.parseOperand(value) || parser.parseComma() ||
290       parser.parseOperand(addr) ||
291       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
292       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
293     return failure();
294 
295   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
296   if (!elemTy)
297     return failure();
298 
299   if (parser.resolveOperand(value, elemTy, result.operands) ||
300       parser.resolveOperand(addr, type, result.operands))
301     return failure();
302 
303   return success();
304 }
305 
306 ///===---------------------------------------------------------------------===//
307 /// LLVM::InvokeOp
308 ///===---------------------------------------------------------------------===//
309 
310 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)311 InvokeOp::getMutableSuccessorOperands(unsigned index) {
312   assert(index < getNumSuccessors() && "invalid successor index");
313   return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
314 }
315 
verify(InvokeOp op)316 static LogicalResult verify(InvokeOp op) {
317   if (op.getNumResults() > 1)
318     return op.emitOpError("must have 0 or 1 result");
319 
320   Block *unwindDest = op.unwindDest();
321   if (unwindDest->empty())
322     return op.emitError(
323         "must have at least one operation in unwind destination");
324 
325   // In unwind destination, first operation must be LandingpadOp
326   if (!isa<LandingpadOp>(unwindDest->front()))
327     return op.emitError("first operation in unwind destination should be a "
328                         "llvm.landingpad operation");
329 
330   return success();
331 }
332 
printInvokeOp(OpAsmPrinter & p,InvokeOp op)333 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
334   auto callee = op.callee();
335   bool isDirect = callee.hasValue();
336 
337   p << op.getOperationName() << ' ';
338 
339   // Either function name or pointer
340   if (isDirect)
341     p.printSymbolName(callee.getValue());
342   else
343     p << op.getOperand(0);
344 
345   p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
346   p << " to ";
347   p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
348   p << " unwind ";
349   p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());
350 
351   p.printOptionalAttrDict(op.getAttrs(),
352                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
353   p << " : ";
354   p.printFunctionalType(
355       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
356       op.getResultTypes());
357 }
358 
359 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
360 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
361 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
362 ///                  attribute-dict? `:` function-type
parseInvokeOp(OpAsmParser & parser,OperationState & result)363 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
364   SmallVector<OpAsmParser::OperandType, 8> operands;
365   FunctionType funcType;
366   SymbolRefAttr funcAttr;
367   llvm::SMLoc trailingTypeLoc;
368   Block *normalDest, *unwindDest;
369   SmallVector<Value, 4> normalOperands, unwindOperands;
370   Builder &builder = parser.getBuilder();
371 
372   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
373   // case of an indirect call, there will be 1 operand before `(`.  In case of a
374   // direct call, there will be no operands and the parser will stop at the
375   // function identifier without complaining.
376   if (parser.parseOperandList(operands))
377     return failure();
378   bool isDirect = operands.empty();
379 
380   // Optionally parse a function identifier.
381   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
382     return failure();
383 
384   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
385       parser.parseKeyword("to") ||
386       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
387       parser.parseKeyword("unwind") ||
388       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
389       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
390       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
391     return failure();
392 
393   if (isDirect) {
394     // Make sure types match.
395     if (parser.resolveOperands(operands, funcType.getInputs(),
396                                parser.getNameLoc(), result.operands))
397       return failure();
398     result.addTypes(funcType.getResults());
399   } else {
400     // Construct the LLVM IR Dialect function type that the first operand
401     // should match.
402     if (funcType.getNumResults() > 1)
403       return parser.emitError(trailingTypeLoc,
404                               "expected function with 0 or 1 result");
405 
406     LLVM::LLVMType llvmResultType;
407     if (funcType.getNumResults() == 0) {
408       llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
409     } else {
410       llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
411       if (!llvmResultType)
412         return parser.emitError(trailingTypeLoc,
413                                 "expected result to have LLVM type");
414     }
415 
416     SmallVector<LLVM::LLVMType, 8> argTypes;
417     argTypes.reserve(funcType.getNumInputs());
418     for (Type ty : funcType.getInputs()) {
419       if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
420         argTypes.push_back(argType);
421       else
422         return parser.emitError(trailingTypeLoc,
423                                 "expected LLVM types as inputs");
424     }
425 
426     auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
427                                                       /*isVarArg=*/false);
428     auto wrappedFuncType = llvmFuncType.getPointerTo();
429 
430     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
431 
432     // Make sure that the first operand (indirect callee) matches the wrapped
433     // LLVM IR function type, and that the types of the other call operands
434     // match the types of the function arguments.
435     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
436         parser.resolveOperands(funcArguments, funcType.getInputs(),
437                                parser.getNameLoc(), result.operands))
438       return failure();
439 
440     result.addTypes(llvmResultType);
441   }
442   result.addSuccessors({normalDest, unwindDest});
443   result.addOperands(normalOperands);
444   result.addOperands(unwindOperands);
445 
446   result.addAttribute(
447       InvokeOp::getOperandSegmentSizeAttr(),
448       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
449                                 static_cast<int32_t>(normalOperands.size()),
450                                 static_cast<int32_t>(unwindOperands.size())}));
451   return success();
452 }
453 
454 ///===----------------------------------------------------------------------===//
455 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
456 ///===----------------------------------------------------------------------===//
457 
verify(LandingpadOp op)458 static LogicalResult verify(LandingpadOp op) {
459   Value value;
460   if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
461     if (!func.personality().hasValue())
462       return op.emitError(
463           "llvm.landingpad needs to be in a function with a personality");
464   }
465 
466   if (!op.cleanup() && op.getOperands().empty())
467     return op.emitError("landingpad instruction expects at least one clause or "
468                         "cleanup attribute");
469 
470   for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
471     value = op.getOperand(idx);
472     bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
473     if (isFilter) {
474       // FIXME: Verify filter clauses when arrays are appropriately handled
475     } else {
476       // catch - global addresses only.
477       // Bitcast ops should have global addresses as their args.
478       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
479         if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
480           continue;
481         return op.emitError("constant clauses expected")
482                    .attachNote(bcOp.getLoc())
483                << "global addresses expected as operand to "
484                   "bitcast used in clauses for landingpad";
485       }
486       // NullOp and AddressOfOp allowed
487       if (value.getDefiningOp<NullOp>())
488         continue;
489       if (value.getDefiningOp<AddressOfOp>())
490         continue;
491       return op.emitError("clause #")
492              << idx << " is not a known constant - null, addressof, bitcast";
493     }
494   }
495   return success();
496 }
497 
printLandingpadOp(OpAsmPrinter & p,LandingpadOp & op)498 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
499   p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
500 
501   // Clauses
502   for (auto value : op.getOperands()) {
503     // Similar to llvm - if clause is an array type then it is filter
504     // clause else catch clause
505     bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
506     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
507       << value.getType() << ") ";
508   }
509 
510   p.printOptionalAttrDict(op.getAttrs(), {"cleanup"});
511 
512   p << ": " << op.getType();
513 }
514 
515 /// <operation> ::= `llvm.landingpad` `cleanup`?
516 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
parseLandingpadOp(OpAsmParser & parser,OperationState & result)517 static ParseResult parseLandingpadOp(OpAsmParser &parser,
518                                      OperationState &result) {
519   // Check for cleanup
520   if (succeeded(parser.parseOptionalKeyword("cleanup")))
521     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
522 
523   // Parse clauses with types
524   while (succeeded(parser.parseOptionalLParen()) &&
525          (succeeded(parser.parseOptionalKeyword("filter")) ||
526           succeeded(parser.parseOptionalKeyword("catch")))) {
527     OpAsmParser::OperandType operand;
528     Type ty;
529     if (parser.parseOperand(operand) || parser.parseColon() ||
530         parser.parseType(ty) ||
531         parser.resolveOperand(operand, ty, result.operands) ||
532         parser.parseRParen())
533       return failure();
534   }
535 
536   Type type;
537   if (parser.parseColon() || parser.parseType(type))
538     return failure();
539 
540   result.addTypes(type);
541   return success();
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // Verifying/Printing/parsing for LLVM::CallOp.
546 //===----------------------------------------------------------------------===//
547 
verify(CallOp & op)548 static LogicalResult verify(CallOp &op) {
549   if (op.getNumResults() > 1)
550     return op.emitOpError("must have 0 or 1 result");
551 
552   // Type for the callee, we'll get it differently depending if it is a direct
553   // or indirect call.
554   LLVMType fnType;
555 
556   bool isIndirect = false;
557 
558   // If this is an indirect call, the callee attribute is missing.
559   Optional<StringRef> calleeName = op.callee();
560   if (!calleeName) {
561     isIndirect = true;
562     if (!op.getNumOperands())
563       return op.emitOpError(
564           "must have either a `callee` attribute or at least an operand");
565     fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
566     if (!fnType)
567       return op.emitOpError("indirect call to a non-llvm type: ")
568              << op.getOperand(0).getType();
569     auto ptrType = fnType.dyn_cast<LLVMPointerType>();
570     if (!ptrType)
571       return op.emitOpError("indirect call expects a pointer as callee: ")
572              << fnType;
573     fnType = ptrType.getElementType();
574   } else {
575     Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
576     if (!callee)
577       return op.emitOpError()
578              << "'" << *calleeName
579              << "' does not reference a symbol in the current scope";
580     auto fn = dyn_cast<LLVMFuncOp>(callee);
581     if (!fn)
582       return op.emitOpError() << "'" << *calleeName
583                               << "' does not reference a valid LLVM function";
584 
585     fnType = fn.getType();
586   }
587   if (!fnType.isFunctionTy())
588     return op.emitOpError("callee does not have a functional type: ") << fnType;
589 
590   // Verify that the operand and result types match the callee.
591 
592   if (!fnType.isFunctionVarArg() &&
593       fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
594     return op.emitOpError()
595            << "incorrect number of operands ("
596            << (op.getNumOperands() - isIndirect)
597            << ") for callee (expecting: " << fnType.getFunctionNumParams()
598            << ")";
599 
600   if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
601     return op.emitOpError() << "incorrect number of operands ("
602                             << (op.getNumOperands() - isIndirect)
603                             << ") for varargs callee (expecting at least: "
604                             << fnType.getFunctionNumParams() << ")";
605 
606   for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
607     if (op.getOperand(i + isIndirect).getType() !=
608         fnType.getFunctionParamType(i))
609       return op.emitOpError() << "operand type mismatch for operand " << i
610                               << ": " << op.getOperand(i + isIndirect).getType()
611                               << " != " << fnType.getFunctionParamType(i);
612 
613   if (op.getNumResults() &&
614       op.getResult(0).getType() != fnType.getFunctionResultType())
615     return op.emitOpError()
616            << "result type mismatch: " << op.getResult(0).getType()
617            << " != " << fnType.getFunctionResultType();
618 
619   return success();
620 }
621 
printCallOp(OpAsmPrinter & p,CallOp & op)622 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
623   auto callee = op.callee();
624   bool isDirect = callee.hasValue();
625 
626   // Print the direct callee if present as a function attribute, or an indirect
627   // callee (first operand) otherwise.
628   p << op.getOperationName() << ' ';
629   if (isDirect)
630     p.printSymbolName(callee.getValue());
631   else
632     p << op.getOperand(0);
633 
634   auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
635   p << '(' << args << ')';
636   p.printOptionalAttrDict(op.getAttrs(), {"callee"});
637 
638   // Reconstruct the function MLIR function type from operand and result types.
639   p << " : "
640     << FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext());
641 }
642 
643 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
644 //                 attribute-dict? `:` function-type
parseCallOp(OpAsmParser & parser,OperationState & result)645 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
646   SmallVector<OpAsmParser::OperandType, 8> operands;
647   Type type;
648   SymbolRefAttr funcAttr;
649   llvm::SMLoc trailingTypeLoc;
650 
651   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
652   // case of an indirect call, there will be 1 operand before `(`.  In case of a
653   // direct call, there will be no operands and the parser will stop at the
654   // function identifier without complaining.
655   if (parser.parseOperandList(operands))
656     return failure();
657   bool isDirect = operands.empty();
658 
659   // Optionally parse a function identifier.
660   if (isDirect)
661     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
662       return failure();
663 
664   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
665       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
666       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
667     return failure();
668 
669   auto funcType = type.dyn_cast<FunctionType>();
670   if (!funcType)
671     return parser.emitError(trailingTypeLoc, "expected function type");
672   if (isDirect) {
673     // Make sure types match.
674     if (parser.resolveOperands(operands, funcType.getInputs(),
675                                parser.getNameLoc(), result.operands))
676       return failure();
677     result.addTypes(funcType.getResults());
678   } else {
679     // Construct the LLVM IR Dialect function type that the first operand
680     // should match.
681     if (funcType.getNumResults() > 1)
682       return parser.emitError(trailingTypeLoc,
683                               "expected function with 0 or 1 result");
684 
685     Builder &builder = parser.getBuilder();
686     LLVM::LLVMType llvmResultType;
687     if (funcType.getNumResults() == 0) {
688       llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
689     } else {
690       llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
691       if (!llvmResultType)
692         return parser.emitError(trailingTypeLoc,
693                                 "expected result to have LLVM type");
694     }
695 
696     SmallVector<LLVM::LLVMType, 8> argTypes;
697     argTypes.reserve(funcType.getNumInputs());
698     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
699       auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
700       if (!argType)
701         return parser.emitError(trailingTypeLoc,
702                                 "expected LLVM types as inputs");
703       argTypes.push_back(argType);
704     }
705     auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
706                                                       /*isVarArg=*/false);
707     auto wrappedFuncType = llvmFuncType.getPointerTo();
708 
709     auto funcArguments =
710         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
711 
712     // Make sure that the first operand (indirect callee) matches the wrapped
713     // LLVM IR function type, and that the types of the other call operands
714     // match the types of the function arguments.
715     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
716         parser.resolveOperands(funcArguments, funcType.getInputs(),
717                                parser.getNameLoc(), result.operands))
718       return failure();
719 
720     result.addTypes(llvmResultType);
721   }
722 
723   return success();
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // Printing/parsing for LLVM::ExtractElementOp.
728 //===----------------------------------------------------------------------===//
729 // Expects vector to be of wrapped LLVM vector type and position to be of
730 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value vector,Value position,ArrayRef<NamedAttribute> attrs)731 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
732                                    Value vector, Value position,
733                                    ArrayRef<NamedAttribute> attrs) {
734   auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
735   auto llvmType = wrappedVectorType.getVectorElementType();
736   build(b, result, llvmType, vector, position);
737   result.addAttributes(attrs);
738 }
739 
printExtractElementOp(OpAsmPrinter & p,ExtractElementOp & op)740 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
741   p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
742     << " : " << op.position().getType() << "]";
743   p.printOptionalAttrDict(op.getAttrs());
744   p << " : " << op.vector().getType();
745 }
746 
747 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
748 //                 attribute-dict? `:` type
parseExtractElementOp(OpAsmParser & parser,OperationState & result)749 static ParseResult parseExtractElementOp(OpAsmParser &parser,
750                                          OperationState &result) {
751   llvm::SMLoc loc;
752   OpAsmParser::OperandType vector, position;
753   Type type, positionType;
754   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
755       parser.parseLSquare() || parser.parseOperand(position) ||
756       parser.parseColonType(positionType) || parser.parseRSquare() ||
757       parser.parseOptionalAttrDict(result.attributes) ||
758       parser.parseColonType(type) ||
759       parser.resolveOperand(vector, type, result.operands) ||
760       parser.resolveOperand(position, positionType, result.operands))
761     return failure();
762   auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
763   if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
764     return parser.emitError(
765         loc, "expected LLVM IR dialect vector type for operand #1");
766   result.addTypes(wrappedVectorType.getVectorElementType());
767   return success();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // Printing/parsing for LLVM::ExtractValueOp.
772 //===----------------------------------------------------------------------===//
773 
printExtractValueOp(OpAsmPrinter & p,ExtractValueOp & op)774 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
775   p << op.getOperationName() << ' ' << op.container() << op.position();
776   p.printOptionalAttrDict(op.getAttrs(), {"position"});
777   p << " : " << op.container().getType();
778 }
779 
780 // Extract the type at `position` in the wrapped LLVM IR aggregate type
781 // `containerType`.  Position is an integer array attribute where each value
782 // is a zero-based position of the element in the aggregate type.  Return the
783 // resulting type wrapped in MLIR, or nullptr on error.
getInsertExtractValueElementType(OpAsmParser & parser,Type containerType,ArrayAttr positionAttr,llvm::SMLoc attributeLoc,llvm::SMLoc typeLoc)784 static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
785                                                        Type containerType,
786                                                        ArrayAttr positionAttr,
787                                                        llvm::SMLoc attributeLoc,
788                                                        llvm::SMLoc typeLoc) {
789   auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
790   if (!wrappedContainerType)
791     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
792 
793   // Infer the element type from the structure type: iteratively step inside the
794   // type by taking the element type, indexed by the position attribute for
795   // structures.  Check the position index before accessing, it is supposed to
796   // be in bounds.
797   for (Attribute subAttr : positionAttr) {
798     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
799     if (!positionElementAttr)
800       return parser.emitError(attributeLoc,
801                               "expected an array of integer literals"),
802              nullptr;
803     int position = positionElementAttr.getInt();
804     if (wrappedContainerType.isArrayTy()) {
805       if (position < 0 || static_cast<unsigned>(position) >=
806                               wrappedContainerType.getArrayNumElements())
807         return parser.emitError(attributeLoc, "position out of bounds"),
808                nullptr;
809       wrappedContainerType = wrappedContainerType.getArrayElementType();
810     } else if (wrappedContainerType.isStructTy()) {
811       if (position < 0 || static_cast<unsigned>(position) >=
812                               wrappedContainerType.getStructNumElements())
813         return parser.emitError(attributeLoc, "position out of bounds"),
814                nullptr;
815       wrappedContainerType =
816           wrappedContainerType.getStructElementType(position);
817     } else {
818       return parser.emitError(typeLoc,
819                               "expected wrapped LLVM IR structure/array type"),
820              nullptr;
821     }
822   }
823   return wrappedContainerType;
824 }
825 
826 // <operation> ::= `llvm.extractvalue` ssa-use
827 //                 `[` integer-literal (`,` integer-literal)* `]`
828 //                 attribute-dict? `:` type
parseExtractValueOp(OpAsmParser & parser,OperationState & result)829 static ParseResult parseExtractValueOp(OpAsmParser &parser,
830                                        OperationState &result) {
831   OpAsmParser::OperandType container;
832   Type containerType;
833   ArrayAttr positionAttr;
834   llvm::SMLoc attributeLoc, trailingTypeLoc;
835 
836   if (parser.parseOperand(container) ||
837       parser.getCurrentLocation(&attributeLoc) ||
838       parser.parseAttribute(positionAttr, "position", result.attributes) ||
839       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
840       parser.getCurrentLocation(&trailingTypeLoc) ||
841       parser.parseType(containerType) ||
842       parser.resolveOperand(container, containerType, result.operands))
843     return failure();
844 
845   auto elementType = getInsertExtractValueElementType(
846       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
847   if (!elementType)
848     return failure();
849 
850   result.addTypes(elementType);
851   return success();
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // Printing/parsing for LLVM::InsertElementOp.
856 //===----------------------------------------------------------------------===//
857 
printInsertElementOp(OpAsmPrinter & p,InsertElementOp & op)858 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
859   p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
860     << op.position() << " : " << op.position().getType() << "]";
861   p.printOptionalAttrDict(op.getAttrs());
862   p << " : " << op.vector().getType();
863 }
864 
865 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
866 //                 attribute-dict? `:` type
parseInsertElementOp(OpAsmParser & parser,OperationState & result)867 static ParseResult parseInsertElementOp(OpAsmParser &parser,
868                                         OperationState &result) {
869   llvm::SMLoc loc;
870   OpAsmParser::OperandType vector, value, position;
871   Type vectorType, positionType;
872   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
873       parser.parseComma() || parser.parseOperand(vector) ||
874       parser.parseLSquare() || parser.parseOperand(position) ||
875       parser.parseColonType(positionType) || parser.parseRSquare() ||
876       parser.parseOptionalAttrDict(result.attributes) ||
877       parser.parseColonType(vectorType))
878     return failure();
879 
880   auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
881   if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
882     return parser.emitError(
883         loc, "expected LLVM IR dialect vector type for operand #1");
884   auto valueType = wrappedVectorType.getVectorElementType();
885   if (!valueType)
886     return failure();
887 
888   if (parser.resolveOperand(vector, vectorType, result.operands) ||
889       parser.resolveOperand(value, valueType, result.operands) ||
890       parser.resolveOperand(position, positionType, result.operands))
891     return failure();
892 
893   result.addTypes(vectorType);
894   return success();
895 }
896 
897 //===----------------------------------------------------------------------===//
898 // Printing/parsing for LLVM::InsertValueOp.
899 //===----------------------------------------------------------------------===//
900 
printInsertValueOp(OpAsmPrinter & p,InsertValueOp & op)901 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
902   p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
903     << op.position();
904   p.printOptionalAttrDict(op.getAttrs(), {"position"});
905   p << " : " << op.container().getType();
906 }
907 
908 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
909 //                 `[` integer-literal (`,` integer-literal)* `]`
910 //                 attribute-dict? `:` type
parseInsertValueOp(OpAsmParser & parser,OperationState & result)911 static ParseResult parseInsertValueOp(OpAsmParser &parser,
912                                       OperationState &result) {
913   OpAsmParser::OperandType container, value;
914   Type containerType;
915   ArrayAttr positionAttr;
916   llvm::SMLoc attributeLoc, trailingTypeLoc;
917 
918   if (parser.parseOperand(value) || parser.parseComma() ||
919       parser.parseOperand(container) ||
920       parser.getCurrentLocation(&attributeLoc) ||
921       parser.parseAttribute(positionAttr, "position", result.attributes) ||
922       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
923       parser.getCurrentLocation(&trailingTypeLoc) ||
924       parser.parseType(containerType))
925     return failure();
926 
927   auto valueType = getInsertExtractValueElementType(
928       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
929   if (!valueType)
930     return failure();
931 
932   if (parser.resolveOperand(container, containerType, result.operands) ||
933       parser.resolveOperand(value, valueType, result.operands))
934     return failure();
935 
936   result.addTypes(containerType);
937   return success();
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // Printing/parsing for LLVM::ReturnOp.
942 //===----------------------------------------------------------------------===//
943 
printReturnOp(OpAsmPrinter & p,ReturnOp & op)944 static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) {
945   p << op.getOperationName();
946   p.printOptionalAttrDict(op.getAttrs());
947   assert(op.getNumOperands() <= 1);
948 
949   if (op.getNumOperands() == 0)
950     return;
951 
952   p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
953 }
954 
955 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
956 //                 type-list-no-parens
parseReturnOp(OpAsmParser & parser,OperationState & result)957 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
958   SmallVector<OpAsmParser::OperandType, 1> operands;
959   Type type;
960 
961   if (parser.parseOperandList(operands) ||
962       parser.parseOptionalAttrDict(result.attributes))
963     return failure();
964   if (operands.empty())
965     return success();
966 
967   if (parser.parseColonType(type) ||
968       parser.resolveOperand(operands[0], type, result.operands))
969     return failure();
970   return success();
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // Verifier for LLVM::AddressOfOp.
975 //===----------------------------------------------------------------------===//
976 
977 template <typename OpTy>
lookupSymbolInModule(Operation * parent,StringRef name)978 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
979   Operation *module = parent;
980   while (module && !satisfiesLLVMModule(module))
981     module = module->getParentOp();
982   assert(module && "unexpected operation outside of a module");
983   return dyn_cast_or_null<OpTy>(
984       mlir::SymbolTable::lookupSymbolIn(module, name));
985 }
986 
getGlobal()987 GlobalOp AddressOfOp::getGlobal() {
988   return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
989                                               global_name());
990 }
991 
getFunction()992 LLVMFuncOp AddressOfOp::getFunction() {
993   return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
994                                                 global_name());
995 }
996 
verify(AddressOfOp op)997 static LogicalResult verify(AddressOfOp op) {
998   auto global = op.getGlobal();
999   auto function = op.getFunction();
1000   if (!global && !function)
1001     return op.emitOpError(
1002         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1003 
1004   if (global && global.getType().getPointerTo(global.addr_space()) !=
1005                     op.getResult().getType())
1006     return op.emitOpError(
1007         "the type must be a pointer to the type of the referenced global");
1008 
1009   if (function && function.getType().getPointerTo() != op.getResult().getType())
1010     return op.emitOpError(
1011         "the type must be a pointer to the type of the referenced function");
1012 
1013   return success();
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // Builder, printer and verifier for LLVM::GlobalOp.
1018 //===----------------------------------------------------------------------===//
1019 
1020 /// Returns the name used for the linkage attribute. This *must* correspond to
1021 /// the name of the attribute in ODS.
getLinkageAttrName()1022 static StringRef getLinkageAttrName() { return "linkage"; }
1023 
build(OpBuilder & builder,OperationState & result,LLVMType type,bool isConstant,Linkage linkage,StringRef name,Attribute value,unsigned addrSpace,ArrayRef<NamedAttribute> attrs)1024 void GlobalOp::build(OpBuilder &builder, OperationState &result, LLVMType type,
1025                      bool isConstant, Linkage linkage, StringRef name,
1026                      Attribute value, unsigned addrSpace,
1027                      ArrayRef<NamedAttribute> attrs) {
1028   result.addAttribute(SymbolTable::getSymbolAttrName(),
1029                       builder.getStringAttr(name));
1030   result.addAttribute("type", TypeAttr::get(type));
1031   if (isConstant)
1032     result.addAttribute("constant", builder.getUnitAttr());
1033   if (value)
1034     result.addAttribute("value", value);
1035   result.addAttribute(getLinkageAttrName(),
1036                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1037   if (addrSpace != 0)
1038     result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1039   result.attributes.append(attrs.begin(), attrs.end());
1040   result.addRegion();
1041 }
1042 
printGlobalOp(OpAsmPrinter & p,GlobalOp op)1043 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
1044   p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
1045   if (op.constant())
1046     p << "constant ";
1047   p.printSymbolName(op.sym_name());
1048   p << '(';
1049   if (auto value = op.getValueOrNull())
1050     p.printAttribute(value);
1051   p << ')';
1052   p.printOptionalAttrDict(op.getAttrs(),
1053                           {SymbolTable::getSymbolAttrName(), "type", "constant",
1054                            "value", getLinkageAttrName()});
1055 
1056   // Print the trailing type unless it's a string global.
1057   if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
1058     return;
1059   p << " : " << op.type();
1060 
1061   Region &initializer = op.getInitializerRegion();
1062   if (!initializer.empty())
1063     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1064 }
1065 
1066 //===----------------------------------------------------------------------===//
1067 // Verifier for LLVM::DialectCastOp.
1068 //===----------------------------------------------------------------------===//
1069 
verify(DialectCastOp op)1070 static LogicalResult verify(DialectCastOp op) {
1071   auto verifyMLIRCastType = [&op](Type type) -> LogicalResult {
1072     if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
1073       if (llvmType.isVectorTy())
1074         llvmType = llvmType.getVectorElementType();
1075       if (llvmType.isIntegerTy() || llvmType.isBFloatTy() ||
1076           llvmType.isHalfTy() || llvmType.isFloatTy() ||
1077           llvmType.isDoubleTy()) {
1078         return success();
1079       }
1080       return op.emitOpError("type must be non-index integer types, float "
1081                             "types, or vector of mentioned types.");
1082     }
1083     if (auto vectorType = type.dyn_cast<VectorType>()) {
1084       if (vectorType.getShape().size() > 1)
1085         return op.emitOpError("only 1-d vector is allowed");
1086       type = vectorType.getElementType();
1087     }
1088     if (type.isSignlessIntOrFloat())
1089       return success();
1090     // Note that memrefs are not supported. We currently don't have a use case
1091     // for it, but even if we do, there are challenges:
1092     // * if we allow memrefs to cast from/to memref descriptors, then the
1093     // semantics of the cast op depends on the implementation detail of the
1094     // descriptor.
1095     // * if we allow memrefs to cast from/to bare pointers, some users might
1096     // alternatively want metadata that only present in the descriptor.
1097     //
1098     // TODO: re-evaluate the memref cast design when it's needed.
1099     return op.emitOpError("type must be non-index integer types, float types, "
1100                           "or vector of mentioned types.");
1101   };
1102   return failure(failed(verifyMLIRCastType(op.in().getType())) ||
1103                  failed(verifyMLIRCastType(op.getType())));
1104 }
1105 
1106 // Parses one of the keywords provided in the list `keywords` and returns the
1107 // position of the parsed keyword in the list. If none of the keywords from the
1108 // list is parsed, returns -1.
parseOptionalKeywordAlternative(OpAsmParser & parser,ArrayRef<StringRef> keywords)1109 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1110                                            ArrayRef<StringRef> keywords) {
1111   for (auto en : llvm::enumerate(keywords)) {
1112     if (succeeded(parser.parseOptionalKeyword(en.value())))
1113       return en.index();
1114   }
1115   return -1;
1116 }
1117 
1118 namespace {
1119 template <typename Ty> struct EnumTraits {};
1120 
1121 #define REGISTER_ENUM_TYPE(Ty)                                                 \
1122   template <> struct EnumTraits<Ty> {                                          \
1123     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
1124     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
1125   }
1126 
1127 REGISTER_ENUM_TYPE(Linkage);
1128 } // end namespace
1129 
1130 template <typename EnumTy>
parseOptionalLLVMKeyword(OpAsmParser & parser,OperationState & result,StringRef name)1131 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
1132                                             OperationState &result,
1133                                             StringRef name) {
1134   SmallVector<StringRef, 10> names;
1135   for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1136     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1137 
1138   int index = parseOptionalKeywordAlternative(parser, names);
1139   if (index == -1)
1140     return failure();
1141   result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
1142   return success();
1143 }
1144 
1145 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1146 //               `(` attribute? `)` attribute-list? (`:` type)? region?
1147 //
1148 // The type can be omitted for string attributes, in which case it will be
1149 // inferred from the value of the string as [strlen(value) x i8].
parseGlobalOp(OpAsmParser & parser,OperationState & result)1150 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1151   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1152                                                getLinkageAttrName())))
1153     result.addAttribute(getLinkageAttrName(),
1154                         parser.getBuilder().getI64IntegerAttr(
1155                             static_cast<int64_t>(LLVM::Linkage::External)));
1156 
1157   if (succeeded(parser.parseOptionalKeyword("constant")))
1158     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1159 
1160   StringAttr name;
1161   if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1162                              result.attributes) ||
1163       parser.parseLParen())
1164     return failure();
1165 
1166   Attribute value;
1167   if (parser.parseOptionalRParen()) {
1168     if (parser.parseAttribute(value, "value", result.attributes) ||
1169         parser.parseRParen())
1170       return failure();
1171   }
1172 
1173   SmallVector<Type, 1> types;
1174   if (parser.parseOptionalAttrDict(result.attributes) ||
1175       parser.parseOptionalColonTypeList(types))
1176     return failure();
1177 
1178   if (types.size() > 1)
1179     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1180 
1181   Region &initRegion = *result.addRegion();
1182   if (types.empty()) {
1183     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1184       MLIRContext *context = parser.getBuilder().getContext();
1185       auto arrayType = LLVM::LLVMType::getArrayTy(
1186           LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
1187       types.push_back(arrayType);
1188     } else {
1189       return parser.emitError(parser.getNameLoc(),
1190                               "type can only be omitted for string globals");
1191     }
1192   } else {
1193     OptionalParseResult parseResult =
1194         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1195                                    /*argTypes=*/{});
1196     if (parseResult.hasValue() && failed(*parseResult))
1197       return failure();
1198   }
1199 
1200   result.addAttribute("type", TypeAttr::get(types[0]));
1201   return success();
1202 }
1203 
verify(GlobalOp op)1204 static LogicalResult verify(GlobalOp op) {
1205   if (!LLVMPointerType::isValidElementType(op.getType()))
1206     return op.emitOpError(
1207         "expects type to be a valid element type for an LLVM pointer");
1208   if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
1209     return op.emitOpError("must appear at the module level");
1210 
1211   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1212     auto type = op.getType();
1213     if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
1214         type.getArrayNumElements() != strAttr.getValue().size())
1215       return op.emitOpError(
1216           "requires an i8 array type of the length equal to that of the string "
1217           "attribute");
1218   }
1219 
1220   if (Block *b = op.getInitializerBlock()) {
1221     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1222     if (ret.operand_type_begin() == ret.operand_type_end())
1223       return op.emitOpError("initializer region cannot return void");
1224     if (*ret.operand_type_begin() != op.getType())
1225       return op.emitOpError("initializer region type ")
1226              << *ret.operand_type_begin() << " does not match global type "
1227              << op.getType();
1228 
1229     if (op.getValueOrNull())
1230       return op.emitOpError("cannot have both initializer value and region");
1231   }
1232   return success();
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // Printing/parsing for LLVM::ShuffleVectorOp.
1237 //===----------------------------------------------------------------------===//
1238 // Expects vector to be of wrapped LLVM vector type and position to be of
1239 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value v1,Value v2,ArrayAttr mask,ArrayRef<NamedAttribute> attrs)1240 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1241                                   Value v1, Value v2, ArrayAttr mask,
1242                                   ArrayRef<NamedAttribute> attrs) {
1243   auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
1244   auto vType = LLVMType::getVectorTy(
1245       wrappedContainerType1.getVectorElementType(), mask.size());
1246   build(b, result, vType, v1, v2, mask);
1247   result.addAttributes(attrs);
1248 }
1249 
printShuffleVectorOp(OpAsmPrinter & p,ShuffleVectorOp & op)1250 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
1251   p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
1252     << op.mask();
1253   p.printOptionalAttrDict(op.getAttrs(), {"mask"});
1254   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1255 }
1256 
1257 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1258 //                 `[` integer-literal (`,` integer-literal)* `]`
1259 //                 attribute-dict? `:` type
parseShuffleVectorOp(OpAsmParser & parser,OperationState & result)1260 static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
1261                                         OperationState &result) {
1262   llvm::SMLoc loc;
1263   OpAsmParser::OperandType v1, v2;
1264   ArrayAttr maskAttr;
1265   Type typeV1, typeV2;
1266   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1267       parser.parseComma() || parser.parseOperand(v2) ||
1268       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1269       parser.parseOptionalAttrDict(result.attributes) ||
1270       parser.parseColonType(typeV1) || parser.parseComma() ||
1271       parser.parseType(typeV2) ||
1272       parser.resolveOperand(v1, typeV1, result.operands) ||
1273       parser.resolveOperand(v2, typeV2, result.operands))
1274     return failure();
1275   auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
1276   if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
1277     return parser.emitError(
1278         loc, "expected LLVM IR dialect vector type for operand #1");
1279   auto vType = LLVMType::getVectorTy(
1280       wrappedContainerType1.getVectorElementType(), maskAttr.size());
1281   result.addTypes(vType);
1282   return success();
1283 }
1284 
1285 //===----------------------------------------------------------------------===//
1286 // Implementations for LLVM::LLVMFuncOp.
1287 //===----------------------------------------------------------------------===//
1288 
1289 // Add the entry block to the function.
addEntryBlock()1290 Block *LLVMFuncOp::addEntryBlock() {
1291   assert(empty() && "function already has an entry block");
1292   assert(!isVarArg() && "unimplemented: non-external variadic functions");
1293 
1294   auto *entry = new Block;
1295   push_back(entry);
1296 
1297   LLVMType type = getType();
1298   for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
1299     entry->addArgument(type.getFunctionParamType(i));
1300   return entry;
1301 }
1302 
build(OpBuilder & builder,OperationState & result,StringRef name,LLVMType type,LLVM::Linkage linkage,ArrayRef<NamedAttribute> attrs,ArrayRef<MutableDictionaryAttr> argAttrs)1303 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1304                        StringRef name, LLVMType type, LLVM::Linkage linkage,
1305                        ArrayRef<NamedAttribute> attrs,
1306                        ArrayRef<MutableDictionaryAttr> argAttrs) {
1307   result.addRegion();
1308   result.addAttribute(SymbolTable::getSymbolAttrName(),
1309                       builder.getStringAttr(name));
1310   result.addAttribute("type", TypeAttr::get(type));
1311   result.addAttribute(getLinkageAttrName(),
1312                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1313   result.attributes.append(attrs.begin(), attrs.end());
1314   if (argAttrs.empty())
1315     return;
1316 
1317   unsigned numInputs = type.getFunctionNumParams();
1318   assert(numInputs == argAttrs.size() &&
1319          "expected as many argument attribute lists as arguments");
1320   SmallString<8> argAttrName;
1321   for (unsigned i = 0; i < numInputs; ++i)
1322     if (auto argDict = argAttrs[i].getDictionary(builder.getContext()))
1323       result.addAttribute(getArgAttrName(i, argAttrName), argDict);
1324 }
1325 
1326 // Builds an LLVM function type from the given lists of input and output types.
1327 // Returns a null type if any of the types provided are non-LLVM types, or if
1328 // there is more than one output type.
buildLLVMFunctionType(OpAsmParser & parser,llvm::SMLoc loc,ArrayRef<Type> inputs,ArrayRef<Type> outputs,impl::VariadicFlag variadicFlag)1329 static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
1330                                   ArrayRef<Type> inputs, ArrayRef<Type> outputs,
1331                                   impl::VariadicFlag variadicFlag) {
1332   Builder &b = parser.getBuilder();
1333   if (outputs.size() > 1) {
1334     parser.emitError(loc, "failed to construct function type: expected zero or "
1335                           "one function result");
1336     return {};
1337   }
1338 
1339   // Convert inputs to LLVM types, exit early on error.
1340   SmallVector<LLVMType, 4> llvmInputs;
1341   for (auto t : inputs) {
1342     auto llvmTy = t.dyn_cast<LLVMType>();
1343     if (!llvmTy) {
1344       parser.emitError(loc, "failed to construct function type: expected LLVM "
1345                             "type for function arguments");
1346       return {};
1347     }
1348     llvmInputs.push_back(llvmTy);
1349   }
1350 
1351   // No output is denoted as "void" in LLVM type system.
1352   LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
1353                                         : outputs.front().dyn_cast<LLVMType>();
1354   if (!llvmOutput) {
1355     parser.emitError(loc, "failed to construct function type: expected LLVM "
1356                           "type for function results");
1357     return {};
1358   }
1359   return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
1360                                  variadicFlag.isVariadic());
1361 }
1362 
1363 // Parses an LLVM function.
1364 //
1365 // operation ::= `llvm.func` linkage? function-signature function-attributes?
1366 //               function-body
1367 //
parseLLVMFuncOp(OpAsmParser & parser,OperationState & result)1368 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
1369                                    OperationState &result) {
1370   // Default to external linkage if no keyword is provided.
1371   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1372                                                getLinkageAttrName())))
1373     result.addAttribute(getLinkageAttrName(),
1374                         parser.getBuilder().getI64IntegerAttr(
1375                             static_cast<int64_t>(LLVM::Linkage::External)));
1376 
1377   StringAttr nameAttr;
1378   SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1379   SmallVector<NamedAttrList, 1> argAttrs;
1380   SmallVector<NamedAttrList, 1> resultAttrs;
1381   SmallVector<Type, 8> argTypes;
1382   SmallVector<Type, 4> resultTypes;
1383   bool isVariadic;
1384 
1385   auto signatureLocation = parser.getCurrentLocation();
1386   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1387                              result.attributes) ||
1388       impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
1389                                    argTypes, argAttrs, isVariadic, resultTypes,
1390                                    resultAttrs))
1391     return failure();
1392 
1393   auto type =
1394       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
1395                             impl::VariadicFlag(isVariadic));
1396   if (!type)
1397     return failure();
1398   result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
1399 
1400   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1401     return failure();
1402   impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
1403                              resultAttrs);
1404 
1405   auto *body = result.addRegion();
1406   OptionalParseResult parseResult = parser.parseOptionalRegion(
1407       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1408   return failure(parseResult.hasValue() && failed(*parseResult));
1409 }
1410 
1411 // Print the LLVMFuncOp. Collects argument and result types and passes them to
1412 // helper functions. Drops "void" result since it cannot be parsed back. Skips
1413 // the external linkage since it is the default value.
printLLVMFuncOp(OpAsmPrinter & p,LLVMFuncOp op)1414 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
1415   p << op.getOperationName() << ' ';
1416   if (op.linkage() != LLVM::Linkage::External)
1417     p << stringifyLinkage(op.linkage()) << ' ';
1418   p.printSymbolName(op.getName());
1419 
1420   LLVMType fnType = op.getType();
1421   SmallVector<Type, 8> argTypes;
1422   SmallVector<Type, 1> resTypes;
1423   argTypes.reserve(fnType.getFunctionNumParams());
1424   for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
1425     argTypes.push_back(fnType.getFunctionParamType(i));
1426 
1427   LLVMType returnType = fnType.getFunctionResultType();
1428   if (!returnType.isVoidTy())
1429     resTypes.push_back(returnType);
1430 
1431   impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
1432   impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
1433                                 {getLinkageAttrName()});
1434 
1435   // Print the body if this is not an external function.
1436   Region &body = op.body();
1437   if (!body.empty())
1438     p.printRegion(body, /*printEntryBlockArgs=*/false,
1439                   /*printBlockTerminators=*/true);
1440 }
1441 
1442 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1443 // attribute is present.  This can check for preconditions of the
1444 // getNumArguments hook not failing.
verifyType()1445 LogicalResult LLVMFuncOp::verifyType() {
1446   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
1447   if (!llvmType || !llvmType.isFunctionTy())
1448     return emitOpError("requires '" + getTypeAttrName() +
1449                        "' attribute of wrapped LLVM function type");
1450 
1451   return success();
1452 }
1453 
1454 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
1455 // Depends on the type attribute being correct as checked by verifyType
getNumFuncArguments()1456 unsigned LLVMFuncOp::getNumFuncArguments() {
1457   return getType().getFunctionNumParams();
1458 }
1459 
1460 // Hook for OpTrait::FunctionLike, returns the number of function results.
1461 // Depends on the type attribute being correct as checked by verifyType
getNumFuncResults()1462 unsigned LLVMFuncOp::getNumFuncResults() {
1463   // We model LLVM functions that return void as having zero results,
1464   // and all others as having one result.
1465   // If we modeled a void return as one result, then it would be possible to
1466   // attach an MLIR result attribute to it, and it isn't clear what semantics we
1467   // would assign to that.
1468   if (getType().getFunctionResultType().isVoidTy())
1469     return 0;
1470   return 1;
1471 }
1472 
1473 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
1474 // - functions don't have 'common' linkage
1475 // - external functions have 'external' or 'extern_weak' linkage;
1476 // - vararg is (currently) only supported for external functions;
1477 // - entry block arguments are of LLVM types and match the function signature.
verify(LLVMFuncOp op)1478 static LogicalResult verify(LLVMFuncOp op) {
1479   if (op.linkage() == LLVM::Linkage::Common)
1480     return op.emitOpError()
1481            << "functions cannot have '"
1482            << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
1483 
1484   if (op.isExternal()) {
1485     if (op.linkage() != LLVM::Linkage::External &&
1486         op.linkage() != LLVM::Linkage::ExternWeak)
1487       return op.emitOpError()
1488              << "external functions must have '"
1489              << stringifyLinkage(LLVM::Linkage::External) << "' or '"
1490              << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
1491     return success();
1492   }
1493 
1494   if (op.isVarArg())
1495     return op.emitOpError("only external functions can be variadic");
1496 
1497   unsigned numArguments = op.getType().getFunctionNumParams();
1498   Block &entryBlock = op.front();
1499   for (unsigned i = 0; i < numArguments; ++i) {
1500     Type argType = entryBlock.getArgument(i).getType();
1501     auto argLLVMType = argType.dyn_cast<LLVMType>();
1502     if (!argLLVMType)
1503       return op.emitOpError("entry block argument #")
1504              << i << " is not of LLVM type";
1505     if (op.getType().getFunctionParamType(i) != argLLVMType)
1506       return op.emitOpError("the type of entry block argument #")
1507              << i << " does not match the function signature";
1508   }
1509 
1510   return success();
1511 }
1512 
1513 //===----------------------------------------------------------------------===//
1514 // Verification for LLVM::NullOp.
1515 //===----------------------------------------------------------------------===//
1516 
1517 // Only LLVM pointer types are supported.
verify(LLVM::NullOp op)1518 static LogicalResult verify(LLVM::NullOp op) {
1519   auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>();
1520   if (!llvmType || !llvmType.isPointerTy())
1521     return op.emitOpError("expected LLVM IR pointer type");
1522   return success();
1523 }
1524 
1525 //===----------------------------------------------------------------------===//
1526 // Verification for LLVM::ConstantOp.
1527 //===----------------------------------------------------------------------===//
1528 
verify(LLVM::ConstantOp op)1529 static LogicalResult verify(LLVM::ConstantOp op) {
1530   if (!(op.value().isa<IntegerAttr>() || op.value().isa<FloatAttr>() ||
1531         op.value().isa<ElementsAttr>() || op.value().isa<StringAttr>()))
1532     return op.emitOpError()
1533            << "only supports integer, float, string or elements attributes";
1534   return success();
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // Utility functions for parsing atomic ops
1539 //===----------------------------------------------------------------------===//
1540 
1541 // Helper function to parse a keyword into the specified attribute named by
1542 // `attrName`. The keyword must match one of the string values defined by the
1543 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
1544 // state.
parseAtomicBinOp(OpAsmParser & parser,OperationState & result,StringRef attrName)1545 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
1546                                     StringRef attrName) {
1547   llvm::SMLoc loc;
1548   StringRef keyword;
1549   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
1550     return failure();
1551 
1552   // Replace the keyword `keyword` with an integer attribute.
1553   auto kind = symbolizeAtomicBinOp(keyword);
1554   if (!kind) {
1555     return parser.emitError(loc)
1556            << "'" << keyword << "' is an incorrect value of the '" << attrName
1557            << "' attribute";
1558   }
1559 
1560   auto value = static_cast<int64_t>(kind.getValue());
1561   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1562   result.addAttribute(attrName, attr);
1563 
1564   return success();
1565 }
1566 
1567 // Helper function to parse a keyword into the specified attribute named by
1568 // `attrName`. The keyword must match one of the string values defined by the
1569 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
1570 // state.
parseAtomicOrdering(OpAsmParser & parser,OperationState & result,StringRef attrName)1571 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
1572                                        OperationState &result,
1573                                        StringRef attrName) {
1574   llvm::SMLoc loc;
1575   StringRef ordering;
1576   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
1577     return failure();
1578 
1579   // Replace the keyword `ordering` with an integer attribute.
1580   auto kind = symbolizeAtomicOrdering(ordering);
1581   if (!kind) {
1582     return parser.emitError(loc)
1583            << "'" << ordering << "' is an incorrect value of the '" << attrName
1584            << "' attribute";
1585   }
1586 
1587   auto value = static_cast<int64_t>(kind.getValue());
1588   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1589   result.addAttribute(attrName, attr);
1590 
1591   return success();
1592 }
1593 
1594 //===----------------------------------------------------------------------===//
1595 // Printer, parser and verifier for LLVM::AtomicRMWOp.
1596 //===----------------------------------------------------------------------===//
1597 
printAtomicRMWOp(OpAsmPrinter & p,AtomicRMWOp & op)1598 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
1599   p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' '
1600     << op.ptr() << ", " << op.val() << ' '
1601     << stringifyAtomicOrdering(op.ordering()) << ' ';
1602   p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"});
1603   p << " : " << op.res().getType();
1604 }
1605 
1606 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
1607 //                 attribute-dict? `:` type
parseAtomicRMWOp(OpAsmParser & parser,OperationState & result)1608 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
1609                                     OperationState &result) {
1610   LLVMType type;
1611   OpAsmParser::OperandType ptr, val;
1612   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
1613       parser.parseComma() || parser.parseOperand(val) ||
1614       parseAtomicOrdering(parser, result, "ordering") ||
1615       parser.parseOptionalAttrDict(result.attributes) ||
1616       parser.parseColonType(type) ||
1617       parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
1618       parser.resolveOperand(val, type, result.operands))
1619     return failure();
1620 
1621   result.addTypes(type);
1622   return success();
1623 }
1624 
verify(AtomicRMWOp op)1625 static LogicalResult verify(AtomicRMWOp op) {
1626   auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1627   auto valType = op.val().getType().cast<LLVM::LLVMType>();
1628   if (valType != ptrType.getPointerElementTy())
1629     return op.emitOpError("expected LLVM IR element type for operand #0 to "
1630                           "match type for operand #1");
1631   auto resType = op.res().getType().cast<LLVM::LLVMType>();
1632   if (resType != valType)
1633     return op.emitOpError(
1634         "expected LLVM IR result type to match type for operand #1");
1635   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
1636     if (!valType.isFloatingPointTy())
1637       return op.emitOpError("expected LLVM IR floating point type");
1638   } else if (op.bin_op() == AtomicBinOp::xchg) {
1639     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1640         !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
1641         !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
1642         !valType.isDoubleTy())
1643       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
1644   } else {
1645     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1646         !valType.isIntegerTy(32) && !valType.isIntegerTy(64))
1647       return op.emitOpError("expected LLVM IR integer type");
1648   }
1649   return success();
1650 }
1651 
1652 //===----------------------------------------------------------------------===//
1653 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
1654 //===----------------------------------------------------------------------===//
1655 
printAtomicCmpXchgOp(OpAsmPrinter & p,AtomicCmpXchgOp & op)1656 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
1657   p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", "
1658     << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' '
1659     << stringifyAtomicOrdering(op.failure_ordering());
1660   p.printOptionalAttrDict(op.getAttrs(),
1661                           {"success_ordering", "failure_ordering"});
1662   p << " : " << op.val().getType();
1663 }
1664 
1665 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
1666 //                 keyword keyword attribute-dict? `:` type
parseAtomicCmpXchgOp(OpAsmParser & parser,OperationState & result)1667 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
1668                                         OperationState &result) {
1669   auto &builder = parser.getBuilder();
1670   LLVMType type;
1671   OpAsmParser::OperandType ptr, cmp, val;
1672   if (parser.parseOperand(ptr) || parser.parseComma() ||
1673       parser.parseOperand(cmp) || parser.parseComma() ||
1674       parser.parseOperand(val) ||
1675       parseAtomicOrdering(parser, result, "success_ordering") ||
1676       parseAtomicOrdering(parser, result, "failure_ordering") ||
1677       parser.parseOptionalAttrDict(result.attributes) ||
1678       parser.parseColonType(type) ||
1679       parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
1680       parser.resolveOperand(cmp, type, result.operands) ||
1681       parser.resolveOperand(val, type, result.operands))
1682     return failure();
1683 
1684   auto boolType = LLVMType::getInt1Ty(builder.getContext());
1685   auto resultType = LLVMType::getStructTy(type, boolType);
1686   result.addTypes(resultType);
1687 
1688   return success();
1689 }
1690 
verify(AtomicCmpXchgOp op)1691 static LogicalResult verify(AtomicCmpXchgOp op) {
1692   auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1693   if (!ptrType.isPointerTy())
1694     return op.emitOpError("expected LLVM IR pointer type for operand #0");
1695   auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
1696   auto valType = op.val().getType().cast<LLVM::LLVMType>();
1697   if (cmpType != ptrType.getPointerElementTy() || cmpType != valType)
1698     return op.emitOpError("expected LLVM IR element type for operand #0 to "
1699                           "match type for all other operands");
1700   if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
1701       !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
1702       !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
1703       !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
1704     return op.emitOpError("unexpected LLVM IR type");
1705   if (op.success_ordering() < AtomicOrdering::monotonic ||
1706       op.failure_ordering() < AtomicOrdering::monotonic)
1707     return op.emitOpError("ordering must be at least 'monotonic'");
1708   if (op.failure_ordering() == AtomicOrdering::release ||
1709       op.failure_ordering() == AtomicOrdering::acq_rel)
1710     return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
1711   return success();
1712 }
1713 
1714 //===----------------------------------------------------------------------===//
1715 // Printer, parser and verifier for LLVM::FenceOp.
1716 //===----------------------------------------------------------------------===//
1717 
1718 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
1719 // attribute-dict?
parseFenceOp(OpAsmParser & parser,OperationState & result)1720 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) {
1721   StringAttr sScope;
1722   StringRef syncscopeKeyword = "syncscope";
1723   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
1724     if (parser.parseLParen() ||
1725         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
1726         parser.parseRParen())
1727       return failure();
1728   } else {
1729     result.addAttribute(syncscopeKeyword,
1730                         parser.getBuilder().getStringAttr(""));
1731   }
1732   if (parseAtomicOrdering(parser, result, "ordering") ||
1733       parser.parseOptionalAttrDict(result.attributes))
1734     return failure();
1735   return success();
1736 }
1737 
printFenceOp(OpAsmPrinter & p,FenceOp & op)1738 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
1739   StringRef syncscopeKeyword = "syncscope";
1740   p << op.getOperationName() << ' ';
1741   if (!op.getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
1742     p << "syncscope(" << op.getAttr(syncscopeKeyword) << ") ";
1743   p << stringifyAtomicOrdering(op.ordering());
1744 }
1745 
verify(FenceOp & op)1746 static LogicalResult verify(FenceOp &op) {
1747   if (op.ordering() == AtomicOrdering::not_atomic ||
1748       op.ordering() == AtomicOrdering::unordered ||
1749       op.ordering() == AtomicOrdering::monotonic)
1750     return op.emitOpError("can be given only acquire, release, acq_rel, "
1751                           "and seq_cst orderings");
1752   return success();
1753 }
1754 
1755 //===----------------------------------------------------------------------===//
1756 // LLVMDialect initialization, type parsing, and registration.
1757 //===----------------------------------------------------------------------===//
1758 
initialize()1759 void LLVMDialect::initialize() {
1760   // clang-format off
1761   addTypes<LLVMVoidType,
1762            LLVMHalfType,
1763            LLVMBFloatType,
1764            LLVMFloatType,
1765            LLVMDoubleType,
1766            LLVMFP128Type,
1767            LLVMX86FP80Type,
1768            LLVMPPCFP128Type,
1769            LLVMX86MMXType,
1770            LLVMTokenType,
1771            LLVMLabelType,
1772            LLVMMetadataType,
1773            LLVMFunctionType,
1774            LLVMIntegerType,
1775            LLVMPointerType,
1776            LLVMFixedVectorType,
1777            LLVMScalableVectorType,
1778            LLVMArrayType,
1779            LLVMStructType>();
1780   // clang-format on
1781   addOperations<
1782 #define GET_OP_LIST
1783 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1784       >();
1785 
1786   // Support unknown operations because not all LLVM operations are registered.
1787   allowUnknownOperations();
1788 }
1789 
1790 #define GET_OP_CLASSES
1791 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1792 
1793 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const1794 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
1795   return detail::parseType(parser);
1796 }
1797 
1798 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const1799 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
1800   return detail::printType(type.cast<LLVMType>(), os);
1801 }
1802 
verifyDataLayoutString(StringRef descr,llvm::function_ref<void (const Twine &)> reportError)1803 LogicalResult LLVMDialect::verifyDataLayoutString(
1804     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
1805   llvm::Expected<llvm::DataLayout> maybeDataLayout =
1806       llvm::DataLayout::parse(descr);
1807   if (maybeDataLayout)
1808     return success();
1809 
1810   std::string message;
1811   llvm::raw_string_ostream messageStream(message);
1812   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
1813   reportError("invalid data layout descriptor: " + messageStream.str());
1814   return failure();
1815 }
1816 
1817 /// Verify LLVM dialect attributes.
verifyOperationAttribute(Operation * op,NamedAttribute attr)1818 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
1819                                                     NamedAttribute attr) {
1820   // If the data layout attribute is present, it must use the LLVM data layout
1821   // syntax. Try parsing it and report errors in case of failure. Users of this
1822   // attribute may assume it is well-formed and can pass it to the (asserting)
1823   // llvm::DataLayout constructor.
1824   if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName())
1825     return success();
1826   if (auto stringAttr = attr.second.dyn_cast<StringAttr>())
1827     return verifyDataLayoutString(
1828         stringAttr.getValue(),
1829         [op](const Twine &message) { op->emitOpError() << message.str(); });
1830 
1831   return op->emitOpError() << "expected '"
1832                            << LLVM::LLVMDialect::getDataLayoutAttrName()
1833                            << "' to be a string attribute";
1834 }
1835 
1836 /// Verify LLVMIR function argument attributes.
verifyRegionArgAttribute(Operation * op,unsigned regionIdx,unsigned argIdx,NamedAttribute argAttr)1837 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
1838                                                     unsigned regionIdx,
1839                                                     unsigned argIdx,
1840                                                     NamedAttribute argAttr) {
1841   // Check that llvm.noalias is a boolean attribute.
1842   if (argAttr.first == LLVMDialect::getNoAliasAttrName() &&
1843       !argAttr.second.isa<BoolAttr>())
1844     return op->emitError()
1845            << "llvm.noalias argument attribute of non boolean type";
1846   // Check that llvm.align is an integer attribute.
1847   if (argAttr.first == LLVMDialect::getAlignAttrName() &&
1848       !argAttr.second.isa<IntegerAttr>())
1849     return op->emitError()
1850            << "llvm.align argument attribute of non integer type";
1851   return success();
1852 }
1853 
1854 //===----------------------------------------------------------------------===//
1855 // Utility functions.
1856 //===----------------------------------------------------------------------===//
1857 
createGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,LLVM::Linkage linkage)1858 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
1859                                      StringRef name, StringRef value,
1860                                      LLVM::Linkage linkage) {
1861   assert(builder.getInsertionBlock() &&
1862          builder.getInsertionBlock()->getParentOp() &&
1863          "expected builder to point to a block constrained in an op");
1864   auto module =
1865       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
1866   assert(module && "builder points to an op outside of a module");
1867 
1868   // Create the global at the entry of the module.
1869   OpBuilder moduleBuilder(module.getBodyRegion());
1870   MLIRContext *ctx = builder.getContext();
1871   auto type =
1872       LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
1873   auto global = moduleBuilder.create<LLVM::GlobalOp>(
1874       loc, type, /*isConstant=*/true, linkage, name,
1875       builder.getStringAttr(value));
1876 
1877   // Get the pointer to the first character in the global string.
1878   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
1879   Value cst0 = builder.create<LLVM::ConstantOp>(
1880       loc, LLVM::LLVMType::getInt64Ty(ctx),
1881       builder.getIntegerAttr(builder.getIndexType(), 0));
1882   return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
1883                                      globalPtr, ValueRange{cst0, cst0});
1884 }
1885 
satisfiesLLVMModule(Operation * op)1886 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
1887   return op->hasTrait<OpTrait::SymbolTable>() &&
1888          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
1889 }
1890