1 //===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test the dynamic pipeline feature.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/LoopUtils.h"
18 #include "mlir/Transforms/Passes.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 class TestDynamicPipelinePass
25     : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
26 public:
getDependentDialects(DialectRegistry & registry) const27   void getDependentDialects(DialectRegistry &registry) const override {
28     OpPassManager pm(ModuleOp::getOperationName(),
29                      OpPassManager::Nesting::Implicit);
30     parsePassPipeline(pipeline, pm, llvm::errs());
31     pm.getDependentDialects(registry);
32   }
33 
TestDynamicPipelinePass()34   TestDynamicPipelinePass(){};
TestDynamicPipelinePass(const TestDynamicPipelinePass &)35   TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
36 
runOnOperation()37   void runOnOperation() override {
38     llvm::errs() << "Dynamic execute '" << pipeline << "' on "
39                  << getOperation()->getName() << "\n";
40     if (pipeline.empty()) {
41       llvm::errs() << "Empty pipeline\n";
42       return;
43     }
44     auto symbolOp = dyn_cast<SymbolOpInterface>(getOperation());
45     if (!symbolOp) {
46       getOperation()->emitWarning()
47           << "Ignoring because not implementing SymbolOpInterface\n";
48       return;
49     }
50 
51     auto opName = symbolOp.getName();
52     if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
53       llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
54       return;
55     }
56     if (!pm) {
57       pm = std::make_unique<OpPassManager>(
58           getOperation()->getName().getIdentifier(),
59           OpPassManager::Nesting::Implicit);
60       parsePassPipeline(pipeline, *pm, llvm::errs());
61     }
62 
63     // Check that running on the parent operation always immediately fails.
64     if (runOnParent) {
65       if (getOperation()->getParentOp())
66         if (!failed(runPipeline(*pm, getOperation()->getParentOp())))
67           signalPassFailure();
68       return;
69     }
70 
71     if (runOnNestedOp) {
72       llvm::errs() << "Run on nested op\n";
73       getOperation()->walk([&](Operation *op) {
74         if (op == getOperation() || !op->isKnownIsolatedFromAbove())
75           return;
76         llvm::errs() << "Run on " << *op << "\n";
77         // Run on the current operation
78         if (failed(runPipeline(*pm, op)))
79           signalPassFailure();
80       });
81     } else {
82       // Run on the current operation
83       if (failed(runPipeline(*pm, getOperation())))
84         signalPassFailure();
85     }
86   }
87 
88   std::unique_ptr<OpPassManager> pm;
89 
90   Option<bool> runOnNestedOp{
91       *this, "run-on-nested-operations",
92       llvm::cl::desc("This will apply the pipeline on nested operations under "
93                      "the visited operation.")};
94   Option<bool> runOnParent{
95       *this, "run-on-parent",
96       llvm::cl::desc("This will apply the pipeline on the parent operation if "
97                      "it exist, this is expected to fail.")};
98   Option<std::string> pipeline{
99       *this, "dynamic-pipeline",
100       llvm::cl::desc("The pipeline description that "
101                      "will run on the filtered function.")};
102   ListOption<std::string> opNames{
103       *this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
104       llvm::cl::desc("List of function name to apply the pipeline to")};
105 };
106 } // namespace
107 
108 namespace mlir {
109 namespace test {
registerTestDynamicPipelinePass()110 void registerTestDynamicPipelinePass() {
111   PassRegistration<TestDynamicPipelinePass>(
112       "test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying "
113                                "a pipeline on a selected set of functions");
114 }
115 } // namespace test
116 } // namespace mlir
117