1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <type_traits>
10
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21 using namespace mlir;
22 using namespace mlir::vector;
23 namespace {
24
25 struct TestVectorToVectorConversion
26 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
runOnFunction__anonac398dc80111::TestVectorToVectorConversion27 void runOnFunction() override {
28 OwningRewritePatternList patterns;
29 auto *ctx = &getContext();
30 patterns.insert<UnrollVectorPattern>(
31 ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
32 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
33 populateVectorToVectorTransformationPatterns(patterns, ctx);
34 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
35 }
36
37 private:
38 // Return the target shape based on op type.
getShape__anonac398dc80111::TestVectorToVectorConversion39 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
40 if (isa<AddFOp>(op))
41 return SmallVector<int64_t, 4>(2, 2);
42 if (isa<vector::ContractionOp>(op))
43 return SmallVector<int64_t, 4>(3, 2);
44 return llvm::None;
45 }
46 };
47
48 struct TestVectorSlicesConversion
49 : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
runOnFunction__anonac398dc80111::TestVectorSlicesConversion50 void runOnFunction() override {
51 OwningRewritePatternList patterns;
52 populateVectorSlicesLoweringPatterns(patterns, &getContext());
53 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
54 }
55 };
56
57 struct TestVectorContractionConversion
58 : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
59 TestVectorContractionConversion() = default;
TestVectorContractionConversion__anonac398dc80111::TestVectorContractionConversion60 TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
61 }
62
63 Option<bool> lowerToFlatMatrix{
64 *this, "vector-lower-matrix-intrinsics",
65 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
66 llvm::cl::init(false)};
67 Option<bool> lowerToFlatTranspose{
68 *this, "vector-flat-transpose",
69 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
70 llvm::cl::init(false)};
71 Option<bool> lowerToOuterProduct{
72 *this, "vector-outerproduct",
73 llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
74 llvm::cl::init(false)};
75 Option<bool> lowerToFilterOuterProduct{
76 *this, "vector-filter-outerproduct",
77 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
78 "vectors of size 4."),
79 llvm::cl::init(false)};
80
runOnFunction__anonac398dc80111::TestVectorContractionConversion81 void runOnFunction() override {
82 OwningRewritePatternList patterns;
83
84 // Test on one pattern in isolation.
85 if (lowerToOuterProduct) {
86 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
87 VectorTransformsOptions options{lowering};
88 patterns.insert<ContractionOpToOuterProductOpLowering>(options,
89 &getContext());
90 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
91 return;
92 }
93
94 // Test on one pattern in isolation.
95 if (lowerToFilterOuterProduct) {
96 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
97 VectorTransformsOptions options{lowering};
98 patterns.insert<ContractionOpToOuterProductOpLowering>(
99 options, &getContext(), [](vector::ContractionOp op) {
100 // Only lowers vector.contract where the lhs as a type vector<MxNx?>
101 // where M is not 4.
102 if (op.getRhsType().getShape()[0] == 4)
103 return failure();
104 return success();
105 });
106 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
107 return;
108 }
109
110 // Test on all contract lowering patterns.
111 VectorContractLowering contractLowering = VectorContractLowering::Dot;
112 if (lowerToFlatMatrix)
113 contractLowering = VectorContractLowering::Matmul;
114 VectorTransposeLowering transposeLowering =
115 VectorTransposeLowering::EltWise;
116 if (lowerToFlatTranspose)
117 transposeLowering = VectorTransposeLowering::Flat;
118 VectorTransformsOptions options{contractLowering, transposeLowering};
119 populateVectorContractLoweringPatterns(patterns, &getContext(), options);
120 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
121 }
122 };
123
124 struct TestVectorUnrollingPatterns
125 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
126 TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonac398dc80111::TestVectorUnrollingPatterns127 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anonac398dc80111::TestVectorUnrollingPatterns128 void runOnFunction() override {
129 MLIRContext *ctx = &getContext();
130 OwningRewritePatternList patterns;
131 patterns.insert<UnrollVectorPattern>(
132 ctx, UnrollVectorOptions()
133 .setNativeShape(ArrayRef<int64_t>{2, 2})
134 .setFilterConstraint(
135 [](Operation *op) { return success(isa<AddFOp>(op)); }));
136
137 if (unrollBasedOnType) {
138 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
139 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
140 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
141 SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
142 if (auto floatType = contractOp.getLhsType()
143 .getElementType()
144 .dyn_cast<FloatType>()) {
145 if (floatType.getWidth() == 16) {
146 nativeShape[2] = 4;
147 }
148 }
149 return nativeShape;
150 };
151 patterns.insert<UnrollVectorPattern>(
152 ctx, UnrollVectorOptions()
153 .setNativeShapeFn(nativeShapeFn)
154 .setFilterConstraint([](Operation *op) {
155 return success(isa<ContractionOp>(op));
156 }));
157 } else {
158 patterns.insert<UnrollVectorPattern>(
159 ctx, UnrollVectorOptions()
160 .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
161 .setFilterConstraint([](Operation *op) {
162 return success(isa<ContractionOp>(op));
163 }));
164 }
165 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
166 populateVectorToVectorTransformationPatterns(patterns, ctx);
167 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
168 }
169
170 Option<bool> unrollBasedOnType{
171 *this, "unroll-based-on-type",
172 llvm::cl::desc("Set the unroll factor based on type of the operation"),
173 llvm::cl::init(false)};
174 };
175
176 struct TestVectorDistributePatterns
177 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
178 TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonac398dc80111::TestVectorDistributePatterns179 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anonac398dc80111::TestVectorDistributePatterns180 void getDependentDialects(DialectRegistry ®istry) const override {
181 registry.insert<VectorDialect>();
182 registry.insert<AffineDialect>();
183 }
184 ListOption<int32_t> multiplicity{
185 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
186 llvm::cl::desc("Set the multiplicity used for distributing vector")};
187
runOnFunction__anonac398dc80111::TestVectorDistributePatterns188 void runOnFunction() override {
189 MLIRContext *ctx = &getContext();
190 OwningRewritePatternList patterns;
191 FuncOp func = getFunction();
192 func.walk([&](AddFOp op) {
193 OpBuilder builder(op);
194 if (auto vecType = op.getType().dyn_cast<VectorType>()) {
195 SmallVector<int64_t, 2> mul;
196 SmallVector<AffineExpr, 2> perm;
197 SmallVector<Value, 2> ids;
198 unsigned count = 0;
199 // Remove the multiplicity of 1 and calculate the affine map based on
200 // the multiplicity.
201 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
202 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
203 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
204 mul.push_back(m[i]);
205 ids.push_back(func.getArgument(count++));
206 perm.push_back(getAffineDimExpr(i, ctx));
207 }
208 }
209 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
210 perm, ctx);
211 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
212 builder, op.getOperation(), ids, mul, map);
213 if (ops.hasValue()) {
214 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
215 op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
216 extractOp);
217 }
218 }
219 });
220 patterns.insert<PointwiseExtractPattern>(ctx);
221 populateVectorToVectorTransformationPatterns(patterns, ctx);
222 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
223 }
224 };
225
226 struct TestVectorToLoopPatterns
227 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
228 TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonac398dc80111::TestVectorToLoopPatterns229 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anonac398dc80111::TestVectorToLoopPatterns230 void getDependentDialects(DialectRegistry ®istry) const override {
231 registry.insert<VectorDialect>();
232 registry.insert<AffineDialect>();
233 }
234 Option<int32_t> multiplicity{
235 *this, "distribution-multiplicity",
236 llvm::cl::desc("Set the multiplicity used for distributing vector"),
237 llvm::cl::init(32)};
runOnFunction__anonac398dc80111::TestVectorToLoopPatterns238 void runOnFunction() override {
239 MLIRContext *ctx = &getContext();
240 OwningRewritePatternList patterns;
241 FuncOp func = getFunction();
242 func.walk([&](AddFOp op) {
243 // Check that the operation type can be broken down into a loop.
244 VectorType type = op.getType().dyn_cast<VectorType>();
245 if (!type || type.getRank() != 1 ||
246 type.getNumElements() % multiplicity != 0)
247 return mlir::WalkResult::advance();
248 auto filterAlloc = [](Operation *op) {
249 if (isa<ConstantOp, AllocOp, CallOp>(op))
250 return false;
251 return true;
252 };
253 auto dependentOps = getSlice(op, filterAlloc);
254 // Create a loop and move instructions from the Op slice into the loop.
255 OpBuilder builder(op);
256 auto zero = builder.create<ConstantOp>(
257 op.getLoc(), builder.getIndexType(),
258 builder.getIntegerAttr(builder.getIndexType(), 0));
259 auto one = builder.create<ConstantOp>(
260 op.getLoc(), builder.getIndexType(),
261 builder.getIntegerAttr(builder.getIndexType(), 1));
262 auto numIter = builder.create<ConstantOp>(
263 op.getLoc(), builder.getIndexType(),
264 builder.getIntegerAttr(builder.getIndexType(), multiplicity));
265 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
266 for (Operation *it : dependentOps) {
267 it->moveBefore(forOp.getBody()->getTerminator());
268 }
269 auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
270 // break up the original op and let the patterns propagate.
271 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
272 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
273 map);
274 if (ops.hasValue()) {
275 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
276 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
277 }
278 return mlir::WalkResult::interrupt();
279 });
280 patterns.insert<PointwiseExtractPattern>(ctx);
281 populateVectorToVectorTransformationPatterns(patterns, ctx);
282 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
283 }
284 };
285
286 struct TestVectorTransferUnrollingPatterns
287 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anonac398dc80111::TestVectorTransferUnrollingPatterns288 void getDependentDialects(DialectRegistry ®istry) const override {
289 registry.insert<AffineDialect>();
290 }
runOnFunction__anonac398dc80111::TestVectorTransferUnrollingPatterns291 void runOnFunction() override {
292 MLIRContext *ctx = &getContext();
293 OwningRewritePatternList patterns;
294 patterns.insert<UnrollVectorPattern>(
295 ctx,
296 UnrollVectorOptions()
297 .setNativeShape(ArrayRef<int64_t>{2, 2})
298 .setFilterConstraint([](Operation *op) {
299 return success(
300 isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
301 }));
302 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
303 populateVectorToVectorTransformationPatterns(patterns, ctx);
304 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
305 }
306 };
307
308 struct TestVectorTransferFullPartialSplitPatterns
309 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
310 FunctionPass> {
311 TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonac398dc80111::TestVectorTransferFullPartialSplitPatterns312 TestVectorTransferFullPartialSplitPatterns(
313 const TestVectorTransferFullPartialSplitPatterns &pass) {}
314
getDependentDialects__anonac398dc80111::TestVectorTransferFullPartialSplitPatterns315 void getDependentDialects(DialectRegistry ®istry) const override {
316 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
317 }
318
319 Option<bool> useLinalgOps{
320 *this, "use-linalg-copy",
321 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
322 "linalg.copy operations."),
323 llvm::cl::init(false)};
runOnFunction__anonac398dc80111::TestVectorTransferFullPartialSplitPatterns324 void runOnFunction() override {
325 MLIRContext *ctx = &getContext();
326 OwningRewritePatternList patterns;
327 VectorTransformsOptions options;
328 if (useLinalgOps)
329 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
330 else
331 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
332 patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
333 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
334 }
335 };
336
337 struct TestVectorTransferOpt
338 : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
runOnFunction__anonac398dc80111::TestVectorTransferOpt339 void runOnFunction() override { transferOpflowOpt(getFunction()); }
340 };
341
342 } // end anonymous namespace
343
344 namespace mlir {
345 namespace test {
registerTestVectorConversions()346 void registerTestVectorConversions() {
347 PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
348 "test-vector-to-vector-conversion",
349 "Test conversion patterns between ops in the vector dialect");
350
351 PassRegistration<TestVectorSlicesConversion> slicesPass(
352 "test-vector-slices-conversion",
353 "Test conversion patterns that lower slices ops in the vector dialect");
354
355 PassRegistration<TestVectorContractionConversion> contractionPass(
356 "test-vector-contraction-conversion",
357 "Test conversion patterns that lower contract ops in the vector dialect");
358
359 PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
360 "test-vector-unrolling-patterns",
361 "Test conversion patterns to unroll contract ops in the vector dialect");
362
363 PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
364 "test-vector-transfer-unrolling-patterns",
365 "Test conversion patterns to unroll transfer ops in the vector dialect");
366
367 PassRegistration<TestVectorTransferFullPartialSplitPatterns>
368 vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
369 "Test conversion patterns to split "
370 "transfer ops via scf.if + linalg ops");
371 PassRegistration<TestVectorDistributePatterns> distributePass(
372 "test-vector-distribute-patterns",
373 "Test conversion patterns to distribute vector ops in the vector "
374 "dialect");
375 PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
376 "test-vector-to-forloop",
377 "Test conversion patterns to break up a vector op into a for loop");
378 PassRegistration<TestVectorTransferOpt> transferOpOpt(
379 "test-vector-transferop-opt",
380 "Test optimization transformations for transfer ops");
381 }
382 } // namespace test
383 } // namespace mlir
384