1 //===- TestSparsification.cpp - Test sparsification of tensors ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12
13 using namespace mlir;
14
15 namespace {
16
17 struct TestSparsification
18 : public PassWrapper<TestSparsification, FunctionPass> {
19
20 TestSparsification() = default;
TestSparsification__anon14aa56c50111::TestSparsification21 TestSparsification(const TestSparsification &pass) {}
22
23 Option<int32_t> parallelization{
24 *this, "parallelization-strategy",
25 llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
26
27 Option<int32_t> vectorization{
28 *this, "vectorization-strategy",
29 llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
30
31 Option<int32_t> vectorLength{
32 *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
33
34 Option<int32_t> ptrType{*this, "ptr-type",
35 llvm::cl::desc("Set the pointer type"),
36 llvm::cl::init(0)};
37
38 Option<int32_t> indType{*this, "ind-type",
39 llvm::cl::desc("Set the index type"),
40 llvm::cl::init(0)};
41
42 /// Registers all dialects required by testing.
getDependentDialects__anon14aa56c50111::TestSparsification43 void getDependentDialects(DialectRegistry ®istry) const override {
44 registry.insert<scf::SCFDialect, vector::VectorDialect>();
45 }
46
47 /// Returns parallelization strategy given on command line.
parallelOption__anon14aa56c50111::TestSparsification48 linalg::SparseParallelizationStrategy parallelOption() {
49 switch (parallelization) {
50 default:
51 return linalg::SparseParallelizationStrategy::kNone;
52 case 1:
53 return linalg::SparseParallelizationStrategy::kDenseOuterLoop;
54 case 2:
55 return linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop;
56 case 3:
57 return linalg::SparseParallelizationStrategy::kDenseAnyLoop;
58 case 4:
59 return linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop;
60 }
61 }
62
63 /// Returns vectorization strategy given on command line.
vectorOption__anon14aa56c50111::TestSparsification64 linalg::SparseVectorizationStrategy vectorOption() {
65 switch (vectorization) {
66 default:
67 return linalg::SparseVectorizationStrategy::kNone;
68 case 1:
69 return linalg::SparseVectorizationStrategy::kDenseInnerLoop;
70 case 2:
71 return linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop;
72 }
73 }
74
75 /// Returns the requested integer type.
typeOption__anon14aa56c50111::TestSparsification76 linalg::SparseIntType typeOption(int32_t option) {
77 switch (option) {
78 default:
79 return linalg::SparseIntType::kNative;
80 case 1:
81 return linalg::SparseIntType::kI64;
82 case 2:
83 return linalg::SparseIntType::kI32;
84 }
85 }
86
87 /// Runs the test on a function.
runOnFunction__anon14aa56c50111::TestSparsification88 void runOnFunction() override {
89 auto *ctx = &getContext();
90 OwningRewritePatternList patterns;
91 // Translate strategy flags to strategy options.
92 linalg::SparsificationOptions options(parallelOption(), vectorOption(),
93 vectorLength, typeOption(ptrType),
94 typeOption(indType));
95 // Apply rewriting.
96 linalg::populateSparsificationPatterns(ctx, patterns, options);
97 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
98 }
99 };
100
101 } // end anonymous namespace
102
103 namespace mlir {
104 namespace test {
105
registerTestSparsification()106 void registerTestSparsification() {
107 PassRegistration<TestSparsification> sparsificationPass(
108 "test-sparsification",
109 "Test automatic geneneration of sparse tensor code");
110 }
111
112 } // namespace test
113 } // namespace mlir
114