//===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; using namespace mlir::pdl; //===----------------------------------------------------------------------===// // PDLDialect //===----------------------------------------------------------------------===// void PDLDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" >(); addTypes(); } Type PDLDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; if (parser.parseKeyword(&keyword)) return Type(); Builder &builder = parser.getBuilder(); Type result = StringSwitch(keyword) .Case("attribute", builder.getType()) .Case("operation", builder.getType()) .Case("type", builder.getType()) .Case("value", builder.getType()) .Default(Type()); if (!result) parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `") << keyword << "'"; return result; } void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const { if (type.isa()) printer << "attribute"; else if (type.isa()) printer << "operation"; else if (type.isa()) printer << "type"; else if (type.isa()) printer << "value"; else llvm_unreachable("unknown 'pdl' type"); } /// Returns true if the given operation is used by a "binding" pdl operation /// within the main matcher body of a `pdl.pattern`. static LogicalResult verifyHasBindingUseInMatcher(Operation *op, StringRef bindableContextStr = "`pdl.operation`") { // If the pattern is not a pattern, there is nothing to do. if (!isa(op->getParentOp())) return success(); Block *matcherBlock = op->getBlock(); for (Operation *user : op->getUsers()) { if (user->getBlock() != matcherBlock) continue; if (isa(user)) return success(); } return op->emitOpError() << "expected a bindable (i.e. " << bindableContextStr << ") user when defined in the matcher body of a `pdl.pattern`"; } //===----------------------------------------------------------------------===// // pdl::ApplyConstraintOp //===----------------------------------------------------------------------===// static LogicalResult verify(ApplyConstraintOp op) { if (op.getNumOperands() == 0) return op.emitOpError("expected at least one argument"); return success(); } //===----------------------------------------------------------------------===// // pdl::AttributeOp //===----------------------------------------------------------------------===// static LogicalResult verify(AttributeOp op) { Value attrType = op.type(); Optional attrValue = op.value(); if (!attrValue && isa(op->getParentOp())) return op.emitOpError("expected constant value when specified within a " "`pdl.rewrite`"); if (attrValue && attrType) return op.emitOpError("expected only one of [`type`, `value`] to be set"); return verifyHasBindingUseInMatcher(op); } //===----------------------------------------------------------------------===// // pdl::InputOp //===----------------------------------------------------------------------===// static LogicalResult verify(InputOp op) { return verifyHasBindingUseInMatcher(op); } //===----------------------------------------------------------------------===// // pdl::OperationOp //===----------------------------------------------------------------------===// static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) { Builder &builder = p.getBuilder(); // Parse the optional operation name. bool startsWithOperands = succeeded(p.parseOptionalLParen()); bool startsWithAttributes = !startsWithOperands && succeeded(p.parseOptionalLBrace()); bool startsWithOpName = false; if (!startsWithAttributes && !startsWithOperands) { StringAttr opName; OptionalParseResult opNameResult = p.parseOptionalAttribute(opName, "name", state.attributes); startsWithOpName = opNameResult.hasValue(); if (startsWithOpName && failed(*opNameResult)) return failure(); } // Parse the operands. SmallVector operands; if (startsWithOperands || (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) { if (p.parseOperandList(operands) || p.parseRParen() || p.resolveOperands(operands, builder.getType(), state.operands)) return failure(); } // Parse the attributes. SmallVector attrNames; if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) { SmallVector attrOps; do { StringAttr nameAttr; OpAsmParser::OperandType operand; if (p.parseAttribute(nameAttr) || p.parseEqual() || p.parseOperand(operand)) return failure(); attrNames.push_back(nameAttr); attrOps.push_back(operand); } while (succeeded(p.parseOptionalComma())); if (p.parseRBrace() || p.resolveOperands(attrOps, builder.getType(), state.operands)) return failure(); } state.addAttribute("attributeNames", builder.getArrayAttr(attrNames)); state.addTypes(builder.getType()); // Parse the result types. SmallVector opResultTypes; if (succeeded(p.parseOptionalArrow())) { if (p.parseOperandList(opResultTypes) || p.resolveOperands(opResultTypes, builder.getType(), state.operands)) return failure(); state.types.append(opResultTypes.size(), builder.getType()); } if (p.parseOptionalAttrDict(state.attributes)) return failure(); int32_t operandSegmentSizes[] = {static_cast(operands.size()), static_cast(attrNames.size()), static_cast(opResultTypes.size())}; state.addAttribute("operand_segment_sizes", builder.getI32VectorAttr(operandSegmentSizes)); return success(); } static void print(OpAsmPrinter &p, OperationOp op) { p << "pdl.operation "; if (Optional name = op.name()) p << '"' << *name << '"'; auto operandValues = op.operands(); if (!operandValues.empty()) p << '(' << operandValues << ')'; // Emit the optional attributes. ArrayAttr attrNames = op.attributeNames(); if (!attrNames.empty()) { Operation::operand_range attrArgs = op.attributes(); p << " {"; interleaveComma(llvm::seq(0, attrNames.size()), p, [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); p << '}'; } // Print the result type constraints of the operation. if (!op.results().empty()) p << " -> " << op.types(); p.printOptionalAttrDict(op.getAttrs(), {"attributeNames", "name", "operand_segment_sizes"}); } /// Verifies that the result types of this operation, defined within a /// `pdl.rewrite`, can be inferred. static LogicalResult verifyResultTypesAreInferrable(OperationOp op, ResultRange opResults, OperandRange resultTypes) { // Functor that returns if the given use can be used to infer a type. Block *rewriterBlock = op->getBlock(); auto canInferTypeFromUse = [&](OpOperand &use) { // If the use is within a ReplaceOp and isn't the operation being replaced // (i.e. is not the first operand of the replacement), we can infer a type. ReplaceOp replOpUser = dyn_cast(use.getOwner()); if (!replOpUser || use.getOperandNumber() == 0) return false; // Make sure the replaced operation was defined before this one. Operation *replacedOp = replOpUser.operation().getDefiningOp(); return replacedOp->getBlock() != rewriterBlock || replacedOp->isBeforeInBlock(op); }; // Check to see if the uses of the operation itself can be used to infer // types. if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) return success(); // Otherwise, make sure each of the types can be inferred. for (int i : llvm::seq(0, opResults.size())) { Operation *resultTypeOp = resultTypes[i].getDefiningOp(); assert(resultTypeOp && "expected valid result type operation"); // If the op was defined by a `create_native`, it is guaranteed to be // usable. if (isa(resultTypeOp)) continue; // If the type is already constrained, there is nothing to do. TypeOp typeOp = cast(resultTypeOp); if (typeOp.type()) continue; // If the type operation was defined in the matcher and constrains the // result of an input operation, it can be used. auto constrainsInputOp = [rewriterBlock](Operation *user) { return user->getBlock() != rewriterBlock && isa(user); }; if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp)) continue; // Otherwise, check to see if any uses of the result can infer the type. if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse)) continue; return op .emitOpError("must have inferable or constrained result types when " "nested within `pdl.rewrite`") .attachNote() .append("result type #", i, " was not constrained"); } return success(); } static LogicalResult verify(OperationOp op) { bool isWithinRewrite = isa(op->getParentOp()); if (isWithinRewrite && !op.name()) return op.emitOpError("must have an operation name when nested within " "a `pdl.rewrite`"); ArrayAttr attributeNames = op.attributeNames(); auto attributeValues = op.attributes(); if (attributeNames.size() != attributeValues.size()) { return op.emitOpError() << "expected the same number of attribute values and attribute " "names, got " << attributeNames.size() << " names and " << attributeValues.size() << " values"; } OperandRange resultTypes = op.types(); auto opResults = op.results(); if (resultTypes.size() != opResults.size()) { return op.emitOpError() << "expected the same number of result values and " "result type constraints, got " << opResults.size() << " results and " << resultTypes.size() << " constraints"; } // If the operation is within a rewrite body and doesn't have type inferrence, // ensure that the result types can be resolved. if (isWithinRewrite && !op.hasTypeInference()) { if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes))) return failure(); } return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`"); } bool OperationOp::hasTypeInference() { Optional opName = name(); if (!opName) return false; OperationName name(*opName, getContext()); if (const AbstractOperation *op = name.getAbstractOperation()) return op->getInterface(); return false; } //===----------------------------------------------------------------------===// // pdl::PatternOp //===----------------------------------------------------------------------===// static LogicalResult verify(PatternOp pattern) { Region &body = pattern.body(); auto *term = body.front().getTerminator(); if (!isa(term)) { return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") .attachNote(term->getLoc()) .append("see terminator defined here"); } // Check that all values defined in the top-level pattern are referenced at // least once in the source tree. WalkResult result = body.walk([&](Operation *op) -> WalkResult { if (!isa_and_nonnull(op->getDialect())) { pattern .emitOpError("expected only `pdl` operations within the pattern body") .attachNote(op->getLoc()) .append("see non-`pdl` operation defined here"); return WalkResult::interrupt(); } return WalkResult::advance(); }); return failure(result.wasInterrupted()); } void PatternOp::build(OpBuilder &builder, OperationState &state, Optional rootKind, Optional benefit, Optional name) { build(builder, state, rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), builder.getI16IntegerAttr(benefit ? *benefit : 0), name ? builder.getStringAttr(*name) : StringAttr()); builder.createBlock(state.addRegion()); } /// Returns the rewrite operation of this pattern. RewriteOp PatternOp::getRewriter() { return cast(body().front().getTerminator()); } /// Return the root operation kind that this pattern matches, or None if /// there isn't a specific root. Optional PatternOp::getRootKind() { OperationOp rootOp = cast(getRewriter().root().getDefiningOp()); return rootOp.name(); } //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// static LogicalResult verify(ReplaceOp op) { auto sourceOp = cast(op.operation().getDefiningOp()); auto sourceOpResults = sourceOp.results(); auto replValues = op.replValues(); if (Value replOpVal = op.replOperation()) { auto replOp = cast(replOpVal.getDefiningOp()); auto replOpResults = replOp.results(); if (sourceOpResults.size() != replOpResults.size()) { return op.emitOpError() << "expected source operation to have the same number of results " "as the replacement operation, replacement operation provided " << replOpResults.size() << " but expected " << sourceOpResults.size(); } if (!replValues.empty()) { return op.emitOpError() << "expected no replacement values to be provided" " when the replacement operation is present"; } return success(); } if (sourceOpResults.size() != replValues.size()) { return op.emitOpError() << "expected source operation to have the same number of results " "as the provided replacement values, found " << replValues.size() << " replacement values but expected " << sourceOpResults.size(); } return success(); } //===----------------------------------------------------------------------===// // pdl::RewriteOp //===----------------------------------------------------------------------===// static LogicalResult verify(RewriteOp op) { Region &rewriteRegion = op.body(); // Handle the case where the rewrite is external. if (op.name()) { if (!rewriteRegion.empty()) { return op.emitOpError() << "expected rewrite region to be empty when rewrite is external"; } return success(); } // Otherwise, check that the rewrite region only contains a single block. if (rewriteRegion.empty()) { return op.emitOpError() << "expected rewrite region to be non-empty if " "external name is not specified"; } // Check that no additional arguments were provided. if (!op.externalArgs().empty()) { return op.emitOpError() << "expected no external arguments when the " "rewrite is specified inline"; } if (op.externalConstParams()) { return op.emitOpError() << "expected no external constant parameters when " "the rewrite is specified inline"; } return success(); } //===----------------------------------------------------------------------===// // pdl::TypeOp //===----------------------------------------------------------------------===// static LogicalResult verify(TypeOp op) { return verifyHasBindingUseInMatcher( op, "`pdl.attribute`, `pdl.input`, or `pdl.operation`"); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"