1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 #include "llvm/ADT/SetVector.h"
24
25 using namespace mlir;
26 using namespace mlir::linalg;
27
28 namespace {
29 struct TestLinalgTransforms
30 : public PassWrapper<TestLinalgTransforms, FunctionPass> {
31 TestLinalgTransforms() = default;
TestLinalgTransforms__anona231df6c0111::TestLinalgTransforms32 TestLinalgTransforms(const TestLinalgTransforms &pass) {}
33
getDependentDialects__anona231df6c0111::TestLinalgTransforms34 void getDependentDialects(DialectRegistry ®istry) const override {
35 // clang-format off
36 registry.insert<AffineDialect,
37 scf::SCFDialect,
38 StandardOpsDialect,
39 vector::VectorDialect,
40 gpu::GPUDialect>();
41 // clang-format on
42 }
43
44 void runOnFunction() override;
45
46 Option<bool> testPatterns{*this, "test-patterns",
47 llvm::cl::desc("Test a mixed set of patterns"),
48 llvm::cl::init(false)};
49 Option<bool> testMatmulToVectorPatterns1dTiling{
50 *this, "test-matmul-to-vector-patterns-tile-1d",
51 llvm::cl::desc(
52 "Test a fused pass that applies patterns from matmul to vectors via "
53 "1-d tiling"),
54 llvm::cl::init(false)};
55 Option<bool> testMatmulToVectorPatterns2dTiling{
56 *this, "test-matmul-to-vector-patterns-tile-2d",
57 llvm::cl::desc(
58 "Test a fused pass that applies patterns from matmul to vectors via "
59 "2-d tiling"),
60 llvm::cl::init(false)};
61 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
62 llvm::cl::desc("Test promotion options"),
63 llvm::cl::init(false)};
64 Option<bool> testTileAndDistributionOptions{
65 *this, "test-tile-and-distribute-options",
66 llvm::cl::desc("Test tile and distribute options"),
67 llvm::cl::init(false)};
68 Option<bool> testVectorTransferForwardingPatterns{
69 *this, "test-vector-transfer-forwarding-patterns",
70 llvm::cl::desc(
71 "Test a fused pass that forwards linalg.copy to vector.transfer"),
72 llvm::cl::init(false)};
73 Option<bool> testGenericToVectorPattern{
74 *this, "test-linalg-to-vector-patterns",
75 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
76 "in vector.contract form"),
77 llvm::cl::init(false)};
78 Option<bool> testAffineMinSCFCanonicalizationPatterns{
79 *this, "test-affine-min-scf-canonicalization-patterns",
80 llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
81 llvm::cl::init(false)};
82 };
83 } // end anonymous namespace
84
applyPatterns(FuncOp funcOp)85 static void applyPatterns(FuncOp funcOp) {
86 MLIRContext *ctx = funcOp.getContext();
87 OwningRewritePatternList patterns;
88
89 //===--------------------------------------------------------------------===//
90 // Linalg tiling patterns.
91 //===--------------------------------------------------------------------===//
92 patterns.insert<LinalgTilingPattern<MatmulOp>>(
93 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
94 LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
95 patterns.insert<LinalgTilingPattern<MatmulOp>>(
96 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
97 LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
98 patterns.insert<LinalgTilingPattern<MatmulOp>>(
99 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
100 LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
101 patterns.insert<LinalgTilingPattern<MatmulOp>>(
102 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
103 LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
104
105 patterns.insert<LinalgTilingPattern<MatvecOp>>(
106 ctx,
107 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
108 LinalgTilingLoopType::ParallelLoops),
109 LinalgMarker({}, Identifier::get("L1", ctx)));
110
111 patterns.insert<LinalgTilingPattern<DotOp>>(
112 ctx, LinalgTilingOptions().setTileSizes(8000),
113 LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
114 Identifier::get("L3", ctx),
115 Identifier::get("L2", ctx)},
116 Identifier::get("REG", ctx)));
117
118 //===--------------------------------------------------------------------===//
119 // Linalg tiling and permutation patterns.
120 //===--------------------------------------------------------------------===//
121 patterns.insert<LinalgTilingPattern<MatmulOp>>(
122 ctx,
123 LinalgTilingOptions()
124 .setTileSizes({2000, 3000, 4000})
125 .setInterchange({1, 2, 0}),
126 LinalgMarker(Identifier::get("__with_perm__", ctx),
127 Identifier::get("L2__with_perm__", ctx)));
128 patterns.insert<LinalgTilingPattern<MatmulOp>>(
129 ctx,
130 LinalgTilingOptions()
131 .setTileSizes({200, 300, 400})
132 .setInterchange({1, 0, 2}),
133 LinalgMarker(Identifier::get("L2__with_perm__", ctx),
134 Identifier::get("L1__with_perm__", ctx)));
135 patterns.insert<LinalgTilingPattern<MatmulOp>>(
136 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
137 LinalgMarker(Identifier::get("L1__with_perm__", ctx),
138 Identifier::get("REG__with_perm__", ctx)));
139
140 patterns.insert<LinalgTilingPattern<MatvecOp>>(
141 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
142 LinalgMarker(Identifier::get("__with_perm__", ctx),
143 Identifier::get("L1__with_perm__", ctx)));
144
145 patterns.insert<LinalgTilingPattern<MatmulOp>>(
146 ctx,
147 LinalgTilingOptions()
148 .setTileSizes({16, 8, 4})
149 .setInterchange({1, 2, 0})
150 .setLoopType(LinalgTilingLoopType::ParallelLoops),
151 LinalgMarker(Identifier::get("par__with_perm__", ctx),
152 Identifier::get("after_par__with_perm__", ctx)));
153
154 //===--------------------------------------------------------------------===//
155 // Linalg to loops patterns.
156 //===--------------------------------------------------------------------===//
157 patterns.insert<LinalgLoweringPattern<DotOp>>(
158 ctx,
159 /*loweringType=*/LinalgLoweringType::Loops,
160 LinalgMarker(Identifier::get("REG", ctx)));
161
162 //===--------------------------------------------------------------------===//
163 // Linalg distribution patterns.
164 //===--------------------------------------------------------------------===//
165 LinalgLoopDistributionOptions distributionOptions;
166
167 //===--------------------------------------------------------------------===//
168 // Linalg to vector contraction patterns.
169 //===--------------------------------------------------------------------===//
170 patterns.insert<LinalgVectorizationPattern<MatmulOp>,
171 LinalgVectorizationPattern<FillOp>,
172 LinalgVectorizationPattern<CopyOp>,
173 LinalgVectorizationPattern<GenericOp>>(
174 ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
175
176 //===--------------------------------------------------------------------===//
177 // Linalg generic permutation patterns.
178 //===--------------------------------------------------------------------===//
179 patterns.insert<LinalgInterchangePattern<GenericOp>>(
180 ctx,
181 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
182 LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
183 patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
184 ctx,
185 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
186 LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
187
188 //===--------------------------------------------------------------------===//
189 // Linalg subview operands promotion.
190 //===--------------------------------------------------------------------===//
191 patterns.insert<LinalgPromotionPattern<MatmulOp>>(
192 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
193 LinalgMarker(Identifier::get("_promote_views_", ctx),
194 Identifier::get("_views_promoted_", ctx)));
195 patterns.insert<LinalgPromotionPattern<MatmulOp>>(
196 ctx,
197 LinalgPromotionOptions()
198 .setOperandsToPromote({0})
199 .setUseFullTileBuffersByDefault(true),
200 LinalgMarker(Identifier::get("_promote_first_view_", ctx),
201 Identifier::get("_first_view_promoted_", ctx)));
202 patterns.insert<LinalgPromotionPattern<FillOp>>(
203 ctx,
204 LinalgPromotionOptions()
205 .setOperandsToPromote({0})
206 .setUseFullTileBuffers({true})
207 .setAlignment(32),
208 LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
209 Identifier::get("_views_aligned_promoted_", ctx)));
210
211 applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
212
213 // Drop the marker.
214 funcOp.walk([](LinalgOp op) {
215 op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
216 });
217 }
218
fillL1TilingAndMatmulToVectorPatterns(FuncOp funcOp,StringRef startMarker,SmallVectorImpl<OwningRewritePatternList> & patternsVector)219 static void fillL1TilingAndMatmulToVectorPatterns(
220 FuncOp funcOp, StringRef startMarker,
221 SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
222 MLIRContext *ctx = funcOp.getContext();
223 patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
224 ctx,
225 LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
226 LinalgMarker(Identifier::get(startMarker, ctx),
227 Identifier::get("L1", ctx))));
228
229 patternsVector.emplace_back(
230 std::make_unique<LinalgPromotionPattern<MatmulOp>>(
231 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
232 LinalgMarker(Identifier::get("L1", ctx),
233 Identifier::get("VEC", ctx))));
234
235 patternsVector.emplace_back(
236 std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
237 ctx, LinalgMarker(Identifier::get("VEC", ctx))));
238 patternsVector.back()
239 .insert<LinalgVectorizationPattern<FillOp>,
240 LinalgVectorizationPattern<CopyOp>>(ctx);
241 }
242
243 //===----------------------------------------------------------------------===//
244 // Test promotion callbacks
245 //===----------------------------------------------------------------------===//
246
247 // Allocation call back
allocCallBackFn(OpBuilder & b,SubViewOp subView,ArrayRef<Value> boundingSubViewSize,OperationFolder * folder)248 static Optional<Value> allocCallBackFn(OpBuilder &b, SubViewOp subView,
249 ArrayRef<Value> boundingSubViewSize,
250 OperationFolder *folder) {
251 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
252 return b
253 .create<AllocOp>(subView.getLoc(),
254 MemRefType::get(shape,
255 subView.getType().getElementType(),
256 /*affineMapComposition =*/{}, 3),
257 boundingSubViewSize)
258 .getResult();
259 }
260
261 // Deallocation callback
deallocCallBackFn(OpBuilder & b,Value buffer)262 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
263 b.create<DeallocOp>(buffer.getLoc(), buffer);
264 return success();
265 }
266
267 // Copy in call back
copyCallBackFn(OpBuilder & b,Value src,Value dst,bool isOutput)268 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
269 bool isOutput) {
270 auto floatType = src.getType().cast<MemRefType>().getElementType();
271 if (!floatType.isa<FloatType>())
272 return failure();
273 if (!isOutput)
274 b.create<FillOp>(
275 src.getLoc(), dst,
276 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
277 b.create<CopyOp>(src.getLoc(), src, dst);
278 return success();
279 }
280
fillPromotionCallBackPatterns(MLIRContext * ctx,OwningRewritePatternList & patterns)281 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
282 OwningRewritePatternList &patterns) {
283 patterns.insert<LinalgTilingPattern<MatmulOp>>(
284 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
285 LinalgMarker(Identifier::get("START", ctx),
286 Identifier::get("PROMOTE", ctx)));
287 patterns.insert<LinalgPromotionPattern<MatmulOp>>(
288 ctx,
289 LinalgPromotionOptions()
290 .setOperandsToPromote({0, 2})
291 .setUseFullTileBuffers({false, false})
292 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
293 .setCopyInOutFns(
294 [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
295 copyCallBackFn(b, src, dst, false);
296 return success();
297 },
298 [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
299 copyCallBackFn(b, src, dst, true);
300 return success();
301 }),
302 LinalgMarker(Identifier::get("PROMOTE", ctx)));
303 }
304
305 template <typename IdOp, typename NProcsOp>
306 static SmallVector<ProcInfo, 2>
getGpuProcIds(OpBuilder & b,Location loc,ArrayRef<Range> parallelLoopRanges)307 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
308 Type indexType = b.getIndexType();
309 SmallVector<ProcInfo, 2> procInfo(2);
310 procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
311 b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
312 procInfo[1] = {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
313 b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
314 return procInfo;
315 }
316
fillTileAndDistributePatterns(MLIRContext * context,OwningRewritePatternList & patterns)317 static void fillTileAndDistributePatterns(MLIRContext *context,
318 OwningRewritePatternList &patterns) {
319 {
320 LinalgLoopDistributionOptions cyclicNprocsEqNiters;
321 cyclicNprocsEqNiters.distributionMethod.resize(
322 2, DistributionMethod::CyclicNumProcsEqNumIters);
323 cyclicNprocsEqNiters.procInfo =
324 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
325 patterns.insert<LinalgTilingPattern<MatmulOp>>(
326 context,
327 LinalgTilingOptions()
328 .setTileSizes({8, 8, 4})
329 .setLoopType(LinalgTilingLoopType::ParallelLoops)
330 .setDistributionOptions(cyclicNprocsEqNiters),
331 LinalgMarker(Identifier::get("distribute1", context),
332 Identifier::get("after_distribute1", context)));
333 }
334
335 {
336 LinalgLoopDistributionOptions cyclicNprocsGeNiters;
337 cyclicNprocsGeNiters.distributionMethod.resize(
338 2, DistributionMethod::CyclicNumProcsGeNumIters);
339 cyclicNprocsGeNiters.procInfo =
340 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
341 patterns.insert<LinalgTilingPattern<MatmulOp>>(
342 context,
343 LinalgTilingOptions()
344 .setTileSizes({8, 8, 4})
345 .setLoopType(LinalgTilingLoopType::ParallelLoops)
346 .setDistributionOptions(cyclicNprocsGeNiters),
347 LinalgMarker(Identifier::get("distribute2", context),
348 Identifier::get("after_distribute2", context)));
349 }
350
351 {
352 LinalgLoopDistributionOptions cyclicNprocsDefault;
353 cyclicNprocsDefault.distributionMethod.resize(2,
354 DistributionMethod::Cyclic);
355 cyclicNprocsDefault.procInfo =
356 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
357 patterns.insert<LinalgTilingPattern<MatmulOp>>(
358 context,
359 LinalgTilingOptions()
360 .setTileSizes({8, 8, 4})
361 .setLoopType(LinalgTilingLoopType::ParallelLoops)
362 .setDistributionOptions(cyclicNprocsDefault),
363 LinalgMarker(Identifier::get("distribute3", context),
364 Identifier::get("after_distribute3", context)));
365 }
366
367 {
368 LinalgLoopDistributionOptions cyclicNprocsMixed1;
369 cyclicNprocsMixed1.distributionMethod = {
370 DistributionMethod::CyclicNumProcsEqNumIters,
371 DistributionMethod::CyclicNumProcsGeNumIters};
372 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
373 patterns.insert<LinalgTilingPattern<MatmulOp>>(
374 context,
375 LinalgTilingOptions()
376 .setTileSizes({8, 8, 4})
377 .setLoopType(LinalgTilingLoopType::ParallelLoops)
378 .setDistributionOptions(cyclicNprocsMixed1),
379 LinalgMarker(Identifier::get("distribute4", context),
380 Identifier::get("after_distribute4", context)));
381 }
382
383 {
384 LinalgLoopDistributionOptions cyclicNprocsMixed2;
385 cyclicNprocsMixed2.distributionMethod = {
386 DistributionMethod::CyclicNumProcsGeNumIters,
387 DistributionMethod::Cyclic};
388 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
389 patterns.insert<LinalgTilingPattern<MatmulOp>>(
390 context,
391 LinalgTilingOptions()
392 .setTileSizes({8, 8, 4})
393 .setLoopType(LinalgTilingLoopType::ParallelLoops)
394 .setDistributionOptions(cyclicNprocsMixed2),
395 LinalgMarker(Identifier::get("distribute5", context),
396 Identifier::get("after_distribute5", context)));
397 }
398
399 {
400 LinalgLoopDistributionOptions cyclicNprocsMixed3;
401 cyclicNprocsMixed3.distributionMethod = {
402 DistributionMethod::Cyclic,
403 DistributionMethod::CyclicNumProcsEqNumIters};
404 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
405
406 patterns.insert<LinalgTilingPattern<MatmulOp>>(
407 context,
408 LinalgTilingOptions()
409 .setTileSizes({8, 8, 4})
410 .setLoopType(LinalgTilingLoopType::ParallelLoops)
411 .setDistributionOptions(cyclicNprocsMixed3),
412 LinalgMarker(Identifier::get("distribute6", context),
413 Identifier::get("after_distribute6", context)));
414 }
415
416 {
417 LinalgLoopDistributionOptions cyclicNprocsEqNiters;
418 cyclicNprocsEqNiters.distributionMethod.resize(2,
419 DistributionMethod::Cyclic);
420 cyclicNprocsEqNiters.procInfo =
421 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
422 patterns.insert<LinalgTilingPattern<MatmulOp>>(
423 context,
424 LinalgTilingOptions()
425 .setTileSizes({8, 8, 4})
426 .setLoopType(LinalgTilingLoopType::Loops)
427 .setDistributionOptions(cyclicNprocsEqNiters),
428 LinalgMarker(Identifier::get("tensors_distribute1", context),
429 Identifier::get("tensors_after_distribute1", context)));
430 }
431 }
432
433 static void
applyMatmulToVectorPatterns(FuncOp funcOp,bool testMatmulToVectorPatterns1dTiling,bool testMatmulToVectorPatterns2dTiling)434 applyMatmulToVectorPatterns(FuncOp funcOp,
435 bool testMatmulToVectorPatterns1dTiling,
436 bool testMatmulToVectorPatterns2dTiling) {
437 MLIRContext *ctx = funcOp.getContext();
438 SmallVector<OwningRewritePatternList, 4> stage1Patterns;
439 if (testMatmulToVectorPatterns1dTiling) {
440 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
441 stage1Patterns);
442 } else if (testMatmulToVectorPatterns2dTiling) {
443 stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
444 ctx,
445 LinalgTilingOptions()
446 .setTileSizes({768, 264, 768})
447 .setInterchange({1, 2, 0}),
448 LinalgMarker(Identifier::get("START", ctx),
449 Identifier::get("L2", ctx))));
450 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
451 stage1Patterns);
452 }
453 SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
454 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
455 FrozenRewritePatternList stage2Patterns =
456 getLinalgTilingCanonicalizationPatterns(ctx);
457 applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns));
458 }
459
applyVectorTransferForwardingPatterns(FuncOp funcOp)460 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
461 OwningRewritePatternList forwardPattern;
462 forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
463 forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
464 applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
465 }
466
applyLinalgToVectorPatterns(FuncOp funcOp)467 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
468 OwningRewritePatternList patterns;
469 patterns.insert<
470 LinalgVectorizationPattern<BatchMatmulOp>,
471 LinalgVectorizationPattern<MatmulOp>,
472 LinalgVectorizationPattern<MatvecOp>,
473 LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
474 LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,
475 LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
476 applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
477 }
478
applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp)479 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
480 OwningRewritePatternList foldPattern;
481 foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
482 FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
483
484 // Explicitly walk and apply the pattern locally to avoid more general folding
485 // on the rest of the IR.
486 funcOp.walk([&frozenPatterns](AffineMinOp minOp) {
487 applyOpPatternsAndFold(minOp, frozenPatterns);
488 });
489 }
490 /// Apply transformations specified as patterns.
runOnFunction()491 void TestLinalgTransforms::runOnFunction() {
492 auto lambda = [&](void *) {
493 getFunction().walk([](LinalgOp op) {
494 op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
495 });
496 };
497 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
498
499 if (testPromotionOptions) {
500 OwningRewritePatternList patterns;
501 fillPromotionCallBackPatterns(&getContext(), patterns);
502 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
503 return;
504 }
505 if (testTileAndDistributionOptions) {
506 OwningRewritePatternList patterns;
507 fillTileAndDistributePatterns(&getContext(), patterns);
508 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
509 return;
510 }
511 if (testPatterns)
512 return applyPatterns(getFunction());
513 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
514 return applyMatmulToVectorPatterns(getFunction(),
515 testMatmulToVectorPatterns1dTiling,
516 testMatmulToVectorPatterns2dTiling);
517 if (testVectorTransferForwardingPatterns)
518 return applyVectorTransferForwardingPatterns(getFunction());
519 if (testGenericToVectorPattern)
520 return applyLinalgToVectorPatterns(getFunction());
521 if (testAffineMinSCFCanonicalizationPatterns)
522 return applyAffineMinSCFCanonicalizationPatterns(getFunction());
523 }
524
525 namespace mlir {
526 namespace test {
registerTestLinalgTransforms()527 void registerTestLinalgTransforms() {
528 PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
529 "test-linalg-transform-patterns",
530 "Test Linalg transformation patterns by applying them greedily.");
531 }
532 } // namespace test
533 } // namespace mlir
534