1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "mlir/Transforms/Passes.h"
19
20 using namespace mlir;
21 using namespace mlir::linalg;
22
23 namespace {
24 struct TestLinalgFusionTransforms
25 : public PassWrapper<TestLinalgFusionTransforms, FunctionPass> {
26 TestLinalgFusionTransforms() = default;
TestLinalgFusionTransforms__anoneb0e8f600111::TestLinalgFusionTransforms27 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
28
getDependentDialects__anoneb0e8f600111::TestLinalgFusionTransforms29 void getDependentDialects(DialectRegistry ®istry) const override {
30 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
31 StandardOpsDialect>();
32 }
33
34 void runOnFunction() override;
35 };
36 } // namespace
37
fillFusionPatterns(MLIRContext * context,const LinalgDependenceGraph & dependenceGraph,OwningRewritePatternList & patterns)38 static void fillFusionPatterns(MLIRContext *context,
39 const LinalgDependenceGraph &dependenceGraph,
40 OwningRewritePatternList &patterns) {
41 patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
42 LinalgTileAndFusePattern<ConvOp>>(
43 context, dependenceGraph,
44 LinalgTilingOptions()
45 .setTileSizes({32, 64, 16})
46 .setLoopType(LinalgTilingLoopType::ParallelLoops),
47 LinalgFusionOptions().setIndicesToFuse({2}),
48 LinalgMarker(Identifier::get("basic_fusion", context),
49 Identifier::get("after_basic_fusion", context)),
50 LinalgMarker(ArrayRef<Identifier>(),
51 Identifier::get("after_basic_fusion_producer", context)),
52 LinalgMarker(ArrayRef<Identifier>(),
53 Identifier::get("after_basic_fusion_original", context)));
54
55 patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
56 context, dependenceGraph,
57 LinalgTilingOptions()
58 .setTileSizes({32, 64, 16})
59 .setLoopType(LinalgTilingLoopType::ParallelLoops),
60 LinalgFusionOptions().setIndicesToFuse({0}),
61 LinalgMarker(Identifier::get("lhs_fusion", context),
62 Identifier::get("after_lhs_fusion", context)),
63 LinalgMarker(ArrayRef<Identifier>(),
64 Identifier::get("after_lhs_fusion_producer", context)),
65 LinalgMarker(ArrayRef<Identifier>(),
66 Identifier::get("after_lhs_fusion_original", context)));
67
68 patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
69 context, dependenceGraph,
70 LinalgTilingOptions()
71 .setTileSizes({32, 64, 16})
72 .setLoopType(LinalgTilingLoopType::ParallelLoops),
73 LinalgFusionOptions().setIndicesToFuse({1}),
74 LinalgMarker(Identifier::get("rhs_fusion", context),
75 Identifier::get("after_rhs_fusion", context)),
76 LinalgMarker(ArrayRef<Identifier>(),
77 Identifier::get("after_rhs_fusion_producer", context)),
78 LinalgMarker(ArrayRef<Identifier>(),
79 Identifier::get("after_rhs_fusion_original", context)));
80
81 patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
82 context, dependenceGraph,
83 LinalgTilingOptions()
84 .setTileSizes({32, 64, 16})
85 .setLoopType(LinalgTilingLoopType::ParallelLoops),
86 LinalgFusionOptions().setIndicesToFuse({0, 2}),
87 LinalgMarker(Identifier::get("two_operand_fusion", context),
88 Identifier::get("after_two_operand_fusion", context)),
89 LinalgMarker(
90 ArrayRef<Identifier>(),
91 Identifier::get("after_two_operand_fusion_producer", context)),
92 LinalgMarker(
93 ArrayRef<Identifier>(),
94 Identifier::get("after_two_operand_fusion_original", context)));
95
96 patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
97 context, dependenceGraph,
98 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
99 LinalgTilingLoopType::ParallelLoops),
100 LinalgFusionOptions().setIndicesToFuse({0, 1}),
101 LinalgMarker(Identifier::get("transpose_fusion", context),
102 Identifier::get("after_transpose_fusion", context)),
103 LinalgMarker(ArrayRef<Identifier>(),
104 Identifier::get("after_transpose_fusion_producer", context)),
105 LinalgMarker(
106 ArrayRef<Identifier>(),
107 Identifier::get("after_transpose_fusion_original", context)));
108 }
109
applyFusionPatterns(MLIRContext * context,FuncOp funcOp)110 static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
111 OwningRewritePatternList fusionPatterns;
112 Aliases alias;
113 LinalgDependenceGraph dependenceGraph =
114 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
115 fillFusionPatterns(context, dependenceGraph, fusionPatterns);
116 applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
117 }
118
runOnFunction()119 void TestLinalgFusionTransforms::runOnFunction() {
120 applyFusionPatterns(&getContext(), getFunction());
121 }
122
fuseLinalgOpsGreedily(FuncOp f)123 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
124 OpBuilder b(f);
125 DenseSet<Operation *> eraseSet;
126
127 // Save original Linalg ops, we only want to make a pass over those.
128 SmallVector<LinalgOp, 8> linalgOps;
129 f.walk([&](LinalgOp op) {
130 // TODO: support multi-results.
131 if (op->getNumResults() <= 1)
132 linalgOps.push_back(op);
133 });
134
135 // Tile and Fuse for tensors inputs (TODO: all tensor operands).
136 bool changed = false;
137 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
138 for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
139 if (en.value().getType().isa<MemRefType>()) {
140 // TODO: LinalgDependenceGraph should be able to update itself.
141 // The current naive and expensive reconstruction of the graph should be
142 // removed.
143 linalg::Aliases aliases;
144 linalg::LinalgDependenceGraph graph(aliases, linalgOps);
145 if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
146 auto *originalOp = info->originalProducer.getOperation();
147 eraseSet.insert(originalOp);
148 auto *originalOpInLinalgOpsVector =
149 std::find(linalgOps.begin(), linalgOps.end(), originalOp);
150 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
151 changed = true;
152 }
153 } else {
154 assert(en.value().getType().isa<RankedTensorType>());
155 // Tile and Fuse tensor input (TODO: init_tensors too).
156 if (en.index() >= linalgOp.getNumInputs())
157 continue;
158 if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
159 auto *originalOp = info->originalProducer.getOperation();
160 auto *originalOpInLinalgOpsVector =
161 std::find(linalgOps.begin(), linalgOps.end(), originalOp);
162 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
163 // Don't mark for erasure in the tensor case, let DCE handle this.
164 changed = true;
165 }
166 }
167 }
168 }
169 // The `fuseProducerOfBuffer` function performs structural checks and in
170 // particular that no covering read or write exist between the consumer and
171 // the producer. As a consequence, the only fusions that may occur preserve
172 // subsequent dependences and are guaranteed by construction to produce the
173 // whole view. We may thus erase the producer once it is fused.
174 for (auto *e : eraseSet)
175 e->erase();
176
177 return changed ? success() : failure();
178 }
179
180 namespace {
181 struct TestLinalgGreedyFusion
182 : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
runOnFunction__anoneb0e8f600311::TestLinalgGreedyFusion183 void runOnFunction() override {
184 MLIRContext *context = &getContext();
185 OwningRewritePatternList patterns =
186 linalg::getLinalgTilingCanonicalizationPatterns(context);
187 patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
188 FrozenRewritePatternList frozenPatterns(std::move(patterns));
189 while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
190 applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
191 PassManager pm(context);
192 pm.addPass(createLoopInvariantCodeMotionPass());
193 pm.addPass(createCanonicalizerPass());
194 pm.addPass(createCSEPass());
195 LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>());
196 if (failed(res))
197 this->signalPassFailure();
198 }
199 }
200 };
201
202 /// Pass to test tile and fuse of sequence of operations. Intended only for
203 /// testing.
204 struct TestLinalgTileAndFuseSequencePass
205 : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
206 TestLinalgTileAndFuseSequencePass() = default;
TestLinalgTileAndFuseSequencePass__anoneb0e8f600311::TestLinalgTileAndFuseSequencePass207 TestLinalgTileAndFuseSequencePass(
208 const TestLinalgTileAndFuseSequencePass &pass){};
209
210 ListOption<int64_t> tileSizes{
211 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
212 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
213
getDependentDialects__anoneb0e8f600311::TestLinalgTileAndFuseSequencePass214 void getDependentDialects(DialectRegistry ®istry) const override {
215 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
216 }
217
runOnFunction__anoneb0e8f600311::TestLinalgTileAndFuseSequencePass218 void runOnFunction() override {
219 FuncOp funcOp = getOperation();
220 auto &blocks = funcOp.getBody().getBlocks();
221 if (!llvm::hasSingleElement(blocks)) {
222 return;
223 }
224 SmallVector<LinalgOp, 2> linalgOps =
225 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
226 Aliases aliases;
227 LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
228 OpBuilder builder(funcOp.getContext());
229 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
230 builder, linalgOps, dependenceGraph,
231 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
232 LinalgTilingLoopType::ParallelLoops));
233 if (!tileAndFuseOps)
234 return signalPassFailure();
235 for (auto op : linalgOps)
236 op.erase();
237 }
238 };
239 } // namespace
240
241 namespace mlir {
242 namespace test {
registerTestLinalgFusionTransforms()243 void registerTestLinalgFusionTransforms() {
244 PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
245 "test-linalg-fusion-transform-patterns",
246 "Test Linalg fusion transformation patterns by applying them greedily.");
247 }
registerTestLinalgGreedyFusion()248 void registerTestLinalgGreedyFusion() {
249 PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
250 "test-linalg-greedy-fusion",
251 "Test Linalg fusion by applying a greedy test transformation.");
252 }
registerTestLinalgTileAndFuseSequencePass()253 void registerTestLinalgTileAndFuseSequencePass() {
254 PassRegistration<TestLinalgTileAndFuseSequencePass>
255 testTileAndFuseSequencePass(
256 "test-linalg-tile-and-fuse",
257 "Test Linalg tiling and fusion of a sequence of Linalg operations.");
258 }
259
260 } // namespace test
261 } // namespace mlir
262