1 //===- TestSparsification.cpp - Test sparsification of tensors ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 
17 struct TestSparsification
18     : public PassWrapper<TestSparsification, FunctionPass> {
19 
20   TestSparsification() = default;
TestSparsification__anon14aa56c50111::TestSparsification21   TestSparsification(const TestSparsification &pass) {}
22 
23   Option<int32_t> parallelization{
24       *this, "parallelization-strategy",
25       llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
26 
27   Option<int32_t> vectorization{
28       *this, "vectorization-strategy",
29       llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
30 
31   Option<int32_t> vectorLength{
32       *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
33 
34   Option<int32_t> ptrType{*this, "ptr-type",
35                           llvm::cl::desc("Set the pointer type"),
36                           llvm::cl::init(0)};
37 
38   Option<int32_t> indType{*this, "ind-type",
39                           llvm::cl::desc("Set the index type"),
40                           llvm::cl::init(0)};
41 
42   /// Registers all dialects required by testing.
getDependentDialects__anon14aa56c50111::TestSparsification43   void getDependentDialects(DialectRegistry &registry) const override {
44     registry.insert<scf::SCFDialect, vector::VectorDialect>();
45   }
46 
47   /// Returns parallelization strategy given on command line.
parallelOption__anon14aa56c50111::TestSparsification48   linalg::SparseParallelizationStrategy parallelOption() {
49     switch (parallelization) {
50     default:
51       return linalg::SparseParallelizationStrategy::kNone;
52     case 1:
53       return linalg::SparseParallelizationStrategy::kDenseOuterLoop;
54     case 2:
55       return linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop;
56     case 3:
57       return linalg::SparseParallelizationStrategy::kDenseAnyLoop;
58     case 4:
59       return linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop;
60     }
61   }
62 
63   /// Returns vectorization strategy given on command line.
vectorOption__anon14aa56c50111::TestSparsification64   linalg::SparseVectorizationStrategy vectorOption() {
65     switch (vectorization) {
66     default:
67       return linalg::SparseVectorizationStrategy::kNone;
68     case 1:
69       return linalg::SparseVectorizationStrategy::kDenseInnerLoop;
70     case 2:
71       return linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop;
72     }
73   }
74 
75   /// Returns the requested integer type.
typeOption__anon14aa56c50111::TestSparsification76   linalg::SparseIntType typeOption(int32_t option) {
77     switch (option) {
78     default:
79       return linalg::SparseIntType::kNative;
80     case 1:
81       return linalg::SparseIntType::kI64;
82     case 2:
83       return linalg::SparseIntType::kI32;
84     }
85   }
86 
87   /// Runs the test on a function.
runOnFunction__anon14aa56c50111::TestSparsification88   void runOnFunction() override {
89     auto *ctx = &getContext();
90     OwningRewritePatternList patterns;
91     // Translate strategy flags to strategy options.
92     linalg::SparsificationOptions options(parallelOption(), vectorOption(),
93                                           vectorLength, typeOption(ptrType),
94                                           typeOption(indType));
95     // Apply rewriting.
96     linalg::populateSparsificationPatterns(ctx, patterns, options);
97     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
98   }
99 };
100 
101 } // end anonymous namespace
102 
103 namespace mlir {
104 namespace test {
105 
registerTestSparsification()106 void registerTestSparsification() {
107   PassRegistration<TestSparsification> sparsificationPass(
108       "test-sparsification",
109       "Test automatic geneneration of sparse tensor code");
110 }
111 
112 } // namespace test
113 } // namespace mlir
114