1 //===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test the dynamic pipeline feature.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/LoopUtils.h"
18 #include "mlir/Transforms/Passes.h"
19
20 using namespace mlir;
21
22 namespace {
23
24 class TestDynamicPipelinePass
25 : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
26 public:
getDependentDialects(DialectRegistry & registry) const27 void getDependentDialects(DialectRegistry ®istry) const override {
28 OpPassManager pm(ModuleOp::getOperationName(),
29 OpPassManager::Nesting::Implicit);
30 parsePassPipeline(pipeline, pm, llvm::errs());
31 pm.getDependentDialects(registry);
32 }
33
TestDynamicPipelinePass()34 TestDynamicPipelinePass(){};
TestDynamicPipelinePass(const TestDynamicPipelinePass &)35 TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
36
runOnOperation()37 void runOnOperation() override {
38 llvm::errs() << "Dynamic execute '" << pipeline << "' on "
39 << getOperation()->getName() << "\n";
40 if (pipeline.empty()) {
41 llvm::errs() << "Empty pipeline\n";
42 return;
43 }
44 auto symbolOp = dyn_cast<SymbolOpInterface>(getOperation());
45 if (!symbolOp) {
46 getOperation()->emitWarning()
47 << "Ignoring because not implementing SymbolOpInterface\n";
48 return;
49 }
50
51 auto opName = symbolOp.getName();
52 if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
53 llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
54 return;
55 }
56 if (!pm) {
57 pm = std::make_unique<OpPassManager>(
58 getOperation()->getName().getIdentifier(),
59 OpPassManager::Nesting::Implicit);
60 parsePassPipeline(pipeline, *pm, llvm::errs());
61 }
62
63 // Check that running on the parent operation always immediately fails.
64 if (runOnParent) {
65 if (getOperation()->getParentOp())
66 if (!failed(runPipeline(*pm, getOperation()->getParentOp())))
67 signalPassFailure();
68 return;
69 }
70
71 if (runOnNestedOp) {
72 llvm::errs() << "Run on nested op\n";
73 getOperation()->walk([&](Operation *op) {
74 if (op == getOperation() || !op->isKnownIsolatedFromAbove())
75 return;
76 llvm::errs() << "Run on " << *op << "\n";
77 // Run on the current operation
78 if (failed(runPipeline(*pm, op)))
79 signalPassFailure();
80 });
81 } else {
82 // Run on the current operation
83 if (failed(runPipeline(*pm, getOperation())))
84 signalPassFailure();
85 }
86 }
87
88 std::unique_ptr<OpPassManager> pm;
89
90 Option<bool> runOnNestedOp{
91 *this, "run-on-nested-operations",
92 llvm::cl::desc("This will apply the pipeline on nested operations under "
93 "the visited operation.")};
94 Option<bool> runOnParent{
95 *this, "run-on-parent",
96 llvm::cl::desc("This will apply the pipeline on the parent operation if "
97 "it exist, this is expected to fail.")};
98 Option<std::string> pipeline{
99 *this, "dynamic-pipeline",
100 llvm::cl::desc("The pipeline description that "
101 "will run on the filtered function.")};
102 ListOption<std::string> opNames{
103 *this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
104 llvm::cl::desc("List of function name to apply the pipeline to")};
105 };
106 } // namespace
107
108 namespace mlir {
109 namespace test {
registerTestDynamicPipelinePass()110 void registerTestDynamicPipelinePass() {
111 PassRegistration<TestDynamicPipelinePass>(
112 "test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying "
113 "a pipeline on a selected set of functions");
114 }
115 } // namespace test
116 } // namespace mlir
117