1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 namespace {
24 
25 struct TestVectorToVectorConversion
26     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
runOnFunction__anonac398dc80111::TestVectorToVectorConversion27   void runOnFunction() override {
28     OwningRewritePatternList patterns;
29     auto *ctx = &getContext();
30     patterns.insert<UnrollVectorPattern>(
31         ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
32     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
33     populateVectorToVectorTransformationPatterns(patterns, ctx);
34     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
35   }
36 
37 private:
38   // Return the target shape based on op type.
getShape__anonac398dc80111::TestVectorToVectorConversion39   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
40     if (isa<AddFOp>(op))
41       return SmallVector<int64_t, 4>(2, 2);
42     if (isa<vector::ContractionOp>(op))
43       return SmallVector<int64_t, 4>(3, 2);
44     return llvm::None;
45   }
46 };
47 
48 struct TestVectorSlicesConversion
49     : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
runOnFunction__anonac398dc80111::TestVectorSlicesConversion50   void runOnFunction() override {
51     OwningRewritePatternList patterns;
52     populateVectorSlicesLoweringPatterns(patterns, &getContext());
53     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
54   }
55 };
56 
57 struct TestVectorContractionConversion
58     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
59   TestVectorContractionConversion() = default;
TestVectorContractionConversion__anonac398dc80111::TestVectorContractionConversion60   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
61   }
62 
63   Option<bool> lowerToFlatMatrix{
64       *this, "vector-lower-matrix-intrinsics",
65       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
66       llvm::cl::init(false)};
67   Option<bool> lowerToFlatTranspose{
68       *this, "vector-flat-transpose",
69       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
70       llvm::cl::init(false)};
71   Option<bool> lowerToOuterProduct{
72       *this, "vector-outerproduct",
73       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
74       llvm::cl::init(false)};
75   Option<bool> lowerToFilterOuterProduct{
76       *this, "vector-filter-outerproduct",
77       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
78                      "vectors of size 4."),
79       llvm::cl::init(false)};
80 
runOnFunction__anonac398dc80111::TestVectorContractionConversion81   void runOnFunction() override {
82     OwningRewritePatternList patterns;
83 
84     // Test on one pattern in isolation.
85     if (lowerToOuterProduct) {
86       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
87       VectorTransformsOptions options{lowering};
88       patterns.insert<ContractionOpToOuterProductOpLowering>(options,
89                                                              &getContext());
90       applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
91       return;
92     }
93 
94     // Test on one pattern in isolation.
95     if (lowerToFilterOuterProduct) {
96       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
97       VectorTransformsOptions options{lowering};
98       patterns.insert<ContractionOpToOuterProductOpLowering>(
99           options, &getContext(), [](vector::ContractionOp op) {
100             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
101             // where M is not 4.
102             if (op.getRhsType().getShape()[0] == 4)
103               return failure();
104             return success();
105           });
106       applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
107       return;
108     }
109 
110     // Test on all contract lowering patterns.
111     VectorContractLowering contractLowering = VectorContractLowering::Dot;
112     if (lowerToFlatMatrix)
113       contractLowering = VectorContractLowering::Matmul;
114     VectorTransposeLowering transposeLowering =
115         VectorTransposeLowering::EltWise;
116     if (lowerToFlatTranspose)
117       transposeLowering = VectorTransposeLowering::Flat;
118     VectorTransformsOptions options{contractLowering, transposeLowering};
119     populateVectorContractLoweringPatterns(patterns, &getContext(), options);
120     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
121   }
122 };
123 
124 struct TestVectorUnrollingPatterns
125     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
126   TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonac398dc80111::TestVectorUnrollingPatterns127   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anonac398dc80111::TestVectorUnrollingPatterns128   void runOnFunction() override {
129     MLIRContext *ctx = &getContext();
130     OwningRewritePatternList patterns;
131     patterns.insert<UnrollVectorPattern>(
132         ctx, UnrollVectorOptions()
133                  .setNativeShape(ArrayRef<int64_t>{2, 2})
134                  .setFilterConstraint(
135                      [](Operation *op) { return success(isa<AddFOp>(op)); }));
136 
137     if (unrollBasedOnType) {
138       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
139           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
140         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
141         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
142         if (auto floatType = contractOp.getLhsType()
143                                  .getElementType()
144                                  .dyn_cast<FloatType>()) {
145           if (floatType.getWidth() == 16) {
146             nativeShape[2] = 4;
147           }
148         }
149         return nativeShape;
150       };
151       patterns.insert<UnrollVectorPattern>(
152           ctx, UnrollVectorOptions()
153                    .setNativeShapeFn(nativeShapeFn)
154                    .setFilterConstraint([](Operation *op) {
155                      return success(isa<ContractionOp>(op));
156                    }));
157     } else {
158       patterns.insert<UnrollVectorPattern>(
159           ctx, UnrollVectorOptions()
160                    .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
161                    .setFilterConstraint([](Operation *op) {
162                      return success(isa<ContractionOp>(op));
163                    }));
164     }
165     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
166     populateVectorToVectorTransformationPatterns(patterns, ctx);
167     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
168   }
169 
170   Option<bool> unrollBasedOnType{
171       *this, "unroll-based-on-type",
172       llvm::cl::desc("Set the unroll factor based on type of the operation"),
173       llvm::cl::init(false)};
174 };
175 
176 struct TestVectorDistributePatterns
177     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
178   TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonac398dc80111::TestVectorDistributePatterns179   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anonac398dc80111::TestVectorDistributePatterns180   void getDependentDialects(DialectRegistry &registry) const override {
181     registry.insert<VectorDialect>();
182     registry.insert<AffineDialect>();
183   }
184   ListOption<int32_t> multiplicity{
185       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
186       llvm::cl::desc("Set the multiplicity used for distributing vector")};
187 
runOnFunction__anonac398dc80111::TestVectorDistributePatterns188   void runOnFunction() override {
189     MLIRContext *ctx = &getContext();
190     OwningRewritePatternList patterns;
191     FuncOp func = getFunction();
192     func.walk([&](AddFOp op) {
193       OpBuilder builder(op);
194       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
195         SmallVector<int64_t, 2> mul;
196         SmallVector<AffineExpr, 2> perm;
197         SmallVector<Value, 2> ids;
198         unsigned count = 0;
199         // Remove the multiplicity of 1 and calculate the affine map based on
200         // the multiplicity.
201         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
202         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
203           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
204             mul.push_back(m[i]);
205             ids.push_back(func.getArgument(count++));
206             perm.push_back(getAffineDimExpr(i, ctx));
207           }
208         }
209         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
210                                   perm, ctx);
211         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
212             builder, op.getOperation(), ids, mul, map);
213         if (ops.hasValue()) {
214           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
215           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
216                                               extractOp);
217         }
218       }
219     });
220     patterns.insert<PointwiseExtractPattern>(ctx);
221     populateVectorToVectorTransformationPatterns(patterns, ctx);
222     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
223   }
224 };
225 
226 struct TestVectorToLoopPatterns
227     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
228   TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonac398dc80111::TestVectorToLoopPatterns229   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anonac398dc80111::TestVectorToLoopPatterns230   void getDependentDialects(DialectRegistry &registry) const override {
231     registry.insert<VectorDialect>();
232     registry.insert<AffineDialect>();
233   }
234   Option<int32_t> multiplicity{
235       *this, "distribution-multiplicity",
236       llvm::cl::desc("Set the multiplicity used for distributing vector"),
237       llvm::cl::init(32)};
runOnFunction__anonac398dc80111::TestVectorToLoopPatterns238   void runOnFunction() override {
239     MLIRContext *ctx = &getContext();
240     OwningRewritePatternList patterns;
241     FuncOp func = getFunction();
242     func.walk([&](AddFOp op) {
243       // Check that the operation type can be broken down into a loop.
244       VectorType type = op.getType().dyn_cast<VectorType>();
245       if (!type || type.getRank() != 1 ||
246           type.getNumElements() % multiplicity != 0)
247         return mlir::WalkResult::advance();
248       auto filterAlloc = [](Operation *op) {
249         if (isa<ConstantOp, AllocOp, CallOp>(op))
250           return false;
251         return true;
252       };
253       auto dependentOps = getSlice(op, filterAlloc);
254       // Create a loop and move instructions from the Op slice into the loop.
255       OpBuilder builder(op);
256       auto zero = builder.create<ConstantOp>(
257           op.getLoc(), builder.getIndexType(),
258           builder.getIntegerAttr(builder.getIndexType(), 0));
259       auto one = builder.create<ConstantOp>(
260           op.getLoc(), builder.getIndexType(),
261           builder.getIntegerAttr(builder.getIndexType(), 1));
262       auto numIter = builder.create<ConstantOp>(
263           op.getLoc(), builder.getIndexType(),
264           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
265       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
266       for (Operation *it : dependentOps) {
267         it->moveBefore(forOp.getBody()->getTerminator());
268       }
269       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
270       // break up the original op and let the patterns propagate.
271       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
272           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
273           map);
274       if (ops.hasValue()) {
275         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
276         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
277       }
278       return mlir::WalkResult::interrupt();
279     });
280     patterns.insert<PointwiseExtractPattern>(ctx);
281     populateVectorToVectorTransformationPatterns(patterns, ctx);
282     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
283   }
284 };
285 
286 struct TestVectorTransferUnrollingPatterns
287     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anonac398dc80111::TestVectorTransferUnrollingPatterns288   void getDependentDialects(DialectRegistry &registry) const override {
289     registry.insert<AffineDialect>();
290   }
runOnFunction__anonac398dc80111::TestVectorTransferUnrollingPatterns291   void runOnFunction() override {
292     MLIRContext *ctx = &getContext();
293     OwningRewritePatternList patterns;
294     patterns.insert<UnrollVectorPattern>(
295         ctx,
296         UnrollVectorOptions()
297             .setNativeShape(ArrayRef<int64_t>{2, 2})
298             .setFilterConstraint([](Operation *op) {
299               return success(
300                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
301             }));
302     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
303     populateVectorToVectorTransformationPatterns(patterns, ctx);
304     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
305   }
306 };
307 
308 struct TestVectorTransferFullPartialSplitPatterns
309     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
310                          FunctionPass> {
311   TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonac398dc80111::TestVectorTransferFullPartialSplitPatterns312   TestVectorTransferFullPartialSplitPatterns(
313       const TestVectorTransferFullPartialSplitPatterns &pass) {}
314 
getDependentDialects__anonac398dc80111::TestVectorTransferFullPartialSplitPatterns315   void getDependentDialects(DialectRegistry &registry) const override {
316     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
317   }
318 
319   Option<bool> useLinalgOps{
320       *this, "use-linalg-copy",
321       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
322                      "linalg.copy operations."),
323       llvm::cl::init(false)};
runOnFunction__anonac398dc80111::TestVectorTransferFullPartialSplitPatterns324   void runOnFunction() override {
325     MLIRContext *ctx = &getContext();
326     OwningRewritePatternList patterns;
327     VectorTransformsOptions options;
328     if (useLinalgOps)
329       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
330     else
331       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
332     patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
333     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
334   }
335 };
336 
337 struct TestVectorTransferOpt
338     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
runOnFunction__anonac398dc80111::TestVectorTransferOpt339   void runOnFunction() override { transferOpflowOpt(getFunction()); }
340 };
341 
342 } // end anonymous namespace
343 
344 namespace mlir {
345 namespace test {
registerTestVectorConversions()346 void registerTestVectorConversions() {
347   PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
348       "test-vector-to-vector-conversion",
349       "Test conversion patterns between ops in the vector dialect");
350 
351   PassRegistration<TestVectorSlicesConversion> slicesPass(
352       "test-vector-slices-conversion",
353       "Test conversion patterns that lower slices ops in the vector dialect");
354 
355   PassRegistration<TestVectorContractionConversion> contractionPass(
356       "test-vector-contraction-conversion",
357       "Test conversion patterns that lower contract ops in the vector dialect");
358 
359   PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
360       "test-vector-unrolling-patterns",
361       "Test conversion patterns to unroll contract ops in the vector dialect");
362 
363   PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
364       "test-vector-transfer-unrolling-patterns",
365       "Test conversion patterns to unroll transfer ops in the vector dialect");
366 
367   PassRegistration<TestVectorTransferFullPartialSplitPatterns>
368       vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
369                                      "Test conversion patterns to split "
370                                      "transfer ops via scf.if + linalg ops");
371   PassRegistration<TestVectorDistributePatterns> distributePass(
372       "test-vector-distribute-patterns",
373       "Test conversion patterns to distribute vector ops in the vector "
374       "dialect");
375   PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
376       "test-vector-to-forloop",
377       "Test conversion patterns to break up a vector op into a for loop");
378   PassRegistration<TestVectorTransferOpt> transferOpOpt(
379       "test-vector-transferop-opt",
380       "Test optimization transformations for transfer ops");
381 }
382 } // namespace test
383 } // namespace mlir
384