//===-- FIROps.cpp --------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" using namespace fir; /// Return true if a sequence type is of some incomplete size or a record type /// is malformed or contains an incomplete sequence type. An incomplete sequence /// type is one with more unknown extents in the type than have been provided /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by /// definition. static bool verifyInType(mlir::Type inType, llvm::SmallVectorImpl &visited, unsigned dynamicExtents = 0) { if (auto st = inType.dyn_cast()) { auto shape = st.getShape(); if (shape.size() == 0) return true; for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { if (shape[i] != fir::SequenceType::getUnknownExtent()) continue; if (dynamicExtents-- == 0) return true; } } else if (auto rt = inType.dyn_cast()) { // don't recurse if we're already visiting this one if (llvm::is_contained(visited, rt.getName())) return false; // keep track of record types currently being visited visited.push_back(rt.getName()); for (auto &field : rt.getTypeList()) if (verifyInType(field.second, visited)) return true; visited.pop_back(); } else if (auto rt = inType.dyn_cast()) { return verifyInType(rt.getEleTy(), visited); } return false; } static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { if (numLenParams > 0) { if (auto rt = inType.dyn_cast()) return numLenParams != rt.getNumLenParams(); return true; } return false; } //===----------------------------------------------------------------------===// // AddfOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef opnds) { return mlir::constFoldBinaryOp( opnds, [](APFloat a, APFloat b) { return a + b; }); } //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// mlir::Type fir::AllocaOp::getAllocatedType() { return getType().cast().getEleTy(); } /// Create a legal memory reference as return type mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { // FIR semantics: memory references to memory references are disallowed if (intype.isa()) return {}; return ReferenceType::get(intype); } mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { return ReferenceType::get(ty); } //===----------------------------------------------------------------------===// // AllocMemOp //===----------------------------------------------------------------------===// mlir::Type fir::AllocMemOp::getAllocatedType() { return getType().cast().getEleTy(); } mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { return HeapType::get(ty); } /// Create a legal heap reference as return type mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well // FIR semantics: one may not allocate a memory reference value if (intype.isa() || intype.isa() || intype.isa() || intype.isa()) return {}; return HeapType::get(intype); } //===----------------------------------------------------------------------===// // BoxAddrOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef opnds) { if (auto v = val().getDefiningOp()) { if (auto box = dyn_cast(v)) return box.memref(); if (auto box = dyn_cast(v)) return box.memref(); } return {}; } //===----------------------------------------------------------------------===// // BoxCharLenOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::BoxCharLenOp::fold(llvm::ArrayRef opnds) { if (auto v = val().getDefiningOp()) { if (auto box = dyn_cast(v)) return box.len(); } return {}; } //===----------------------------------------------------------------------===// // BoxDimsOp //===----------------------------------------------------------------------===// /// Get the result types packed in a tuple tuple mlir::Type fir::BoxDimsOp::getTupleType() { // note: triple, but 4 is nearest power of 2 llvm::SmallVector triple{ getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; return mlir::TupleType::get(triple, getContext()); } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { auto callee = op.callee(); bool isDirect = callee.hasValue(); p << op.getOperationName() << ' '; if (isDirect) p << callee.getValue(); else p << op.getOperand(0); p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); auto resultTypes{op.getResultTypes()}; llvm::SmallVector argTypes( llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); } static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::SmallVector operands; if (parser.parseOperandList(operands)) return mlir::failure(); mlir::NamedAttrList attrs; mlir::SymbolRefAttr funcAttr; bool isDirect = operands.empty(); if (isDirect) if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) return mlir::failure(); Type type; if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || parser.parseOptionalAttrDict(attrs) || parser.parseColon() || parser.parseType(type)) return mlir::failure(); auto funcType = type.dyn_cast(); if (!funcType) return parser.emitError(parser.getNameLoc(), "expected function type"); if (isDirect) { if (parser.resolveOperands(operands, funcType.getInputs(), parser.getNameLoc(), result.operands)) return mlir::failure(); } else { auto funcArgs = llvm::ArrayRef(operands).drop_front(); llvm::SmallVector resultArgs( result.operands.begin() + (result.operands.empty() ? 0 : 1), result.operands.end()); if (parser.resolveOperand(operands[0], funcType, result.operands) || parser.resolveOperands(funcArgs, funcType.getInputs(), parser.getNameLoc(), resultArgs)) return mlir::failure(); } result.addTypes(funcType.getResults()); result.attributes = attrs; return mlir::success(); } //===----------------------------------------------------------------------===// // CmpfOp //===----------------------------------------------------------------------===// // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { auto pred = mlir::symbolizeCmpFPredicate(name); assert(pred.hasValue() && "invalid predicate name"); return pred.getValue(); } void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(builder.getI1Type()); result.addAttribute( CmpfOp::getPredicateAttrName(), builder.getI64IntegerAttr(static_cast(predicate))); } template static void printCmpOp(OpAsmPrinter &p, OPTY op) { p << op.getOperationName() << ' '; auto predSym = mlir::symbolizeCmpFPredicate( op.template getAttrOfType(OPTY::getPredicateAttrName()) .getInt()); assert(predSym.hasValue() && "invalid symbol value for predicate"); p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; p.printOperand(op.lhs()); p << ", "; p.printOperand(op.rhs()); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); p << " : " << op.lhs().getType(); } static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } template static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::SmallVector ops; mlir::NamedAttrList attrs; mlir::Attribute predicateNameAttr; mlir::Type type; if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), attrs) || parser.parseComma() || parser.parseOperandList(ops, 2) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands)) return failure(); if (!predicateNameAttr.isa()) return parser.emitError(parser.getNameLoc(), "expected string comparison predicate attribute"); // Rewrite string attribute to an enum value. llvm::StringRef predicateName = predicateNameAttr.cast().getValue(); auto predicate = fir::CmpfOp::getPredicateByName(predicateName); auto builder = parser.getBuilder(); mlir::Type i1Type = builder.getI1Type(); attrs.set(OPTY::getPredicateAttrName(), builder.getI64IntegerAttr(static_cast(predicate))); result.attributes = attrs; result.addTypes({i1Type}); return success(); } mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseCmpOp(parser, result); } //===----------------------------------------------------------------------===// // CmpcOp //===----------------------------------------------------------------------===// void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(builder.getI1Type()); result.addAttribute( fir::CmpcOp::getPredicateAttrName(), builder.getI64IntegerAttr(static_cast(predicate))); } static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseCmpOp(parser, result); } //===----------------------------------------------------------------------===// // ConvertOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef opnds) { if (value().getType() == getType()) return value(); if (matchPattern(value(), m_Op())) { auto inner = cast(value().getDefiningOp()); // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a if (auto toTy = getType().dyn_cast()) if (auto fromTy = inner.value().getType().dyn_cast()) if (inner.getType().isa() && (toTy == fromTy)) return inner.value(); // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a if (auto toTy = getType().dyn_cast()) if (auto fromTy = inner.value().getType().dyn_cast()) if (inner.getType().isa() && (toTy == fromTy) && (fromTy.getWidth() == 1)) return inner.value(); } return {}; } bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { return ty.isa() || ty.isa() || ty.isa() || ty.isa() || ty.isa(); } bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { return ty.isa() || ty.isa(); } bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { return ty.isa() || ty.isa() || ty.isa() || ty.isa() || ty.isa(); } //===----------------------------------------------------------------------===// // CoordinateOp //===----------------------------------------------------------------------===// static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::ArrayRef allOperandTypes; llvm::ArrayRef allResultTypes; llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); llvm::SmallVector allOperands; if (parser.parseOperandList(allOperands)) return failure(); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (parser.parseColon()) return failure(); mlir::FunctionType funcTy; if (parser.parseType(funcTy)) return failure(); allOperandTypes = funcTy.getInputs(); allResultTypes = funcTy.getResults(); result.addTypes(allResultTypes); if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, result.operands)) return failure(); if (funcTy.getNumInputs()) { // No inputs handled by verify result.addAttribute(fir::CoordinateOp::baseType(), mlir::TypeAttr::get(funcTy.getInput(0))); } return success(); } mlir::Type fir::CoordinateOp::getBaseType() { return getAttr(CoordinateOp::baseType()).cast().getValue(); } void fir::CoordinateOp::build(OpBuilder &, OperationState &result, mlir::Type resType, ValueRange operands, ArrayRef attrs) { assert(operands.size() >= 1u && "mismatched number of parameters"); result.addOperands(operands); result.addAttribute(fir::CoordinateOp::baseType(), mlir::TypeAttr::get(operands[0].getType())); result.attributes.append(attrs.begin(), attrs.end()); result.addTypes({resType}); } void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, mlir::Type resType, mlir::Value ref, ValueRange coor, ArrayRef attrs) { llvm::SmallVector operands{ref}; operands.append(coor.begin(), coor.end()); build(builder, result, resType, operands, attrs); } //===----------------------------------------------------------------------===// // DispatchOp //===----------------------------------------------------------------------===// mlir::FunctionType fir::DispatchOp::getFunctionType() { auto attr = getAttr("fn_type").cast(); return attr.getValue().cast(); } //===----------------------------------------------------------------------===// // DispatchTableOp //===----------------------------------------------------------------------===// void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { assert(mlir::isa(*op) && "operation must be a DTEntryOp"); auto &block = getBlock(); block.getOperations().insert(block.end(), op); } //===----------------------------------------------------------------------===// // EmboxOp //===----------------------------------------------------------------------===// static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::FunctionType type; llvm::SmallVector operands; mlir::OpAsmParser::OperandType memref; if (parser.parseOperand(memref)) return mlir::failure(); operands.push_back(memref); auto &builder = parser.getBuilder(); if (!parser.parseOptionalLParen()) { if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || parser.parseRParen()) return mlir::failure(); auto lens = builder.getI32IntegerAttr(operands.size()); result.addAttribute(fir::EmboxOp::lenpName(), lens); } if (!parser.parseOptionalComma()) { mlir::OpAsmParser::OperandType dims; if (parser.parseOperand(dims)) return mlir::failure(); operands.push_back(dims); } else if (!parser.parseOptionalLSquare()) { mlir::AffineMapAttr map; if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), result.attributes) || parser.parseRSquare()) return mlir::failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), result.operands) || parser.addTypesToList(type.getResults(), result.types)) return mlir::failure(); return mlir::success(); } //===----------------------------------------------------------------------===// // GenTypeDescOp //===----------------------------------------------------------------------===// void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, mlir::TypeAttr inty) { result.addAttribute("in_type", inty); result.addTypes(TypeDescType::get(inty.getValue())); } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { // Parse the optional linkage llvm::StringRef linkage; auto &builder = parser.getBuilder(); if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { if (fir::GlobalOp::verifyValidLinkage(linkage)) return failure(); mlir::StringAttr linkAttr = builder.getStringAttr(linkage); result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); } // Parse the name as a symbol reference attribute. mlir::SymbolRefAttr nameAttr; if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), result.attributes)) return failure(); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(nameAttr.getRootReference())); bool simpleInitializer = false; if (mlir::succeeded(parser.parseOptionalLParen())) { Attribute attr; if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), result.attributes) || parser.parseRParen()) return failure(); simpleInitializer = true; } if (succeeded(parser.parseOptionalKeyword("constant"))) { // if "constant" keyword then mark this as a constant, not a variable result.addAttribute(fir::GlobalOp::constantAttrName(), builder.getUnitAttr()); } mlir::Type globalType; if (parser.parseColonType(globalType)) return failure(); result.addAttribute(fir::GlobalOp::typeAttrName(), mlir::TypeAttr::get(globalType)); if (simpleInitializer) { result.addRegion(); } else { // Parse the optional initializer body. if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) return failure(); } return success(); } void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { getBlock().getOperations().push_back(op); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, bool isConstant, Type type, Attribute initialVal, StringAttr linkage, ArrayRef attrs) { result.addRegion(); result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); if (isConstant) result.addAttribute(constantAttrName(), builder.getUnitAttr()); if (initialVal) result.addAttribute(initValAttrName(), initialVal); if (linkage) result.addAttribute(linkageAttrName(), linkage); result.attributes.append(attrs.begin(), attrs.end()); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, Type type, Attribute initialVal, StringAttr linkage, ArrayRef attrs) { build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, bool isConstant, Type type, StringAttr linkage, ArrayRef attrs) { build(builder, result, name, isConstant, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, Type type, StringAttr linkage, ArrayRef attrs) { build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, bool isConstant, Type type, ArrayRef attrs) { build(builder, result, name, isConstant, type, StringAttr{}, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, Type type, ArrayRef attrs) { build(builder, result, name, /*isConstant=*/false, type, attrs); } mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { // Supporting only a subset of the LLVM linkage types for now static const llvm::SmallVector validNames = { "internal", "common", "weak"}; return mlir::success(llvm::is_contained(validNames, linkage)); } //===----------------------------------------------------------------------===// // IterWhileOp //===----------------------------------------------------------------------===// void fir::IterWhileOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, mlir::Value iterate, mlir::ValueRange iterArgs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step, iterate}); result.addTypes(iterate.getType()); result.addOperands(iterArgs); for (auto v : iterArgs) result.addTypes(v.getType()); mlir::Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block{}); bodyRegion->front().addArgument(builder.getIndexType()); bodyRegion->front().addArgument(iterate.getType()); bodyRegion->front().addArguments(iterArgs.getTypes()); result.addAttributes(attributes); } static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) return mlir::failure(); // Parse loop bounds. auto indexType = builder.getIndexType(); auto i1Type = builder.getIntegerType(1); if (parser.parseOperand(lb) || parser.resolveOperand(lb, indexType, result.operands) || parser.parseKeyword("to") || parser.parseOperand(ub) || parser.resolveOperand(ub, indexType, result.operands) || parser.parseKeyword("step") || parser.parseOperand(step) || parser.parseRParen() || parser.resolveOperand(step, indexType, result.operands)) return mlir::failure(); mlir::OpAsmParser::OperandType iterateVar, iterateInput; if (parser.parseKeyword("and") || parser.parseLParen() || parser.parseRegionArgument(iterateVar) || parser.parseEqual() || parser.parseOperand(iterateInput) || parser.parseRParen() || parser.resolveOperand(iterateInput, i1Type, result.operands)) return mlir::failure(); // Parse the initial iteration arguments. llvm::SmallVector regionArgs; // Induction variable. regionArgs.push_back(inductionVariable); regionArgs.push_back(iterateVar); result.addTypes(i1Type); if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { llvm::SmallVector operands; llvm::SmallVector regionTypes; // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(regionTypes)) return mlir::failure(); // Resolve input operands. for (auto operand_type : llvm::zip(operands, regionTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return mlir::failure(); result.addTypes(regionTypes); } if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); llvm::SmallVector argTypes; // Induction variable (hidden) argTypes.push_back(indexType); // Loop carried variables (including iterate) argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. auto *body = result.addRegion(); if (regionArgs.size() != argTypes.size()) return parser.emitError( parser.getNameLoc(), "mismatch in number of loop-carried values and defined values"); if (parser.parseRegion(*body, regionArgs, argTypes)) return failure(); fir::IterWhileOp::ensureTerminator(*body, builder, result.location); return mlir::success(); } static mlir::LogicalResult verify(fir::IterWhileOp op) { if (auto cst = dyn_cast_or_null(op.step().getDefiningOp())) if (cst.getValue() <= 0) return op.emitOpError("constant step operand must be positive"); // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); if (!body->getArgument(1).getType().isInteger(1)) return op.emitOpError( "expected body second argument to be an index argument for " "the induction variable"); if (!body->getArgument(0).getType().isIndex()) return op.emitOpError( "expected body first argument to be an index argument for " "the induction variable"); auto opNumResults = op.getNumResults(); if (opNumResults == 0) return mlir::failure(); if (op.getNumIterOperands() != opNumResults) return op.emitOpError( "mismatch in number of loop-carried values and defined values"); if (op.getNumRegionIterArgs() != opNumResults) return op.emitOpError( "mismatch in number of basic block args and defined values"); auto iterOperands = op.getIterOperands(); auto iterArgs = op.getRegionIterArgs(); auto opResults = op.getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) return op.emitOpError() << "types mismatch between " << i << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) return op.emitOpError() << "types mismatch between " << i << "th iter region arg and defined value"; i++; } return mlir::success(); } static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step() << ") and ("; assert(op.hasIterOperands()); auto regionArgs = op.getRegionIterArgs(); auto operands = op.getIterOperands(); p << regionArgs.front() << " = " << *operands.begin() << ")"; if (regionArgs.size() > 1) { p << " iter_args("; llvm::interleaveComma( llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << op.getResultTypes().drop_front() << ')'; } p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { return !region().isAncestor(value.getParentRegion()); } mlir::LogicalResult fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef ops) { for (auto op : ops) op->moveBefore(*this); return success(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// /// Get the element type of a reference like type; otherwise null static mlir::Type elementTypeOf(mlir::Type ref) { return llvm::TypeSwitch(ref) .Case( [](auto type) { return type.getEleTy(); }) .Default([](mlir::Type) { return mlir::Type{}; }); } mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { if ((ele = elementTypeOf(ref))) return mlir::success(); return mlir::failure(); } //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, bool unordered, mlir::ValueRange iterArgs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step}); result.addOperands(iterArgs); for (auto v : iterArgs) result.addTypes(v.getType()); mlir::Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block{}); if (iterArgs.empty()) LoopOp::ensureTerminator(*bodyRegion, builder, result.location); bodyRegion->front().addArgument(builder.getIndexType()); bodyRegion->front().addArguments(iterArgs.getTypes()); if (unordered) result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); result.addAttributes(attributes); } static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) return mlir::failure(); // Parse loop bounds. auto indexType = builder.getIndexType(); if (parser.parseOperand(lb) || parser.resolveOperand(lb, indexType, result.operands) || parser.parseKeyword("to") || parser.parseOperand(ub) || parser.resolveOperand(ub, indexType, result.operands) || parser.parseKeyword("step") || parser.parseOperand(step) || parser.resolveOperand(step, indexType, result.operands)) return failure(); if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) result.addAttribute(fir::LoopOp::unorderedAttrName(), builder.getUnitAttr()); // Parse the optional initial iteration arguments. llvm::SmallVector regionArgs, operands; llvm::SmallVector argTypes; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(result.types)) return failure(); // Resolve input operands. for (auto operand_type : llvm::zip(operands, result.types)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return failure(); } if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); // Induction variable. argTypes.push_back(indexType); // Loop carried variables argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. auto *body = result.addRegion(); if (regionArgs.size() != argTypes.size()) return parser.emitError( parser.getNameLoc(), "mismatch in number of loop-carried values and defined values"); if (parser.parseRegion(*body, regionArgs, argTypes)) return failure(); fir::LoopOp::ensureTerminator(*body, builder, result.location); return mlir::success(); } fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { auto ivArg = val.dyn_cast(); if (!ivArg) return {}; assert(ivArg.getOwner() && "unlinked block argument"); auto *containingInst = ivArg.getOwner()->getParentOp(); return dyn_cast_or_null(containingInst); } // Lifted from loop.loop static mlir::LogicalResult verify(fir::LoopOp op) { if (auto cst = dyn_cast_or_null(op.step().getDefiningOp())) if (cst.getValue() <= 0) return op.emitOpError("constant step operand must be positive"); // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); if (!body->getArgument(0).getType().isIndex()) return op.emitOpError( "expected body first argument to be an index argument for " "the induction variable"); auto opNumResults = op.getNumResults(); if (opNumResults == 0) return success(); if (op.getNumIterOperands() != opNumResults) return op.emitOpError( "mismatch in number of loop-carried values and defined values"); if (op.getNumRegionIterArgs() != opNumResults) return op.emitOpError( "mismatch in number of basic block args and defined values"); auto iterOperands = op.getIterOperands(); auto iterArgs = op.getRegionIterArgs(); auto opResults = op.getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) return op.emitOpError() << "types mismatch between " << i << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) return op.emitOpError() << "types mismatch between " << i << "th iter region arg and defined value"; i++; } return success(); } static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { bool printBlockTerminators = false; p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); if (op.unordered()) p << " unordered"; if (op.hasIterOperands()) { p << " iter_args("; auto regionArgs = op.getRegionIterArgs(); auto operands = op.getIterOperands(); llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << op.getResultTypes() << ')'; printBlockTerminators = true; } p.printOptionalAttrDictWithKeyword(op.getAttrs(), {fir::LoopOp::unorderedAttrName()}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); } mlir::Region &fir::LoopOp::getLoopBody() { return region(); } bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { return !region().isAncestor(value.getParentRegion()); } mlir::LogicalResult fir::LoopOp::moveOutOfLoop(llvm::ArrayRef ops) { for (auto op : ops) op->moveBefore(*this); return success(); } //===----------------------------------------------------------------------===// // MulfOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef opnds) { return mlir::constFoldBinaryOp( opnds, [](APFloat a, APFloat b) { return a * b; }); } //===----------------------------------------------------------------------===// // ResultOp //===----------------------------------------------------------------------===// static mlir::LogicalResult verify(fir::ResultOp op) { auto parentOp = op.getParentOp(); auto results = parentOp->getResults(); auto operands = op.getOperands(); if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "parent of result must have same arity"; for (auto e : llvm::zip(results, operands)) if (std::get<0>(e).getType() != std::get<1>(e).getType()) return op.emitOpError() << "types mismatch between result op and its parent"; return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// static constexpr llvm::StringRef getCompareOffsetAttr() { return "compare_operand_offsets"; } static constexpr llvm::StringRef getTargetOffsetAttr() { return "target_operand_offsets"; } template static A getSubOperands(unsigned pos, A allArgs, mlir::DenseIntElementsAttr ranges, AdditionalArgs &&... additionalArgs) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += (*(ranges.begin() + i)).getZExtValue(); return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), std::forward(additionalArgs)...); } static mlir::MutableOperandRange getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, StringRef offsetAttr) { Operation *owner = operands.getOwner(); NamedAttribute targetOffsetAttr = *owner->getMutableAttrDict().getNamed(offsetAttr); return getSubOperands( pos, operands, targetOffsetAttr.second.cast(), mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); } static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { return attr.getNumElements(); } llvm::Optional fir::SelectOp::getCompareOperands(unsigned) { return {}; } llvm::Optional> fir::SelectOp::getCompareOperands(llvm::ArrayRef, unsigned) { return {}; } llvm::Optional fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { return ::getMutableSuccessorOperands(oper, targetArgsMutable(), getTargetOffsetAttr()); } llvm::Optional> fir::SelectOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = getAttrOfType(getTargetOffsetAttr()); auto segments = getAttrOfType(getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// // SelectCaseOp //===----------------------------------------------------------------------===// llvm::Optional fir::SelectCaseOp::getCompareOperands(unsigned cond) { auto a = getAttrOfType(getCompareOffsetAttr()); return {getSubOperands(cond, compareArgs(), a)}; } llvm::Optional> fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef operands, unsigned cond) { auto a = getAttrOfType(getCompareOffsetAttr()); auto segments = getAttrOfType(getOperandSegmentSizeAttr()); return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } llvm::Optional fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { return ::getMutableSuccessorOperands(oper, targetArgsMutable(), getTargetOffsetAttr()); } llvm::Optional> fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = getAttrOfType(getTargetOffsetAttr()); auto segments = getAttrOfType(getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } // parser for fir.select_case Op static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::OpAsmParser::OperandType selector; mlir::Type type; if (parseSelector(parser, result, selector, type)) return mlir::failure(); llvm::SmallVector attrs; llvm::SmallVector opers; llvm::SmallVector dests; llvm::SmallVector, 8> destArgs; llvm::SmallVector argOffs; int32_t offSize = 0; while (true) { mlir::Attribute attr; mlir::Block *dest; llvm::SmallVector destArg; mlir::NamedAttrList temp; if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || parser.parseComma()) return mlir::failure(); attrs.push_back(attr); if (attr.dyn_cast_or_null()) { argOffs.push_back(0); } else if (attr.dyn_cast_or_null()) { mlir::OpAsmParser::OperandType oper1; mlir::OpAsmParser::OperandType oper2; if (parser.parseOperand(oper1) || parser.parseComma() || parser.parseOperand(oper2) || parser.parseComma()) return mlir::failure(); opers.push_back(oper1); opers.push_back(oper2); argOffs.push_back(2); offSize += 2; } else { mlir::OpAsmParser::OperandType oper; if (parser.parseOperand(oper) || parser.parseComma()) return mlir::failure(); opers.push_back(oper); argOffs.push_back(1); ++offSize; } if (parser.parseSuccessorAndUseList(dest, destArg)) return mlir::failure(); dests.push_back(dest); destArgs.push_back(destArg); if (!parser.parseOptionalRSquare()) break; if (parser.parseComma()) return mlir::failure(); } result.addAttribute(fir::SelectCaseOp::getCasesAttr(), parser.getBuilder().getArrayAttr(attrs)); if (parser.resolveOperands(opers, type, result.operands)) return mlir::failure(); llvm::SmallVector targOffs; int32_t toffSize = 0; const auto count = dests.size(); for (std::remove_const_t i = 0; i != count; ++i) { result.addSuccessors(dests[i]); result.addOperands(destArgs[i]); auto argSize = destArgs[i].size(); targOffs.push_back(argSize); toffSize += argSize; } auto &bld = parser.getBuilder(); result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), bld.getI32VectorAttr({1, offSize, toffSize})); result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); return mlir::success(); } unsigned fir::SelectCaseOp::compareOffsetSize() { return denseElementsSize( getAttrOfType(getCompareOffsetAttr())); } unsigned fir::SelectCaseOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); } void fir::SelectCaseOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value selector, llvm::ArrayRef compareAttrs, llvm::ArrayRef cmpOperands, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, llvm::ArrayRef attributes) { result.addOperands(selector); result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); llvm::SmallVector operOffs; int32_t operSize = 0; for (auto attr : compareAttrs) { if (attr.isa()) { operOffs.push_back(2); operSize += 2; } else if (attr.isa()) { operOffs.push_back(0); } else { operOffs.push_back(1); ++operSize; } } for (auto ops : cmpOperands) result.addOperands(ops); result.addAttribute(getCompareOffsetAttr(), builder.getI32VectorAttr(operOffs)); const auto count = destinations.size(); for (auto d : destinations) result.addSuccessors(d); const auto opCount = destOperands.size(); llvm::SmallVector argOffs; int32_t sumArgs = 0; for (std::remove_const_t i = 0; i != count; ++i) { if (i < opCount) { result.addOperands(destOperands[i]); const auto argSz = destOperands[i].size(); argOffs.push_back(argSz); sumArgs += argSz; } else { argOffs.push_back(0); } } result.addAttribute(getOperandSegmentSizeAttr(), builder.getI32VectorAttr({1, operSize, sumArgs})); result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); result.addAttributes(attributes); } /// This builder has a slightly simplified interface in that the list of /// operands need not be partitioned by the builder. Instead the operands are /// partitioned here, before being passed to the default builder. This /// partitioning is unchecked, so can go awry on bad input. void fir::SelectCaseOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value selector, llvm::ArrayRef compareAttrs, llvm::ArrayRef cmpOpList, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, llvm::ArrayRef attributes) { llvm::SmallVector cmpOpers; auto iter = cmpOpList.begin(); for (auto &attr : compareAttrs) { if (attr.isa()) { cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); iter += 2; } else if (attr.isa()) { cmpOpers.push_back(mlir::ValueRange{}); } else { cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); ++iter; } } build(builder, result, selector, compareAttrs, cmpOpers, destinations, destOperands, attributes); } //===----------------------------------------------------------------------===// // SelectRankOp //===----------------------------------------------------------------------===// llvm::Optional fir::SelectRankOp::getCompareOperands(unsigned) { return {}; } llvm::Optional> fir::SelectRankOp::getCompareOperands(llvm::ArrayRef, unsigned) { return {}; } llvm::Optional fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { return ::getMutableSuccessorOperands(oper, targetArgsMutable(), getTargetOffsetAttr()); } llvm::Optional> fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = getAttrOfType(getTargetOffsetAttr()); auto segments = getAttrOfType(getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectRankOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// // SelectTypeOp //===----------------------------------------------------------------------===// llvm::Optional fir::SelectTypeOp::getCompareOperands(unsigned) { return {}; } llvm::Optional> fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef, unsigned) { return {}; } llvm::Optional fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { return ::getMutableSuccessorOperands(oper, targetArgsMutable(), getTargetOffsetAttr()); } llvm::Optional> fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = getAttrOfType(getTargetOffsetAttr()); auto segments = getAttrOfType(getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } static ParseResult parseSelectType(OpAsmParser &parser, OperationState &result) { mlir::OpAsmParser::OperandType selector; mlir::Type type; if (parseSelector(parser, result, selector, type)) return mlir::failure(); llvm::SmallVector attrs; llvm::SmallVector dests; llvm::SmallVector, 8> destArgs; while (true) { mlir::Attribute attr; mlir::Block *dest; llvm::SmallVector destArg; mlir::NamedAttrList temp; if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || parser.parseSuccessorAndUseList(dest, destArg)) return mlir::failure(); attrs.push_back(attr); dests.push_back(dest); destArgs.push_back(destArg); if (!parser.parseOptionalRSquare()) break; if (parser.parseComma()) return mlir::failure(); } auto &bld = parser.getBuilder(); result.addAttribute(fir::SelectTypeOp::getCasesAttr(), bld.getArrayAttr(attrs)); llvm::SmallVector argOffs; int32_t offSize = 0; const auto count = dests.size(); for (std::remove_const_t i = 0; i != count; ++i) { result.addSuccessors(dests[i]); result.addOperands(destArgs[i]); auto argSize = destArgs[i].size(); argOffs.push_back(argSize); offSize += argSize; } result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), bld.getI32VectorAttr({1, 0, offSize})); result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); return mlir::success(); } unsigned fir::SelectTypeOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// mlir::Type fir::StoreOp::elementType(mlir::Type refType) { if (auto ref = refType.dyn_cast()) return ref.getEleTy(); if (auto ref = refType.dyn_cast()) return ref.getEleTy(); if (auto ref = refType.dyn_cast()) return ref.getEleTy(); return {}; } //===----------------------------------------------------------------------===// // StringLitOp //===----------------------------------------------------------------------===// bool fir::StringLitOp::isWideValue() { auto eleTy = getType().cast().getEleTy(); return eleTy.cast().getFKind() != 1; } //===----------------------------------------------------------------------===// // SubfOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef opnds) { return mlir::constFoldBinaryOp( opnds, [](APFloat a, APFloat b) { return a - b; }); } //===----------------------------------------------------------------------===// // WhereOp //===----------------------------------------------------------------------===// void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, mlir::Value cond, bool withElseRegion) { build(builder, result, llvm::None, cond, withElseRegion); } void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, mlir::TypeRange resultTypes, mlir::Value cond, bool withElseRegion) { result.addOperands(cond); result.addTypes(resultTypes); mlir::Region *thenRegion = result.addRegion(); thenRegion->push_back(new mlir::Block()); if (resultTypes.empty()) WhereOp::ensureTerminator(*thenRegion, builder, result.location); mlir::Region *elseRegion = result.addRegion(); if (withElseRegion) { elseRegion->push_back(new mlir::Block()); if (resultTypes.empty()) WhereOp::ensureTerminator(*elseRegion, builder, result.location); } } static mlir::ParseResult parseWhereOp(OpAsmParser &parser, OperationState &result) { result.regions.reserve(2); mlir::Region *thenRegion = result.addRegion(); mlir::Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::OperandType cond; mlir::Type i1Type = builder.getIntegerType(1); if (parser.parseOperand(cond) || parser.resolveOperand(cond, i1Type, result.operands)) return mlir::failure(); if (parser.parseRegion(*thenRegion, {}, {})) return mlir::failure(); WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, {}, {})) return mlir::failure(); WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); return mlir::success(); } static LogicalResult verify(fir::WhereOp op) { if (op.getNumResults() != 0 && op.otherRegion().empty()) return op.emitOpError("must have an else block if defining values"); return mlir::success(); } static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { bool printBlockTerminators = false; p << fir::WhereOp::getOperationName() << ' ' << op.condition(); if (!op.results().empty()) { p << " -> (" << op.getResultTypes() << ')'; printBlockTerminators = true; } p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); // Print the 'else' regions if it exists and has a block. auto &otherReg = op.otherRegion(); if (!otherReg.empty()) { p << " else"; p.printRegion(otherReg, /*printEntryBlockArgs=*/false, printBlockTerminators); } p.printOptionalAttrDict(op.getAttrs()); } //===----------------------------------------------------------------------===// mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { if (attr.dyn_cast_or_null() || attr.dyn_cast_or_null() || attr.dyn_cast_or_null() || attr.dyn_cast_or_null() || attr.dyn_cast_or_null()) return mlir::success(); return mlir::failure(); } unsigned fir::getCaseArgumentOffset(llvm::ArrayRef cases, unsigned dest) { unsigned o = 0; for (unsigned i = 0; i < dest; ++i) { auto &attr = cases[i]; if (!attr.dyn_cast_or_null()) { ++o; if (attr.dyn_cast_or_null()) ++o; } } return o; } mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, mlir::OpAsmParser::OperandType &selector, mlir::Type &type) { if (parser.parseOperand(selector) || parser.parseColonType(type) || parser.resolveOperand(selector, type, result.operands) || parser.parseLSquare()) return mlir::failure(); return mlir::success(); } /// Generic pretty-printer of a binary operation static void printBinaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 2 && "binary op must have two operands"); assert(op->getNumResults() == 1 && "binary op must have one result"); p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getResult(0).getType(); } /// Generic pretty-printer of an unary operation static void printUnaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 1 && "unary op must have one operand"); assert(op->getNumResults() == 1 && "unary op must have one result"); p << op->getName() << ' ' << op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getResult(0).getType(); } bool fir::isReferenceLike(mlir::Type type) { return type.isa() || type.isa() || type.isa(); } mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, StringRef name, mlir::FunctionType type, llvm::ArrayRef attrs) { if (auto f = module.lookupSymbol(name)) return f; mlir::OpBuilder modBuilder(module.getBodyRegion()); modBuilder.setInsertionPoint(module.getBody()->getTerminator()); return modBuilder.create(loc, name, type, attrs); } fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, StringRef name, mlir::Type type, llvm::ArrayRef attrs) { if (auto g = module.lookupSymbol(name)) return g; mlir::OpBuilder modBuilder(module.getBodyRegion()); return modBuilder.create(loc, name, type, attrs); } // Tablegen operators #define GET_OP_CLASSES #include "flang/Optimizer/Dialect/FIROps.cpp.inc"