1 //===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
10 #include "mlir/Dialect/Linalg/Passes.h"
11 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/Vector/VectorTransforms.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "mlir/Transforms/LoopUtils.h"
19 #include "mlir/Transforms/Passes.h"
20
21 using namespace mlir;
22 using namespace vector;
23
24 namespace {
25 /// A pass converting MLIR Linalg ops into Vector ops.
26 class TestConvVectorization
27 : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
28 public:
29 TestConvVectorization() = default;
TestConvVectorization(const TestConvVectorization &)30 TestConvVectorization(const TestConvVectorization &) {}
TestConvVectorization(ArrayRef<int64_t> tileSizesParam)31 explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
32 tileSizes = tileSizesParam;
33 }
34
35 void runOnOperation() override;
36
getDependentDialects(DialectRegistry & registry) const37 void getDependentDialects(DialectRegistry ®istry) const override {
38 registry.insert<VectorDialect>();
39 registry.insert<linalg::LinalgDialect>();
40 registry.insert<scf::SCFDialect>();
41 registry.insert<AffineDialect>();
42 registry.insert<StandardOpsDialect>();
43 }
44
45 ListOption<int64_t> tileSizes{
46 *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
47 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
48 };
49 } // namespace
50
runOnOperation()51 void TestConvVectorization::runOnOperation() {
52 MLIRContext *context = &getContext();
53 ModuleOp module = getOperation();
54
55 ConversionTarget target(*context);
56 target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
57 VectorDialect>();
58 target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
59 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
60
61 SmallVector<OwningRewritePatternList, 4> stage1Patterns;
62 linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
63 SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
64 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
65
66 OwningRewritePatternList stage2Patterns =
67 linalg::getLinalgTilingCanonicalizationPatterns(context);
68 stage2Patterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
69
70 auto stage3Transforms = [](Operation *op) {
71 PassManager pm(op->getContext());
72 pm.addPass(createLoopInvariantCodeMotionPass());
73 if (failed(pm.run(cast<ModuleOp>(op))))
74 llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
75 op->walk([](FuncOp func) {
76 promoteSingleIterationLoops(func);
77 linalg::hoistViewAllocOps(func);
78 linalg::hoistRedundantVectorTransfers(func);
79 });
80 return success();
81 };
82
83 linalg::applyStagedPatterns(module, frozenStage1Patterns,
84 std::move(stage2Patterns), stage3Transforms);
85
86 //===--------------------------------------------------------------------===//
87 // Post staged patterns transforms
88 //===--------------------------------------------------------------------===//
89
90 VectorTransformsOptions vectorTransformsOptions{
91 VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
92
93 OwningRewritePatternList vectorTransferPatterns;
94 // Pattern is not applied because rank-reducing vector transfer is not yet
95 // supported as can be seen in splitFullAndPartialTransferPrecondition,
96 // VectorTransforms.cpp
97 vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
98 context, vectorTransformsOptions);
99 applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
100
101 // Programmatic controlled lowering of linalg.copy and linalg.fill.
102 PassManager pm(context);
103 pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
104 if (failed(pm.run(module)))
105 llvm_unreachable("Unexpected failure in linalg to loops pass.");
106
107 // Programmatic controlled lowering of vector.contract only.
108 OwningRewritePatternList vectorContractLoweringPatterns;
109 populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
110 context, vectorTransformsOptions);
111 applyPatternsAndFoldGreedily(module,
112 std::move(vectorContractLoweringPatterns));
113
114 // Programmatic controlled lowering of vector.transfer only.
115 OwningRewritePatternList vectorToLoopsPatterns;
116 populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
117 VectorTransferToSCFOptions());
118 applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
119
120 // Ensure we drop the marker in the end.
121 module.walk([](linalg::LinalgOp op) {
122 op.removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
123 });
124 }
125
126 namespace mlir {
127 namespace test {
registerTestConvVectorization()128 void registerTestConvVectorization() {
129 PassRegistration<TestConvVectorization> testTransformPatternsPass(
130 "test-conv-vectorization", "Test vectorization of convolutions");
131 }
132 } // namespace test
133 } // namespace mlir
134