1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "mlir/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 namespace {
24 struct TestLinalgFusionTransforms
25     : public PassWrapper<TestLinalgFusionTransforms, FunctionPass> {
26   TestLinalgFusionTransforms() = default;
TestLinalgFusionTransforms__anoneb0e8f600111::TestLinalgFusionTransforms27   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
28 
getDependentDialects__anoneb0e8f600111::TestLinalgFusionTransforms29   void getDependentDialects(DialectRegistry &registry) const override {
30     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
31                     StandardOpsDialect>();
32   }
33 
34   void runOnFunction() override;
35 };
36 } // namespace
37 
fillFusionPatterns(MLIRContext * context,const LinalgDependenceGraph & dependenceGraph,OwningRewritePatternList & patterns)38 static void fillFusionPatterns(MLIRContext *context,
39                                const LinalgDependenceGraph &dependenceGraph,
40                                OwningRewritePatternList &patterns) {
41   patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
42                   LinalgTileAndFusePattern<ConvOp>>(
43       context, dependenceGraph,
44       LinalgTilingOptions()
45           .setTileSizes({32, 64, 16})
46           .setLoopType(LinalgTilingLoopType::ParallelLoops),
47       LinalgFusionOptions().setIndicesToFuse({2}),
48       LinalgMarker(Identifier::get("basic_fusion", context),
49                    Identifier::get("after_basic_fusion", context)),
50       LinalgMarker(ArrayRef<Identifier>(),
51                    Identifier::get("after_basic_fusion_producer", context)),
52       LinalgMarker(ArrayRef<Identifier>(),
53                    Identifier::get("after_basic_fusion_original", context)));
54 
55   patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
56       context, dependenceGraph,
57       LinalgTilingOptions()
58           .setTileSizes({32, 64, 16})
59           .setLoopType(LinalgTilingLoopType::ParallelLoops),
60       LinalgFusionOptions().setIndicesToFuse({0}),
61       LinalgMarker(Identifier::get("lhs_fusion", context),
62                    Identifier::get("after_lhs_fusion", context)),
63       LinalgMarker(ArrayRef<Identifier>(),
64                    Identifier::get("after_lhs_fusion_producer", context)),
65       LinalgMarker(ArrayRef<Identifier>(),
66                    Identifier::get("after_lhs_fusion_original", context)));
67 
68   patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
69       context, dependenceGraph,
70       LinalgTilingOptions()
71           .setTileSizes({32, 64, 16})
72           .setLoopType(LinalgTilingLoopType::ParallelLoops),
73       LinalgFusionOptions().setIndicesToFuse({1}),
74       LinalgMarker(Identifier::get("rhs_fusion", context),
75                    Identifier::get("after_rhs_fusion", context)),
76       LinalgMarker(ArrayRef<Identifier>(),
77                    Identifier::get("after_rhs_fusion_producer", context)),
78       LinalgMarker(ArrayRef<Identifier>(),
79                    Identifier::get("after_rhs_fusion_original", context)));
80 
81   patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
82       context, dependenceGraph,
83       LinalgTilingOptions()
84           .setTileSizes({32, 64, 16})
85           .setLoopType(LinalgTilingLoopType::ParallelLoops),
86       LinalgFusionOptions().setIndicesToFuse({0, 2}),
87       LinalgMarker(Identifier::get("two_operand_fusion", context),
88                    Identifier::get("after_two_operand_fusion", context)),
89       LinalgMarker(
90           ArrayRef<Identifier>(),
91           Identifier::get("after_two_operand_fusion_producer", context)),
92       LinalgMarker(
93           ArrayRef<Identifier>(),
94           Identifier::get("after_two_operand_fusion_original", context)));
95 
96   patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
97       context, dependenceGraph,
98       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
99           LinalgTilingLoopType::ParallelLoops),
100       LinalgFusionOptions().setIndicesToFuse({0, 1}),
101       LinalgMarker(Identifier::get("transpose_fusion", context),
102                    Identifier::get("after_transpose_fusion", context)),
103       LinalgMarker(ArrayRef<Identifier>(),
104                    Identifier::get("after_transpose_fusion_producer", context)),
105       LinalgMarker(
106           ArrayRef<Identifier>(),
107           Identifier::get("after_transpose_fusion_original", context)));
108 }
109 
applyFusionPatterns(MLIRContext * context,FuncOp funcOp)110 static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
111   OwningRewritePatternList fusionPatterns;
112   Aliases alias;
113   LinalgDependenceGraph dependenceGraph =
114       LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
115   fillFusionPatterns(context, dependenceGraph, fusionPatterns);
116   applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
117 }
118 
runOnFunction()119 void TestLinalgFusionTransforms::runOnFunction() {
120   applyFusionPatterns(&getContext(), getFunction());
121 }
122 
fuseLinalgOpsGreedily(FuncOp f)123 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
124   OpBuilder b(f);
125   DenseSet<Operation *> eraseSet;
126 
127   // Save original Linalg ops, we only want to make a pass over those.
128   SmallVector<LinalgOp, 8> linalgOps;
129   f.walk([&](LinalgOp op) {
130     // TODO: support multi-results.
131     if (op->getNumResults() <= 1)
132       linalgOps.push_back(op);
133   });
134 
135   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
136   bool changed = false;
137   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
138     for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
139       if (en.value().getType().isa<MemRefType>()) {
140         // TODO: LinalgDependenceGraph should be able to update itself.
141         // The current naive and expensive reconstruction of the graph should be
142         // removed.
143         linalg::Aliases aliases;
144         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
145         if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
146           auto *originalOp = info->originalProducer.getOperation();
147           eraseSet.insert(originalOp);
148           auto *originalOpInLinalgOpsVector =
149               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
150           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
151           changed = true;
152         }
153       } else {
154         assert(en.value().getType().isa<RankedTensorType>());
155         // Tile and Fuse tensor input (TODO: init_tensors too).
156         if (en.index() >= linalgOp.getNumInputs())
157           continue;
158         if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
159           auto *originalOp = info->originalProducer.getOperation();
160           auto *originalOpInLinalgOpsVector =
161               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
162           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
163           // Don't mark for erasure in the tensor case, let DCE handle this.
164           changed = true;
165         }
166       }
167     }
168   }
169   // The `fuseProducerOfBuffer` function performs structural checks and in
170   // particular that no covering read or write exist between the consumer and
171   // the producer. As a consequence, the only fusions that may occur preserve
172   // subsequent dependences and are guaranteed by construction to produce the
173   // whole view. We may thus erase the producer once it is fused.
174   for (auto *e : eraseSet)
175     e->erase();
176 
177   return changed ? success() : failure();
178 }
179 
180 namespace {
181 struct TestLinalgGreedyFusion
182     : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
runOnFunction__anoneb0e8f600311::TestLinalgGreedyFusion183   void runOnFunction() override {
184     MLIRContext *context = &getContext();
185     OwningRewritePatternList patterns =
186         linalg::getLinalgTilingCanonicalizationPatterns(context);
187     patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
188     FrozenRewritePatternList frozenPatterns(std::move(patterns));
189     while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
190       applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
191       PassManager pm(context);
192       pm.addPass(createLoopInvariantCodeMotionPass());
193       pm.addPass(createCanonicalizerPass());
194       pm.addPass(createCSEPass());
195       LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>());
196       if (failed(res))
197         this->signalPassFailure();
198     }
199   }
200 };
201 
202 /// Pass to test tile and fuse of sequence of operations. Intended only for
203 /// testing.
204 struct TestLinalgTileAndFuseSequencePass
205     : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
206   TestLinalgTileAndFuseSequencePass() = default;
TestLinalgTileAndFuseSequencePass__anoneb0e8f600311::TestLinalgTileAndFuseSequencePass207   TestLinalgTileAndFuseSequencePass(
208       const TestLinalgTileAndFuseSequencePass &pass){};
209 
210   ListOption<int64_t> tileSizes{
211       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
212       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
213 
getDependentDialects__anoneb0e8f600311::TestLinalgTileAndFuseSequencePass214   void getDependentDialects(DialectRegistry &registry) const override {
215     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
216   }
217 
runOnFunction__anoneb0e8f600311::TestLinalgTileAndFuseSequencePass218   void runOnFunction() override {
219     FuncOp funcOp = getOperation();
220     auto &blocks = funcOp.getBody().getBlocks();
221     if (!llvm::hasSingleElement(blocks)) {
222       return;
223     }
224     SmallVector<LinalgOp, 2> linalgOps =
225         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
226     Aliases aliases;
227     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
228     OpBuilder builder(funcOp.getContext());
229     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
230         builder, linalgOps, dependenceGraph,
231         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
232             LinalgTilingLoopType::ParallelLoops));
233     if (!tileAndFuseOps)
234       return signalPassFailure();
235     for (auto op : linalgOps)
236       op.erase();
237   }
238 };
239 } // namespace
240 
241 namespace mlir {
242 namespace test {
registerTestLinalgFusionTransforms()243 void registerTestLinalgFusionTransforms() {
244   PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
245       "test-linalg-fusion-transform-patterns",
246       "Test Linalg fusion transformation patterns by applying them greedily.");
247 }
registerTestLinalgGreedyFusion()248 void registerTestLinalgGreedyFusion() {
249   PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
250       "test-linalg-greedy-fusion",
251       "Test Linalg fusion by applying a greedy test transformation.");
252 }
registerTestLinalgTileAndFuseSequencePass()253 void registerTestLinalgTileAndFuseSequencePass() {
254   PassRegistration<TestLinalgTileAndFuseSequencePass>
255       testTileAndFuseSequencePass(
256           "test-linalg-tile-and-fuse",
257           "Test Linalg tiling and fusion of a sequence of Linalg operations.");
258 }
259 
260 } // namespace test
261 } // namespace mlir
262