1 //===- sdbm-api-test.cpp - Tests for SDBM expression APIs -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 // RUN: mlir-sdbm-api-test | FileCheck %s
10
11 #include "mlir/Dialect/SDBM/SDBM.h"
12 #include "mlir/Dialect/SDBM/SDBMDialect.h"
13 #include "mlir/Dialect/SDBM/SDBMExpr.h"
14 #include "mlir/IR/MLIRContext.h"
15
16 #include "llvm/Support/raw_ostream.h"
17
18 #include "APITest.h"
19
20 using namespace mlir;
21
22
ctx()23 static MLIRContext *ctx() {
24 static thread_local MLIRContext context;
25 static thread_local bool once =
26 (context.getOrLoadDialect<SDBMDialect>(), true);
27 (void)once;
28 return &context;
29 }
30
dialect()31 static SDBMDialect *dialect() {
32 static thread_local SDBMDialect *d = nullptr;
33 if (!d) {
34 d = ctx()->getOrLoadDialect<SDBMDialect>();
35 }
36 return d;
37 }
38
dim(unsigned pos)39 static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
40
symb(unsigned pos)41 static SDBMExpr symb(unsigned pos) {
42 return SDBMSymbolExpr::get(dialect(), pos);
43 }
44
45 namespace {
46
47 using namespace mlir::ops_assertions;
48
TEST_FUNC(SDBM_SingleConstraint)49 TEST_FUNC(SDBM_SingleConstraint) {
50 // Build an SDBM defined by
51 // d0 - 3 <= 0 <=> d0 <= 3.
52 auto sdbm = SDBM::get(dim(0) - 3, llvm::None);
53
54 // CHECK: cst d0
55 // CHECK-NEXT: cst inf 3
56 // CHECK-NEXT: d0 inf inf
57 sdbm.print(llvm::outs());
58 }
59
TEST_FUNC(SDBM_Equality)60 TEST_FUNC(SDBM_Equality) {
61 // Build an SDBM defined by
62 //
63 // d0 - d1 - 3 = 0
64 // <=> {d0 - d1 - 3 <= 0 and d0 - d1 - 3 >= 0}
65 // <=> {d0 - d1 <= 3 and d1 - d0 <= -3}.
66 auto sdbm = SDBM::get(llvm::None, dim(0) - dim(1) - 3);
67
68 // CHECK: cst d0 d1
69 // CHECK-NEXT: cst inf inf inf
70 // CHECK-NEXT: d0 inf inf -3
71 // CHECK-NEXT: d1 inf 3 inf
72 sdbm.print(llvm::outs());
73 }
74
TEST_FUNC(SDBM_TrivialSimplification)75 TEST_FUNC(SDBM_TrivialSimplification) {
76 // Build an SDBM defined by
77 //
78 // d0 - 3 <= 0 <=> d0 <= 3
79 // d0 - 5 <= 0 <=> d0 <= 5
80 //
81 // which should get simplified on construction to only the former.
82 auto sdbm = SDBM::get({dim(0) - 3, dim(0) - 5}, llvm::None);
83
84 // CHECK: cst d0
85 // CHECK-NEXT: cst inf 3
86 // CHECK-NEXT: d0 inf inf
87 sdbm.print(llvm::outs());
88 }
89
TEST_FUNC(SDBM_StripeInducedIneqs)90 TEST_FUNC(SDBM_StripeInducedIneqs) {
91 // Build an SDBM defined by d1 = d0 # 3, which induces the constraints
92 //
93 // d1 - d0 <= 0
94 // d0 - d1 <= 3 - 1 = 2
95 auto sdbm = SDBM::get(llvm::None, dim(1) - stripe(dim(0), 3));
96
97 // CHECK: cst d0 d1
98 // CHECK-NEXT: cst inf inf inf
99 // CHECK-NEXT: d0 inf inf 0
100 // CHECK-NEXT: d1 inf 2 0
101 // CHECK-NEXT: d1 = d0 # 3
102 sdbm.print(llvm::outs());
103 }
104
TEST_FUNC(SDBM_StripeTemporaries)105 TEST_FUNC(SDBM_StripeTemporaries) {
106 // Build an SDBM defined by d0 # 3 <= 0, which creates a temporary
107 // t0 = d0 # 3 leading to a constraint t0 <= 0 and the stripe-induced
108 // constraints
109 //
110 // t0 - d0 <= 0
111 // d0 - t0 <= 3 - 1 = 2
112 auto sdbm = SDBM::get(stripe(dim(0), 3), llvm::None);
113
114 // CHECK: cst d0 t0
115 // CHECK-NEXT: cst inf inf 0
116 // CHECK-NEXT: d0 inf inf 0
117 // CHECK-NEXT: t0 inf 2 inf
118 // CHECK-NEXT: t0 = d0 # 3
119 sdbm.print(llvm::outs());
120 }
121
TEST_FUNC(SDBM_ElideInducedInequalities)122 TEST_FUNC(SDBM_ElideInducedInequalities) {
123 // Build an SDBM defined by a single stripe equality d0 = s0 # 3 and make sure
124 // the induced inequalities are not present after converting the SDBM back
125 // into lists of expressions.
126 auto sdbm = SDBM::get(llvm::None, {dim(0) - stripe(symb(0), 3)});
127
128 SmallVector<SDBMExpr, 4> eqs, ineqs;
129 sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
130 // CHECK-EMPTY:
131 for (auto ineq : ineqs)
132 ineq.print(llvm::outs() << '\n');
133 llvm::outs() << "\n";
134
135 // CHECK: d0 - s0 # 3
136 // CHECK-EMPTY:
137 for (auto eq : eqs)
138 eq.print(llvm::outs() << '\n');
139 llvm::outs() << "\n\n";
140 }
141
TEST_FUNC(SDBM_StripeTightening)142 TEST_FUNC(SDBM_StripeTightening) {
143 // Build an SDBM defined by
144 //
145 // d0 = s0 # 3 # 5
146 // s0 # 3 # 5 - d1 + 42 = 0
147 // s0 # 3 - d0 <= 2
148 //
149 // where the last inequality is tighter than that induced by the first stripe
150 // equality (s0 # 3 - d0 <= 5 - 1 = 4). Check that the conversion from SDBM
151 // back to the lists of constraints conserves both the stripe equality and the
152 // tighter inequality.
153 auto s = stripe(stripe(symb(0), 3), 5);
154 auto tight = stripe(symb(0), 3) - dim(0) - 2;
155 auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42});
156
157 SmallVector<SDBMExpr, 4> eqs, ineqs;
158 sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
159 // CHECK: s0 # 3 + -2 - d0
160 // CHECK-EMPTY:
161 for (auto ineq : ineqs)
162 ineq.print(llvm::outs() << '\n');
163 llvm::outs() << "\n";
164
165 // CHECK-DAG: d1 + -42 - d0
166 // CHECK-DAG: d0 - s0 # 3 # 5
167 for (auto eq : eqs)
168 eq.print(llvm::outs() << '\n');
169 llvm::outs() << "\n\n";
170 }
171
TEST_FUNC(SDBM_StripeTransitive)172 TEST_FUNC(SDBM_StripeTransitive) {
173 // Build an SDBM defined by
174 //
175 // d0 = d1 # 3
176 // d0 = d2 # 7
177 //
178 // where the same dimension is declared equal to two stripe expressions over
179 // different variables. This is practically handled by introducing a
180 // temporary variable for the second stripe expression and adding an equality
181 // constraint between this variable and the original dimension variable.
182 auto sdbm = SDBM::get(
183 llvm::None, {stripe(dim(1), 3) - dim(0), stripe(dim(2), 7) - dim(0)});
184
185 // CHECK: cst d0 d1 d2 t0
186 // CHECK-NEXT: cst inf inf inf inf inf
187 // CHECK-NEXT: d0 inf 0 2 inf 0
188 // CHECK-NEXT: d1 inf 0 inf inf inf
189 // CHECK-NEXT: d2 inf inf inf inf 0
190 // CHECK-NEXT: t0 inf 0 inf 6 inf
191 // CHECK-NEXT: t0 = d2 # 7
192 // CHECK-NEXT: d0 = d1 # 3
193 sdbm.print(llvm::outs());
194 }
195
196 } // end namespace
197
main()198 int main() {
199 RUN_TESTS();
200 return 0;
201 }
202