1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/Pass.h"
10 #include "mlir/Pass/PassManager.h"
11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12 
13 using namespace mlir;
14 
15 /// Custom constraint invoked from PDL.
customSingleEntityConstraint(PDLValue value,ArrayAttr constantParams,PatternRewriter & rewriter)16 static LogicalResult customSingleEntityConstraint(PDLValue value,
17                                                   ArrayAttr constantParams,
18                                                   PatternRewriter &rewriter) {
19   Operation *rootOp = value.cast<Operation *>();
20   return success(rootOp->getName().getStringRef() == "test.op");
21 }
customMultiEntityConstraint(ArrayRef<PDLValue> values,ArrayAttr constantParams,PatternRewriter & rewriter)22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
23                                                  ArrayAttr constantParams,
24                                                  PatternRewriter &rewriter) {
25   return customSingleEntityConstraint(values[1], constantParams, rewriter);
26 }
27 
28 // Custom creator invoked from PDL.
customCreate(ArrayRef<PDLValue> args,ArrayAttr constantParams,PatternRewriter & rewriter)29 static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
30                              PatternRewriter &rewriter) {
31   return rewriter.createOperation(
32       OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
33 }
34 
35 /// Custom rewriter invoked from PDL.
customRewriter(Operation * root,ArrayRef<PDLValue> args,ArrayAttr constantParams,PatternRewriter & rewriter)36 static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
37                            ArrayAttr constantParams,
38                            PatternRewriter &rewriter) {
39   OperationState successOpState(root->getLoc(), "test.success");
40   successOpState.addOperands(args[0].cast<Value>());
41   successOpState.addAttribute("constantParams", constantParams);
42   rewriter.createOperation(successOpState);
43   rewriter.eraseOp(root);
44 }
45 
46 namespace {
47 struct TestPDLByteCodePass
48     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
runOnOperation__anon1ab63e080111::TestPDLByteCodePass49   void runOnOperation() final {
50     ModuleOp module = getOperation();
51 
52     // The test cases are encompassed via two modules, one containing the
53     // patterns and one containing the operations to rewrite.
54     ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
55     ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
56     if (!patternModule || !irModule)
57       return;
58 
59     // Process the pattern module.
60     patternModule.getOperation()->remove();
61     PDLPatternModule pdlPattern(patternModule);
62     pdlPattern.registerConstraintFunction("multi_entity_constraint",
63                                           customMultiEntityConstraint);
64     pdlPattern.registerConstraintFunction("single_entity_constraint",
65                                           customSingleEntityConstraint);
66     pdlPattern.registerCreateFunction("creator", customCreate);
67     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
68 
69     OwningRewritePatternList patternList(std::move(pdlPattern));
70 
71     // Invoke the pattern driver with the provided patterns.
72     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
73                                        std::move(patternList));
74   }
75 };
76 } // end anonymous namespace
77 
78 namespace mlir {
79 namespace test {
registerTestPDLByteCodePass()80 void registerTestPDLByteCodePass() {
81   PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
82                                         "Test PDL ByteCode functionality");
83 }
84 } // namespace test
85 } // namespace mlir
86