1 //===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SDBM/SDBM.h"
10 #include "mlir/Dialect/SDBM/SDBMDialect.h"
11 #include "mlir/Dialect/SDBM/SDBMExpr.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/MLIRContext.h"
14 #include "gtest/gtest.h"
15 
16 #include "llvm/ADT/DenseSet.h"
17 
18 using namespace mlir;
19 
20 
ctx()21 static MLIRContext *ctx() {
22   static thread_local MLIRContext context;
23   context.getOrLoadDialect<SDBMDialect>();
24   return &context;
25 }
26 
dialect()27 static SDBMDialect *dialect() {
28   static thread_local SDBMDialect *d = nullptr;
29   if (!d) {
30     d = ctx()->getOrLoadDialect<SDBMDialect>();
31   }
32   return d;
33 }
34 
dim(unsigned pos)35 static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
36 
symb(unsigned pos)37 static SDBMExpr symb(unsigned pos) {
38   return SDBMSymbolExpr::get(dialect(), pos);
39 }
40 
41 namespace {
42 
43 using namespace mlir::ops_assertions;
44 
TEST(SDBMOperators,Add)45 TEST(SDBMOperators, Add) {
46   auto expr = dim(0) + 42;
47   auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
48   ASSERT_TRUE(sumExpr);
49   EXPECT_EQ(sumExpr.getLHS(), dim(0));
50   EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
51 }
52 
TEST(SDBMOperators,AddFolding)53 TEST(SDBMOperators, AddFolding) {
54   auto constant = SDBMConstantExpr::get(dialect(), 2) + 42;
55   auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
56   ASSERT_TRUE(constantExpr);
57   EXPECT_EQ(constantExpr.getValue(), 44);
58 
59   auto expr = (dim(0) + 10) + 32;
60   auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
61   ASSERT_TRUE(sumExpr);
62   EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
63 
64   expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1));
65   auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
66   ASSERT_TRUE(diffExpr);
67   EXPECT_EQ(diffExpr.getLHS(), dim(0));
68   EXPECT_EQ(diffExpr.getRHS(), dim(1));
69 
70   auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
71   EXPECT_EQ(inverted, expr);
72 
73   // Check that opposite values cancel each other, and that we elide the zero
74   // constant.
75   expr = dim(0) + 42;
76   auto onlyDim = expr - 42;
77   EXPECT_EQ(onlyDim, dim(0));
78 
79   // Check that we can sink a constant under a negation.
80   expr = -(dim(0) + 2);
81   auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
82   ASSERT_TRUE(negatedSum);
83   auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
84   ASSERT_TRUE(sum);
85   EXPECT_EQ(sum.getRHS().getValue(), -8);
86 
87   // Sum with zero is the same as the original expression.
88   EXPECT_EQ(dim(0) + 0, dim(0));
89 
90   // Sum of opposite differences is zero.
91   auto diffOfDiffs =
92       ((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
93   EXPECT_EQ(diffOfDiffs.getValue(), 0);
94 }
95 
TEST(SDBMOperators,AddNegativeTerms)96 TEST(SDBMOperators, AddNegativeTerms) {
97   const int64_t A = 7;
98   const int64_t B = -5;
99   auto x = SDBMDimExpr::get(dialect(), 0);
100   auto y = SDBMDimExpr::get(dialect(), 1);
101 
102   // Check the simplification patterns in addition where one of the variables is
103   // cancelled out and the result remains an SDBM.
104   EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B)));
105   EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A);
106   EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A)));
107   EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B);
108 }
109 
TEST(SDBMOperators,Diff)110 TEST(SDBMOperators, Diff) {
111   auto expr = dim(0) - dim(1);
112   auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
113   ASSERT_TRUE(diffExpr);
114   EXPECT_EQ(diffExpr.getLHS(), dim(0));
115   EXPECT_EQ(diffExpr.getRHS(), dim(1));
116 }
117 
TEST(SDBMOperators,DiffFolding)118 TEST(SDBMOperators, DiffFolding) {
119   auto constant = SDBMConstantExpr::get(dialect(), 10) - 3;
120   auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
121   ASSERT_TRUE(constantExpr);
122   EXPECT_EQ(constantExpr.getValue(), 7);
123 
124   auto expr = dim(0) - 3;
125   auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
126   ASSERT_TRUE(sumExpr);
127   EXPECT_EQ(sumExpr.getRHS().getValue(), -3);
128 
129   auto zero = dim(0) - dim(0);
130   constantExpr = zero.dyn_cast<SDBMConstantExpr>();
131   ASSERT_TRUE(constantExpr);
132   EXPECT_EQ(constantExpr.getValue(), 0);
133 
134   // Check that the constant terms in difference-of-sums are folded.
135   // (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
136   auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
137   ASSERT_TRUE(diffOfSums);
138   auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
139   ASSERT_TRUE(lhs);
140   EXPECT_EQ(lhs.getLHS(), dim(0));
141   EXPECT_EQ(lhs.getRHS().getValue(), 2);
142   EXPECT_EQ(diffOfSums.getRHS(), dim(1));
143 
144   // Check that identical dimensions with opposite signs cancel each other.
145   auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
146   ASSERT_TRUE(cstOnly);
147   EXPECT_EQ(cstOnly.getValue(), 42);
148 
149   // Check that identical terms in sum of diffs cancel out.
150   auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
151   EXPECT_EQ(dimOnly, -dim(1));
152   dimOnly = (dim(0) - dim(1)) + (-dim(0));
153   EXPECT_EQ(dimOnly, -dim(1));
154   dimOnly = (dim(0) - dim(1)) + dim(1);
155   EXPECT_EQ(dimOnly, dim(0));
156   dimOnly = dim(0) + (dim(1) - dim(0));
157   EXPECT_EQ(dimOnly, dim(1));
158 
159   // Top-level zero constant is fine.
160   cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
161   ASSERT_TRUE(cstOnly);
162   EXPECT_EQ(cstOnly.getValue(), 0);
163 }
164 
TEST(SDBMOperators,Negate)165 TEST(SDBMOperators, Negate) {
166   auto sum = dim(0) + 3;
167   auto negated = (-sum).dyn_cast<SDBMNegExpr>();
168   ASSERT_TRUE(negated);
169   EXPECT_EQ(negated.getVar(), sum);
170 }
171 
TEST(SDBMOperators,Stripe)172 TEST(SDBMOperators, Stripe) {
173   auto expr = stripe(dim(0), 3);
174   auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
175   ASSERT_TRUE(stripeExpr);
176   EXPECT_EQ(stripeExpr.getLHS(), dim(0));
177   EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
178 }
179 
TEST(SDBM,RoundTripEqs)180 TEST(SDBM, RoundTripEqs) {
181   // Build an SDBM defined by
182   //
183   //   d0 = s0 # 3 # 5
184   //   s0 # 3 # 5 - d1 + 42 = 0
185   //
186   // and perform a double round-trip between the "list of equalities" and SDBM
187   // representation.  After the first round-trip, the equalities may be
188   // different due to simplification or equivalent substitutions (e.g., the
189   // second equality may become d0 - d1 + 42 = 0).  However, there should not
190   // be any further simplification after the second round-trip,
191 
192   // Build the SDBM from a pair of equalities and extract back the lists of
193   // inequalities and equalities.  Check that all equalities are properly
194   // detected and none of them decayed into inequalities.
195   auto s = stripe(stripe(symb(0), 3), 5);
196   auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
197   SmallVector<SDBMExpr, 4> eqs, ineqs;
198   sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
199   ASSERT_TRUE(ineqs.empty());
200 
201   // Do the second round-trip.
202   auto sdbm2 = SDBM::get(llvm::None, eqs);
203   SmallVector<SDBMExpr, 4> eqs2, ineqs2;
204   sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2);
205   ASSERT_EQ(eqs.size(), eqs2.size());
206 
207   // Check that the sets of equalities are equal, their order is not relevant.
208   llvm::DenseSet<SDBMExpr> eqSet, eq2Set;
209   eqSet.insert(eqs.begin(), eqs.end());
210   eq2Set.insert(eqs2.begin(), eqs2.end());
211   EXPECT_EQ(eqSet, eq2Set);
212 }
213 
TEST(SDBMExpr,Constant)214 TEST(SDBMExpr, Constant) {
215   // We can create constants and query them.
216   auto expr = SDBMConstantExpr::get(dialect(), 42);
217   EXPECT_EQ(expr.getValue(), 42);
218 
219   // Two separately created constants with identical values are trivially equal.
220   auto expr2 = SDBMConstantExpr::get(dialect(), 42);
221   EXPECT_EQ(expr, expr2);
222 
223   // Hierarchy is okay.
224   auto generic = static_cast<SDBMExpr>(expr);
225   EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
226 }
227 
TEST(SDBMExpr,Dim)228 TEST(SDBMExpr, Dim) {
229   // We can create dimension expressions and query them.
230   auto expr = SDBMDimExpr::get(dialect(), 0);
231   EXPECT_EQ(expr.getPosition(), 0u);
232 
233   // Two separately created dimensions with the same position are trivially
234   // equal.
235   auto expr2 = SDBMDimExpr::get(dialect(), 0);
236   EXPECT_EQ(expr, expr2);
237 
238   // Hierarchy is okay.
239   auto generic = static_cast<SDBMExpr>(expr);
240   EXPECT_TRUE(generic.isa<SDBMDimExpr>());
241   EXPECT_TRUE(generic.isa<SDBMInputExpr>());
242   EXPECT_TRUE(generic.isa<SDBMTermExpr>());
243   EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
244   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
245 
246   // Dimensions are not Symbols.
247   auto symbol = SDBMSymbolExpr::get(dialect(), 0);
248   EXPECT_NE(expr, symbol);
249   EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
250 }
251 
TEST(SDBMExpr,Symbol)252 TEST(SDBMExpr, Symbol) {
253   // We can create symbol expressions and query them.
254   auto expr = SDBMSymbolExpr::get(dialect(), 0);
255   EXPECT_EQ(expr.getPosition(), 0u);
256 
257   // Two separately created symbols with the same position are trivially equal.
258   auto expr2 = SDBMSymbolExpr::get(dialect(), 0);
259   EXPECT_EQ(expr, expr2);
260 
261   // Hierarchy is okay.
262   auto generic = static_cast<SDBMExpr>(expr);
263   EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
264   EXPECT_TRUE(generic.isa<SDBMInputExpr>());
265   EXPECT_TRUE(generic.isa<SDBMTermExpr>());
266   EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
267   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
268 
269   // Dimensions are not Symbols.
270   auto symbol = SDBMDimExpr::get(dialect(), 0);
271   EXPECT_NE(expr, symbol);
272   EXPECT_FALSE(expr.isa<SDBMDimExpr>());
273 }
274 
TEST(SDBMExpr,Stripe)275 TEST(SDBMExpr, Stripe) {
276   auto cst2 = SDBMConstantExpr::get(dialect(), 2);
277   auto cst0 = SDBMConstantExpr::get(dialect(), 0);
278   auto var = SDBMSymbolExpr::get(dialect(), 0);
279 
280   // We can create stripe expressions and query them.
281   auto expr = SDBMStripeExpr::get(var, cst2);
282   EXPECT_EQ(expr.getLHS(), var);
283   EXPECT_EQ(expr.getStripeFactor(), cst2);
284 
285   // Two separately created stripe expressions with the same LHS and RHS are
286   // trivially equal.
287   auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2);
288   EXPECT_EQ(expr, expr2);
289 
290   // Stripes can be nested.
291   SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4));
292 
293   // Non-positive stripe factors are not allowed.
294   EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
295 
296   // Stripes can have sums on the LHS.
297   SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
298 
299   // Hierarchy is okay.
300   auto generic = static_cast<SDBMExpr>(expr);
301   EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
302   EXPECT_TRUE(generic.isa<SDBMTermExpr>());
303   EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
304   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
305 }
306 
TEST(SDBMExpr,Neg)307 TEST(SDBMExpr, Neg) {
308   auto cst2 = SDBMConstantExpr::get(dialect(), 2);
309   auto var = SDBMSymbolExpr::get(dialect(), 0);
310   auto stripe = SDBMStripeExpr::get(var, cst2);
311 
312   // We can create negation expressions and query them.
313   auto expr = SDBMNegExpr::get(var);
314   EXPECT_EQ(expr.getVar(), var);
315   auto expr2 = SDBMNegExpr::get(stripe);
316   EXPECT_EQ(expr2.getVar(), stripe);
317 
318   // Neg expressions are trivially comparable.
319   EXPECT_EQ(expr, SDBMNegExpr::get(var));
320 
321   // Hierarchy is okay.
322   auto generic = static_cast<SDBMExpr>(expr);
323   EXPECT_TRUE(generic.isa<SDBMNegExpr>());
324   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
325 }
326 
TEST(SDBMExpr,Sum)327 TEST(SDBMExpr, Sum) {
328   auto cst2 = SDBMConstantExpr::get(dialect(), 2);
329   auto var = SDBMSymbolExpr::get(dialect(), 0);
330   auto stripe = SDBMStripeExpr::get(var, cst2);
331 
332   // We can create sum expressions and query them.
333   auto expr = SDBMSumExpr::get(var, cst2);
334   EXPECT_EQ(expr.getLHS(), var);
335   EXPECT_EQ(expr.getRHS(), cst2);
336   auto expr2 = SDBMSumExpr::get(stripe, cst2);
337   EXPECT_EQ(expr2.getLHS(), stripe);
338   EXPECT_EQ(expr2.getRHS(), cst2);
339 
340   // Sum expressions are trivially comparable.
341   EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
342 
343   // Hierarchy is okay.
344   auto generic = static_cast<SDBMExpr>(expr);
345   EXPECT_TRUE(generic.isa<SDBMSumExpr>());
346   EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
347   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
348 }
349 
TEST(SDBMExpr,Diff)350 TEST(SDBMExpr, Diff) {
351   auto cst2 = SDBMConstantExpr::get(dialect(), 2);
352   auto var = SDBMSymbolExpr::get(dialect(), 0);
353   auto stripe = SDBMStripeExpr::get(var, cst2);
354 
355   // We can create sum expressions and query them.
356   auto expr = SDBMDiffExpr::get(var, stripe);
357   EXPECT_EQ(expr.getLHS(), var);
358   EXPECT_EQ(expr.getRHS(), stripe);
359   auto expr2 = SDBMDiffExpr::get(stripe, var);
360   EXPECT_EQ(expr2.getLHS(), stripe);
361   EXPECT_EQ(expr2.getRHS(), var);
362 
363   // Sum expressions are trivially comparable.
364   EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
365 
366   // Hierarchy is okay.
367   auto generic = static_cast<SDBMExpr>(expr);
368   EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
369   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
370 }
371 
TEST(SDBMExpr,AffineRoundTrip)372 TEST(SDBMExpr, AffineRoundTrip) {
373   // Build an expression (s0 - s0 # 2)
374   auto cst2 = SDBMConstantExpr::get(dialect(), 2);
375   auto var = SDBMSymbolExpr::get(dialect(), 0);
376   auto stripe = SDBMStripeExpr::get(var, cst2);
377   auto expr = SDBMDiffExpr::get(var, stripe);
378 
379   // Check that it can be converted to AffineExpr and back, i.e. stripe
380   // detection works correctly.
381   Optional<SDBMExpr> roundtripped =
382       SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
383   ASSERT_TRUE(roundtripped.hasValue());
384   EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
385 
386   // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
387   // detection supports nested expressions.
388   auto cst5 = SDBMConstantExpr::get(dialect(), 5);
389   auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
390   roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
391   ASSERT_TRUE(roundtripped.hasValue());
392   EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
393 
394   // Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
395   // stripe detection supports sum expressions.
396   auto inner = SDBMSumExpr::get(var, cst2);
397   auto stripeSum = SDBMStripeExpr::get(inner, cst5);
398   roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
399   ASSERT_TRUE(roundtripped.hasValue());
400   EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
401 
402   // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
403   // deeper expression tree.
404   auto sum = SDBMSumExpr::get(outerStripe, cst2);
405   auto diff = SDBMDiffExpr::get(sum, stripe);
406   roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
407   ASSERT_TRUE(roundtripped.hasValue());
408   EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
409 
410   // Check a nested stripe-sum combination.
411   auto cst7 = SDBMConstantExpr::get(dialect(), 7);
412   auto nestedStripe =
413       SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
414   diff = SDBMDiffExpr::get(nestedStripe, stripe);
415   roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
416   ASSERT_TRUE(roundtripped.hasValue());
417   EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
418 }
419 
TEST(SDBMExpr,MatchStripeMulPattern)420 TEST(SDBMExpr, MatchStripeMulPattern) {
421   // Make sure conversion from AffineExpr recognizes multiplicative stripe
422   // pattern (x floordiv B) * B == x # B.
423   auto cst = getAffineConstantExpr(42, ctx());
424   auto dim = getAffineDimExpr(0, ctx());
425   auto floor = dim.floorDiv(cst);
426   auto mul = cst * floor;
427   Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
428   ASSERT_TRUE(converted.hasValue());
429   EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
430 }
431 
TEST(SDBMExpr,NonSDBM)432 TEST(SDBMExpr, NonSDBM) {
433   auto d0 = getAffineDimExpr(0, ctx());
434   auto d1 = getAffineDimExpr(1, ctx());
435   auto sum = d0 + d1;
436   auto c2 = getAffineConstantExpr(2, ctx());
437   auto prod = d0 * c2;
438   auto ceildiv = d1.ceilDiv(c2);
439 
440   // The following are not valid SDBM expressions:
441   // - a sum of two variables
442   EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
443   // - a variable with coefficient other than 1 or -1
444   EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
445   // - a ceildiv expression
446   EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
447 }
448 
449 } // end namespace
450