1 //===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
10 #include "mlir/Dialect/Linalg/Passes.h"
11 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/Vector/VectorTransforms.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "mlir/Transforms/LoopUtils.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 using namespace mlir;
22 using namespace vector;
23 
24 namespace {
25 /// A pass converting MLIR Linalg ops into Vector ops.
26 class TestConvVectorization
27     : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
28 public:
29   TestConvVectorization() = default;
TestConvVectorization(const TestConvVectorization &)30   TestConvVectorization(const TestConvVectorization &) {}
TestConvVectorization(ArrayRef<int64_t> tileSizesParam)31   explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
32     tileSizes = tileSizesParam;
33   }
34 
35   void runOnOperation() override;
36 
getDependentDialects(DialectRegistry & registry) const37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<VectorDialect>();
39     registry.insert<linalg::LinalgDialect>();
40     registry.insert<scf::SCFDialect>();
41     registry.insert<AffineDialect>();
42     registry.insert<StandardOpsDialect>();
43   }
44 
45   ListOption<int64_t> tileSizes{
46       *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
47       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
48 };
49 } // namespace
50 
runOnOperation()51 void TestConvVectorization::runOnOperation() {
52   MLIRContext *context = &getContext();
53   ModuleOp module = getOperation();
54 
55   ConversionTarget target(*context);
56   target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
57                          VectorDialect>();
58   target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
59   target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
60 
61   SmallVector<OwningRewritePatternList, 4> stage1Patterns;
62   linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
63   SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
64   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
65 
66   OwningRewritePatternList stage2Patterns =
67       linalg::getLinalgTilingCanonicalizationPatterns(context);
68   stage2Patterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
69 
70   auto stage3Transforms = [](Operation *op) {
71     PassManager pm(op->getContext());
72     pm.addPass(createLoopInvariantCodeMotionPass());
73     if (failed(pm.run(cast<ModuleOp>(op))))
74       llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
75     op->walk([](FuncOp func) {
76       promoteSingleIterationLoops(func);
77       linalg::hoistViewAllocOps(func);
78       linalg::hoistRedundantVectorTransfers(func);
79     });
80     return success();
81   };
82 
83   linalg::applyStagedPatterns(module, frozenStage1Patterns,
84                               std::move(stage2Patterns), stage3Transforms);
85 
86   //===--------------------------------------------------------------------===//
87   // Post staged patterns transforms
88   //===--------------------------------------------------------------------===//
89 
90   VectorTransformsOptions vectorTransformsOptions{
91       VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
92 
93   OwningRewritePatternList vectorTransferPatterns;
94   // Pattern is not applied because rank-reducing vector transfer is not yet
95   // supported as can be seen in splitFullAndPartialTransferPrecondition,
96   // VectorTransforms.cpp
97   vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
98       context, vectorTransformsOptions);
99   applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
100 
101   // Programmatic controlled lowering of linalg.copy and linalg.fill.
102   PassManager pm(context);
103   pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
104   if (failed(pm.run(module)))
105     llvm_unreachable("Unexpected failure in linalg to loops pass.");
106 
107   // Programmatic controlled lowering of vector.contract only.
108   OwningRewritePatternList vectorContractLoweringPatterns;
109   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
110                                          context, vectorTransformsOptions);
111   applyPatternsAndFoldGreedily(module,
112                                std::move(vectorContractLoweringPatterns));
113 
114   // Programmatic controlled lowering of vector.transfer only.
115   OwningRewritePatternList vectorToLoopsPatterns;
116   populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
117                                         VectorTransferToSCFOptions());
118   applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
119 
120   // Ensure we drop the marker in the end.
121   module.walk([](linalg::LinalgOp op) {
122     op.removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
123   });
124 }
125 
126 namespace mlir {
127 namespace test {
registerTestConvVectorization()128 void registerTestConvVectorization() {
129   PassRegistration<TestConvVectorization> testTransformPatternsPass(
130       "test-conv-vectorization", "Test vectorization of convolutions");
131 }
132 } // namespace test
133 } // namespace mlir
134