1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 namespace {
29 struct TestLinalgTransforms
30     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
31   TestLinalgTransforms() = default;
TestLinalgTransforms__anona231df6c0111::TestLinalgTransforms32   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
33 
getDependentDialects__anona231df6c0111::TestLinalgTransforms34   void getDependentDialects(DialectRegistry &registry) const override {
35     // clang-format off
36     registry.insert<AffineDialect,
37                     scf::SCFDialect,
38                     StandardOpsDialect,
39                     vector::VectorDialect,
40                     gpu::GPUDialect>();
41     // clang-format on
42   }
43 
44   void runOnFunction() override;
45 
46   Option<bool> testPatterns{*this, "test-patterns",
47                             llvm::cl::desc("Test a mixed set of patterns"),
48                             llvm::cl::init(false)};
49   Option<bool> testMatmulToVectorPatterns1dTiling{
50       *this, "test-matmul-to-vector-patterns-tile-1d",
51       llvm::cl::desc(
52           "Test a fused pass that applies patterns from matmul to vectors via "
53           "1-d tiling"),
54       llvm::cl::init(false)};
55   Option<bool> testMatmulToVectorPatterns2dTiling{
56       *this, "test-matmul-to-vector-patterns-tile-2d",
57       llvm::cl::desc(
58           "Test a fused pass that applies patterns from matmul to vectors via "
59           "2-d tiling"),
60       llvm::cl::init(false)};
61   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
62                                     llvm::cl::desc("Test promotion options"),
63                                     llvm::cl::init(false)};
64   Option<bool> testTileAndDistributionOptions{
65       *this, "test-tile-and-distribute-options",
66       llvm::cl::desc("Test tile and distribute options"),
67       llvm::cl::init(false)};
68   Option<bool> testVectorTransferForwardingPatterns{
69       *this, "test-vector-transfer-forwarding-patterns",
70       llvm::cl::desc(
71           "Test a fused pass that forwards linalg.copy to vector.transfer"),
72       llvm::cl::init(false)};
73   Option<bool> testGenericToVectorPattern{
74       *this, "test-linalg-to-vector-patterns",
75       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
76                      "in vector.contract form"),
77       llvm::cl::init(false)};
78   Option<bool> testAffineMinSCFCanonicalizationPatterns{
79       *this, "test-affine-min-scf-canonicalization-patterns",
80       llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
81       llvm::cl::init(false)};
82 };
83 } // end anonymous namespace
84 
applyPatterns(FuncOp funcOp)85 static void applyPatterns(FuncOp funcOp) {
86   MLIRContext *ctx = funcOp.getContext();
87   OwningRewritePatternList patterns;
88 
89   //===--------------------------------------------------------------------===//
90   // Linalg tiling patterns.
91   //===--------------------------------------------------------------------===//
92   patterns.insert<LinalgTilingPattern<MatmulOp>>(
93       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
94       LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
95   patterns.insert<LinalgTilingPattern<MatmulOp>>(
96       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
97       LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
98   patterns.insert<LinalgTilingPattern<MatmulOp>>(
99       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
100       LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
101   patterns.insert<LinalgTilingPattern<MatmulOp>>(
102       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
103       LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
104 
105   patterns.insert<LinalgTilingPattern<MatvecOp>>(
106       ctx,
107       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
108           LinalgTilingLoopType::ParallelLoops),
109       LinalgMarker({}, Identifier::get("L1", ctx)));
110 
111   patterns.insert<LinalgTilingPattern<DotOp>>(
112       ctx, LinalgTilingOptions().setTileSizes(8000),
113       LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
114                                         Identifier::get("L3", ctx),
115                                         Identifier::get("L2", ctx)},
116                    Identifier::get("REG", ctx)));
117 
118   //===--------------------------------------------------------------------===//
119   // Linalg tiling and permutation patterns.
120   //===--------------------------------------------------------------------===//
121   patterns.insert<LinalgTilingPattern<MatmulOp>>(
122       ctx,
123       LinalgTilingOptions()
124           .setTileSizes({2000, 3000, 4000})
125           .setInterchange({1, 2, 0}),
126       LinalgMarker(Identifier::get("__with_perm__", ctx),
127                    Identifier::get("L2__with_perm__", ctx)));
128   patterns.insert<LinalgTilingPattern<MatmulOp>>(
129       ctx,
130       LinalgTilingOptions()
131           .setTileSizes({200, 300, 400})
132           .setInterchange({1, 0, 2}),
133       LinalgMarker(Identifier::get("L2__with_perm__", ctx),
134                    Identifier::get("L1__with_perm__", ctx)));
135   patterns.insert<LinalgTilingPattern<MatmulOp>>(
136       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
137       LinalgMarker(Identifier::get("L1__with_perm__", ctx),
138                    Identifier::get("REG__with_perm__", ctx)));
139 
140   patterns.insert<LinalgTilingPattern<MatvecOp>>(
141       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
142       LinalgMarker(Identifier::get("__with_perm__", ctx),
143                    Identifier::get("L1__with_perm__", ctx)));
144 
145   patterns.insert<LinalgTilingPattern<MatmulOp>>(
146       ctx,
147       LinalgTilingOptions()
148           .setTileSizes({16, 8, 4})
149           .setInterchange({1, 2, 0})
150           .setLoopType(LinalgTilingLoopType::ParallelLoops),
151       LinalgMarker(Identifier::get("par__with_perm__", ctx),
152                    Identifier::get("after_par__with_perm__", ctx)));
153 
154   //===--------------------------------------------------------------------===//
155   // Linalg to loops patterns.
156   //===--------------------------------------------------------------------===//
157   patterns.insert<LinalgLoweringPattern<DotOp>>(
158       ctx,
159       /*loweringType=*/LinalgLoweringType::Loops,
160       LinalgMarker(Identifier::get("REG", ctx)));
161 
162   //===--------------------------------------------------------------------===//
163   // Linalg distribution patterns.
164   //===--------------------------------------------------------------------===//
165   LinalgLoopDistributionOptions distributionOptions;
166 
167   //===--------------------------------------------------------------------===//
168   // Linalg to vector contraction patterns.
169   //===--------------------------------------------------------------------===//
170   patterns.insert<LinalgVectorizationPattern<MatmulOp>,
171                   LinalgVectorizationPattern<FillOp>,
172                   LinalgVectorizationPattern<CopyOp>,
173                   LinalgVectorizationPattern<GenericOp>>(
174       ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
175 
176   //===--------------------------------------------------------------------===//
177   // Linalg generic permutation patterns.
178   //===--------------------------------------------------------------------===//
179   patterns.insert<LinalgInterchangePattern<GenericOp>>(
180       ctx,
181       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
182       LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
183   patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
184       ctx,
185       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
186       LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
187 
188   //===--------------------------------------------------------------------===//
189   // Linalg subview operands promotion.
190   //===--------------------------------------------------------------------===//
191   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
192       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
193       LinalgMarker(Identifier::get("_promote_views_", ctx),
194                    Identifier::get("_views_promoted_", ctx)));
195   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
196       ctx,
197       LinalgPromotionOptions()
198           .setOperandsToPromote({0})
199           .setUseFullTileBuffersByDefault(true),
200       LinalgMarker(Identifier::get("_promote_first_view_", ctx),
201                    Identifier::get("_first_view_promoted_", ctx)));
202   patterns.insert<LinalgPromotionPattern<FillOp>>(
203       ctx,
204       LinalgPromotionOptions()
205           .setOperandsToPromote({0})
206           .setUseFullTileBuffers({true})
207           .setAlignment(32),
208       LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
209                    Identifier::get("_views_aligned_promoted_", ctx)));
210 
211   applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
212 
213   // Drop the marker.
214   funcOp.walk([](LinalgOp op) {
215     op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
216   });
217 }
218 
fillL1TilingAndMatmulToVectorPatterns(FuncOp funcOp,StringRef startMarker,SmallVectorImpl<OwningRewritePatternList> & patternsVector)219 static void fillL1TilingAndMatmulToVectorPatterns(
220     FuncOp funcOp, StringRef startMarker,
221     SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
222   MLIRContext *ctx = funcOp.getContext();
223   patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
224       ctx,
225       LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
226       LinalgMarker(Identifier::get(startMarker, ctx),
227                    Identifier::get("L1", ctx))));
228 
229   patternsVector.emplace_back(
230       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
231           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
232           LinalgMarker(Identifier::get("L1", ctx),
233                        Identifier::get("VEC", ctx))));
234 
235   patternsVector.emplace_back(
236       std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
237           ctx, LinalgMarker(Identifier::get("VEC", ctx))));
238   patternsVector.back()
239       .insert<LinalgVectorizationPattern<FillOp>,
240               LinalgVectorizationPattern<CopyOp>>(ctx);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Test promotion callbacks
245 //===----------------------------------------------------------------------===//
246 
247 // Allocation call back
allocCallBackFn(OpBuilder & b,SubViewOp subView,ArrayRef<Value> boundingSubViewSize,OperationFolder * folder)248 static Optional<Value> allocCallBackFn(OpBuilder &b, SubViewOp subView,
249                                        ArrayRef<Value> boundingSubViewSize,
250                                        OperationFolder *folder) {
251   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
252   return b
253       .create<AllocOp>(subView.getLoc(),
254                        MemRefType::get(shape,
255                                        subView.getType().getElementType(),
256                                        /*affineMapComposition =*/{}, 3),
257                        boundingSubViewSize)
258       .getResult();
259 }
260 
261 // Deallocation callback
deallocCallBackFn(OpBuilder & b,Value buffer)262 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
263   b.create<DeallocOp>(buffer.getLoc(), buffer);
264   return success();
265 }
266 
267 // Copy in call back
copyCallBackFn(OpBuilder & b,Value src,Value dst,bool isOutput)268 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
269                                     bool isOutput) {
270   auto floatType = src.getType().cast<MemRefType>().getElementType();
271   if (!floatType.isa<FloatType>())
272     return failure();
273   if (!isOutput)
274     b.create<FillOp>(
275         src.getLoc(), dst,
276         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
277   b.create<CopyOp>(src.getLoc(), src, dst);
278   return success();
279 }
280 
fillPromotionCallBackPatterns(MLIRContext * ctx,OwningRewritePatternList & patterns)281 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
282                                           OwningRewritePatternList &patterns) {
283   patterns.insert<LinalgTilingPattern<MatmulOp>>(
284       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
285       LinalgMarker(Identifier::get("START", ctx),
286                    Identifier::get("PROMOTE", ctx)));
287   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
288       ctx,
289       LinalgPromotionOptions()
290           .setOperandsToPromote({0, 2})
291           .setUseFullTileBuffers({false, false})
292           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
293           .setCopyInOutFns(
294               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
295                 copyCallBackFn(b, src, dst, false);
296                 return success();
297               },
298               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
299                 copyCallBackFn(b, src, dst, true);
300                 return success();
301               }),
302       LinalgMarker(Identifier::get("PROMOTE", ctx)));
303 }
304 
305 template <typename IdOp, typename NProcsOp>
306 static SmallVector<ProcInfo, 2>
getGpuProcIds(OpBuilder & b,Location loc,ArrayRef<Range> parallelLoopRanges)307 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
308   Type indexType = b.getIndexType();
309   SmallVector<ProcInfo, 2> procInfo(2);
310   procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
311                  b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
312   procInfo[1] = {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
313                  b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
314   return procInfo;
315 }
316 
fillTileAndDistributePatterns(MLIRContext * context,OwningRewritePatternList & patterns)317 static void fillTileAndDistributePatterns(MLIRContext *context,
318                                           OwningRewritePatternList &patterns) {
319   {
320     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
321     cyclicNprocsEqNiters.distributionMethod.resize(
322         2, DistributionMethod::CyclicNumProcsEqNumIters);
323     cyclicNprocsEqNiters.procInfo =
324         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
325     patterns.insert<LinalgTilingPattern<MatmulOp>>(
326         context,
327         LinalgTilingOptions()
328             .setTileSizes({8, 8, 4})
329             .setLoopType(LinalgTilingLoopType::ParallelLoops)
330             .setDistributionOptions(cyclicNprocsEqNiters),
331         LinalgMarker(Identifier::get("distribute1", context),
332                      Identifier::get("after_distribute1", context)));
333   }
334 
335   {
336     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
337     cyclicNprocsGeNiters.distributionMethod.resize(
338         2, DistributionMethod::CyclicNumProcsGeNumIters);
339     cyclicNprocsGeNiters.procInfo =
340         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
341     patterns.insert<LinalgTilingPattern<MatmulOp>>(
342         context,
343         LinalgTilingOptions()
344             .setTileSizes({8, 8, 4})
345             .setLoopType(LinalgTilingLoopType::ParallelLoops)
346             .setDistributionOptions(cyclicNprocsGeNiters),
347         LinalgMarker(Identifier::get("distribute2", context),
348                      Identifier::get("after_distribute2", context)));
349   }
350 
351   {
352     LinalgLoopDistributionOptions cyclicNprocsDefault;
353     cyclicNprocsDefault.distributionMethod.resize(2,
354                                                   DistributionMethod::Cyclic);
355     cyclicNprocsDefault.procInfo =
356         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
357     patterns.insert<LinalgTilingPattern<MatmulOp>>(
358         context,
359         LinalgTilingOptions()
360             .setTileSizes({8, 8, 4})
361             .setLoopType(LinalgTilingLoopType::ParallelLoops)
362             .setDistributionOptions(cyclicNprocsDefault),
363         LinalgMarker(Identifier::get("distribute3", context),
364                      Identifier::get("after_distribute3", context)));
365   }
366 
367   {
368     LinalgLoopDistributionOptions cyclicNprocsMixed1;
369     cyclicNprocsMixed1.distributionMethod = {
370         DistributionMethod::CyclicNumProcsEqNumIters,
371         DistributionMethod::CyclicNumProcsGeNumIters};
372     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
373     patterns.insert<LinalgTilingPattern<MatmulOp>>(
374         context,
375         LinalgTilingOptions()
376             .setTileSizes({8, 8, 4})
377             .setLoopType(LinalgTilingLoopType::ParallelLoops)
378             .setDistributionOptions(cyclicNprocsMixed1),
379         LinalgMarker(Identifier::get("distribute4", context),
380                      Identifier::get("after_distribute4", context)));
381   }
382 
383   {
384     LinalgLoopDistributionOptions cyclicNprocsMixed2;
385     cyclicNprocsMixed2.distributionMethod = {
386         DistributionMethod::CyclicNumProcsGeNumIters,
387         DistributionMethod::Cyclic};
388     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
389     patterns.insert<LinalgTilingPattern<MatmulOp>>(
390         context,
391         LinalgTilingOptions()
392             .setTileSizes({8, 8, 4})
393             .setLoopType(LinalgTilingLoopType::ParallelLoops)
394             .setDistributionOptions(cyclicNprocsMixed2),
395         LinalgMarker(Identifier::get("distribute5", context),
396                      Identifier::get("after_distribute5", context)));
397   }
398 
399   {
400     LinalgLoopDistributionOptions cyclicNprocsMixed3;
401     cyclicNprocsMixed3.distributionMethod = {
402         DistributionMethod::Cyclic,
403         DistributionMethod::CyclicNumProcsEqNumIters};
404     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
405 
406     patterns.insert<LinalgTilingPattern<MatmulOp>>(
407         context,
408         LinalgTilingOptions()
409             .setTileSizes({8, 8, 4})
410             .setLoopType(LinalgTilingLoopType::ParallelLoops)
411             .setDistributionOptions(cyclicNprocsMixed3),
412         LinalgMarker(Identifier::get("distribute6", context),
413                      Identifier::get("after_distribute6", context)));
414   }
415 
416   {
417     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
418     cyclicNprocsEqNiters.distributionMethod.resize(2,
419                                                    DistributionMethod::Cyclic);
420     cyclicNprocsEqNiters.procInfo =
421         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
422     patterns.insert<LinalgTilingPattern<MatmulOp>>(
423         context,
424         LinalgTilingOptions()
425             .setTileSizes({8, 8, 4})
426             .setLoopType(LinalgTilingLoopType::Loops)
427             .setDistributionOptions(cyclicNprocsEqNiters),
428         LinalgMarker(Identifier::get("tensors_distribute1", context),
429                      Identifier::get("tensors_after_distribute1", context)));
430   }
431 }
432 
433 static void
applyMatmulToVectorPatterns(FuncOp funcOp,bool testMatmulToVectorPatterns1dTiling,bool testMatmulToVectorPatterns2dTiling)434 applyMatmulToVectorPatterns(FuncOp funcOp,
435                             bool testMatmulToVectorPatterns1dTiling,
436                             bool testMatmulToVectorPatterns2dTiling) {
437   MLIRContext *ctx = funcOp.getContext();
438   SmallVector<OwningRewritePatternList, 4> stage1Patterns;
439   if (testMatmulToVectorPatterns1dTiling) {
440     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
441                                           stage1Patterns);
442   } else if (testMatmulToVectorPatterns2dTiling) {
443     stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
444         ctx,
445         LinalgTilingOptions()
446             .setTileSizes({768, 264, 768})
447             .setInterchange({1, 2, 0}),
448         LinalgMarker(Identifier::get("START", ctx),
449                      Identifier::get("L2", ctx))));
450     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
451                                           stage1Patterns);
452   }
453   SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
454   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
455   FrozenRewritePatternList stage2Patterns =
456       getLinalgTilingCanonicalizationPatterns(ctx);
457   applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns));
458 }
459 
applyVectorTransferForwardingPatterns(FuncOp funcOp)460 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
461   OwningRewritePatternList forwardPattern;
462   forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
463   forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
464   applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
465 }
466 
applyLinalgToVectorPatterns(FuncOp funcOp)467 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
468   OwningRewritePatternList patterns;
469   patterns.insert<
470       LinalgVectorizationPattern<BatchMatmulOp>,
471       LinalgVectorizationPattern<MatmulOp>,
472       LinalgVectorizationPattern<MatvecOp>,
473       LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
474       LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,
475       LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
476   applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
477 }
478 
applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp)479 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
480   OwningRewritePatternList foldPattern;
481   foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
482   FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
483 
484   // Explicitly walk and apply the pattern locally to avoid more general folding
485   // on the rest of the IR.
486   funcOp.walk([&frozenPatterns](AffineMinOp minOp) {
487     applyOpPatternsAndFold(minOp, frozenPatterns);
488   });
489 }
490 /// Apply transformations specified as patterns.
runOnFunction()491 void TestLinalgTransforms::runOnFunction() {
492   auto lambda = [&](void *) {
493     getFunction().walk([](LinalgOp op) {
494       op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
495     });
496   };
497   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
498 
499   if (testPromotionOptions) {
500     OwningRewritePatternList patterns;
501     fillPromotionCallBackPatterns(&getContext(), patterns);
502     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
503     return;
504   }
505   if (testTileAndDistributionOptions) {
506     OwningRewritePatternList patterns;
507     fillTileAndDistributePatterns(&getContext(), patterns);
508     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
509     return;
510   }
511   if (testPatterns)
512     return applyPatterns(getFunction());
513   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
514     return applyMatmulToVectorPatterns(getFunction(),
515                                        testMatmulToVectorPatterns1dTiling,
516                                        testMatmulToVectorPatterns2dTiling);
517   if (testVectorTransferForwardingPatterns)
518     return applyVectorTransferForwardingPatterns(getFunction());
519   if (testGenericToVectorPattern)
520     return applyLinalgToVectorPatterns(getFunction());
521   if (testAffineMinSCFCanonicalizationPatterns)
522     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
523 }
524 
525 namespace mlir {
526 namespace test {
registerTestLinalgTransforms()527 void registerTestLinalgTransforms() {
528   PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
529       "test-linalg-transform-patterns",
530       "Test Linalg transformation patterns by applying them greedily.");
531 }
532 } // namespace test
533 } // namespace mlir
534