1 //===- TestLinalgCodegenStrategy.cpp - Test Linalg codegen strategy -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing the Linalg codegen strategy.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22
23 #include "llvm/ADT/SetVector.h"
24
25 using namespace mlir;
26 using namespace mlir::linalg;
27
28 namespace {
29 struct TestLinalgCodegenStrategy
30 : public PassWrapper<TestLinalgCodegenStrategy, FunctionPass> {
31 TestLinalgCodegenStrategy() = default;
TestLinalgCodegenStrategy__anona12f4ea50111::TestLinalgCodegenStrategy32 TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {}
33
getDependentDialects__anona12f4ea50111::TestLinalgCodegenStrategy34 void getDependentDialects(DialectRegistry ®istry) const override {
35 // clang-format off
36 registry.insert<AffineDialect,
37 gpu::GPUDialect,
38 linalg::LinalgDialect,
39 scf::SCFDialect,
40 StandardOpsDialect,
41 vector::VectorDialect>();
42 // clang-format on
43 }
44
45 void runOnFunction() override;
46
47 ListOption<int64_t> tileSizes{*this, "tile-sizes",
48 llvm::cl::MiscFlags::CommaSeparated,
49 llvm::cl::desc("Specifies the tile sizes.")};
50 Option<bool> promote{
51 *this, "promote",
52 llvm::cl::desc("Promote the tile into a small aligned memory buffer."),
53 llvm::cl::init(false)};
54 Option<bool> promoteFullTile{
55 *this, "promote-full-tile-pad",
56 llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
57 llvm::cl::init(false)};
58 ListOption<int64_t> registerTileSizes{
59 *this, "register-tile-sizes", llvm::cl::MiscFlags::CommaSeparated,
60 llvm::cl::desc(
61 "Specifies the size of the register tile that will be used "
62 " to vectorize")};
63 Option<bool> registerPromote{
64 *this, "register-promote",
65 llvm::cl::desc(
66 "Promote the register tile into a small aligned memory buffer."),
67 llvm::cl::init(false)};
68 Option<bool> registerPromoteFullTile{
69 *this, "register-promote-full-tile-pad",
70 llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
71 llvm::cl::init(false)};
72 Option<bool> vectorize{
73 *this, "vectorize",
74 llvm::cl::desc("Rewrite the linalg op as a vector operation."),
75 llvm::cl::init(false)};
76 Option<std::string> splitVectorTransfersTo{
77 *this, "split-transfers",
78 llvm::cl::desc(
79 "Split vector transfers between slow (masked) and fast "
80 "(unmasked) variants. Possible options are:\n"
81 "\tnone: keep unsplit vector.transfer and pay the full price\n"
82 "\tlinalg-copy: use linalg.fill + linalg.copy for the slow path\n"
83 "\tvector-transfers: use extra small unmasked vector.transfer for"
84 " the slow path\n"),
85 llvm::cl::init("none")};
86 Option<std::string> vectorizeContractionTo{
87 *this, "vectorize-contraction-to",
88 llvm::cl::desc("the type of vector op to use for linalg contractions"),
89 llvm::cl::init("outerproduct")};
90 Option<bool> unrollVectorTransfers{
91 *this, "unroll-vector-transfers",
92 llvm::cl::desc("Enable full unrolling of vector.transfer operations"),
93 llvm::cl::init(false)};
94 };
95 } // end anonymous namespace
96
97 /// Apply transformations specified as patterns.
runOnFunction()98 void TestLinalgCodegenStrategy::runOnFunction() {
99 LinalgTilingOptions tilingOptions;
100 if (!tileSizes.empty())
101 tilingOptions = tilingOptions.setTileSizes(tileSizes);
102
103 LinalgTilingOptions registerTilingOptions;
104 if (!registerTileSizes.empty())
105 registerTilingOptions =
106 registerTilingOptions.setTileSizes(registerTileSizes);
107
108 vector::VectorContractLowering vectorContractLowering =
109 llvm::StringSwitch<vector::VectorContractLowering>(
110 vectorizeContractionTo.getValue())
111 .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
112 .Case("dot", vector::VectorContractLowering::Dot)
113 .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
114 .Default(vector::VectorContractLowering::OuterProduct);
115 vector::VectorTransferSplit vectorTransferSplit =
116 llvm::StringSwitch<vector::VectorTransferSplit>(
117 splitVectorTransfersTo.getValue())
118 .Case("none", vector::VectorTransferSplit::None)
119 .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
120 .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
121 .Default(vector::VectorTransferSplit::None);
122
123 CodegenStrategy strategy;
124 strategy.tileIf<MatmulOp>(!tileSizes.empty(), tilingOptions)
125 .promoteIf<MatmulOp>(promote,
126 LinalgPromotionOptions()
127 .setAlignment(16)
128 .setUseFullTileBuffersByDefault(promoteFullTile))
129 .tileIf<MatmulOp>(!registerTileSizes.empty(), registerTilingOptions)
130 .promoteIf<MatmulOp>(registerPromote, LinalgPromotionOptions()
131 .setAlignment(16)
132 .setUseFullTileBuffersByDefault(
133 registerPromoteFullTile))
134 .vectorizeIf<MatmulOp>(vectorize)
135 .setVectorTransformsOptions(
136 vector::VectorTransformsOptions()
137 .setVectorTransformsOptions(vectorContractLowering)
138 .setVectorTransferSplit(vectorTransferSplit))
139 .setVectorTransferToSCFOptions(
140 VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
141
142 strategy.transform(getFunction());
143 }
144
145 namespace mlir {
146 namespace test {
registerTestLinalgCodegenStrategy()147 void registerTestLinalgCodegenStrategy() {
148 PassRegistration<TestLinalgCodegenStrategy> testLinalgCodegenStrategyPass(
149 "test-linalg-codegen-strategy", "Test Linalg Codegen Strategy.");
150 }
151 } // namespace test
152 } // namespace mlir
153