1 //===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A striped difference-bound matrix (SDBM) expression is a constant expression,
10 // an identifier, a binary expression with constant RHS and +, stripe operators
11 // or a difference expression between two identifiers.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SDBM/SDBMExpr.h"
16 #include "SDBMExprDetail.h"
17 #include "mlir/Dialect/SDBM/SDBMDialect.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 /// A simple compositional matcher for AffineExpr
27 ///
28 /// Example usage:
29 ///
30 /// ```c++
31 ///    AffineExprMatcher x, C, m;
32 ///    AffineExprMatcher pattern1 = ((x % C) * m) + x;
33 ///    AffineExprMatcher pattern2 = x + ((x % C) * m);
34 ///    if (pattern1.match(expr) || pattern2.match(expr)) {
35 ///      ...
36 ///    }
37 /// ```
38 class AffineExprMatcherStorage;
39 class AffineExprMatcher {
40 public:
41   AffineExprMatcher();
42   AffineExprMatcher(const AffineExprMatcher &other);
43 
operator +(AffineExprMatcher other)44   AffineExprMatcher operator+(AffineExprMatcher other) {
45     return AffineExprMatcher(AffineExprKind::Add, *this, other);
46   }
operator *(AffineExprMatcher other)47   AffineExprMatcher operator*(AffineExprMatcher other) {
48     return AffineExprMatcher(AffineExprKind::Mul, *this, other);
49   }
floorDiv(AffineExprMatcher other)50   AffineExprMatcher floorDiv(AffineExprMatcher other) {
51     return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
52   }
ceilDiv(AffineExprMatcher other)53   AffineExprMatcher ceilDiv(AffineExprMatcher other) {
54     return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
55   }
operator %(AffineExprMatcher other)56   AffineExprMatcher operator%(AffineExprMatcher other) {
57     return AffineExprMatcher(AffineExprKind::Mod, *this, other);
58   }
59 
60   AffineExpr match(AffineExpr expr);
61   AffineExpr matched();
62   Optional<int> getMatchedConstantValue();
63 
64 private:
65   AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
66   AffineExprKind kind; // only used to match in binary op cases.
67   // A shared_ptr allows multiple references to same matcher storage without
68   // worrying about ownership or dealing with an arena. To be cleaned up if we
69   // go with this.
70   std::shared_ptr<AffineExprMatcherStorage> storage;
71 };
72 
73 class AffineExprMatcherStorage {
74 public:
AffineExprMatcherStorage()75   AffineExprMatcherStorage() {}
AffineExprMatcherStorage(const AffineExprMatcherStorage & other)76   AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
77       : subExprs(other.subExprs.begin(), other.subExprs.end()),
78         matched(other.matched) {}
AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)79   AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
80       : subExprs(exprs.begin(), exprs.end()) {}
AffineExprMatcherStorage(AffineExprMatcher & a,AffineExprMatcher & b)81   AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
82       : subExprs({a, b}) {}
83   SmallVector<AffineExprMatcher, 0> subExprs;
84   AffineExpr matched;
85 };
86 } // namespace
87 
AffineExprMatcher()88 AffineExprMatcher::AffineExprMatcher()
89     : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
90 
AffineExprMatcher(const AffineExprMatcher & other)91 AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
92     : kind(other.kind), storage(other.storage) {}
93 
getMatchedConstantValue()94 Optional<int> AffineExprMatcher::getMatchedConstantValue() {
95   if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
96     return cst.getValue();
97   return None;
98 }
99 
match(AffineExpr expr)100 AffineExpr AffineExprMatcher::match(AffineExpr expr) {
101   if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
102     if (storage->matched)
103       if (storage->matched != expr)
104         return AffineExpr();
105     storage->matched = expr;
106     return storage->matched;
107   }
108   if (kind != expr.getKind()) {
109     return AffineExpr();
110   }
111   if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
112     if (!storage->subExprs.empty() &&
113         !storage->subExprs[0].match(bin.getLHS())) {
114       return AffineExpr();
115     }
116     if (!storage->subExprs.empty() &&
117         !storage->subExprs[1].match(bin.getRHS())) {
118       return AffineExpr();
119     }
120     if (storage->matched)
121       if (storage->matched != expr)
122         return AffineExpr();
123     storage->matched = expr;
124     return storage->matched;
125   }
126   llvm_unreachable("binary expected");
127 }
128 
matched()129 AffineExpr AffineExprMatcher::matched() { return storage->matched; }
130 
AffineExprMatcher(AffineExprKind k,AffineExprMatcher a,AffineExprMatcher b)131 AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
132                                      AffineExprMatcher b)
133     : kind(k), storage(new AffineExprMatcherStorage(a, b)) {
134   storage->subExprs.push_back(a);
135   storage->subExprs.push_back(b);
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // SDBMExpr
140 //===----------------------------------------------------------------------===//
141 
getKind() const142 SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
143 
getContext() const144 MLIRContext *SDBMExpr::getContext() const {
145   return impl->dialect->getContext();
146 }
147 
getDialect() const148 SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
149 
print(raw_ostream & os) const150 void SDBMExpr::print(raw_ostream &os) const {
151   struct Printer : public SDBMVisitor<Printer> {
152     Printer(raw_ostream &ostream) : prn(ostream) {}
153 
154     void visitSum(SDBMSumExpr expr) {
155       visit(expr.getLHS());
156       prn << " + ";
157       visit(expr.getRHS());
158     }
159     void visitDiff(SDBMDiffExpr expr) {
160       visit(expr.getLHS());
161       prn << " - ";
162       visit(expr.getRHS());
163     }
164     void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
165     void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
166     void visitStripe(SDBMStripeExpr expr) {
167       SDBMDirectExpr lhs = expr.getLHS();
168       bool isTerm = lhs.isa<SDBMTermExpr>();
169       if (!isTerm)
170         prn << '(';
171       visit(lhs);
172       if (!isTerm)
173         prn << ')';
174       prn << " # ";
175       visitConstant(expr.getStripeFactor());
176     }
177     void visitNeg(SDBMNegExpr expr) {
178       bool isSum = expr.getVar().isa<SDBMSumExpr>();
179       prn << '-';
180       if (isSum)
181         prn << '(';
182       visit(expr.getVar());
183       if (isSum)
184         prn << ')';
185     }
186     void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
187 
188     raw_ostream &prn;
189   };
190   Printer printer(os);
191   printer.visit(*this);
192 }
193 
dump() const194 void SDBMExpr::dump() const {
195   print(llvm::errs());
196   llvm::errs() << '\n';
197 }
198 
199 namespace {
200 // Helper class to perform negation of an SDBM expression.
201 struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
202   // Any term expression is wrapped into a negation expression.
203   //  -(x) = -x
visitDirect__anon65c8f2b80211::SDBMNegator204   SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
205   // A negation expression is unwrapped.
206   //  -(-x) = x
visitNeg__anon65c8f2b80211::SDBMNegator207   SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
208   // The value of the constant is negated.
visitConstant__anon65c8f2b80211::SDBMNegator209   SDBMExpr visitConstant(SDBMConstantExpr expr) {
210     return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
211   }
212 
213   // Terms of a difference are interchanged. Since only the LHS of a diff
214   // expression is allowed to be a sum with a constant, we need to recreate the
215   // sum with the negated value:
216   //   -((x + C) - y) = (y - C) - x.
visitDiff__anon65c8f2b80211::SDBMNegator217   SDBMExpr visitDiff(SDBMDiffExpr expr) {
218     // If the LHS is just a term, we can do straightforward interchange.
219     if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
220       return SDBMDiffExpr::get(expr.getRHS(), term);
221 
222     auto sum = expr.getLHS().cast<SDBMSumExpr>();
223     auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
224     return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
225                              sum.getLHS());
226   }
227 };
228 } // namespace
229 
operator -()230 SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
231 
232 //===----------------------------------------------------------------------===//
233 // SDBMSumExpr
234 //===----------------------------------------------------------------------===//
235 
get(SDBMTermExpr lhs,SDBMConstantExpr rhs)236 SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
237   assert(lhs && "expected SDBM variable expression");
238   assert(rhs && "expected SDBM constant");
239 
240   // If LHS of a sum is another sum, fold the constant RHS parts.
241   if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
242     lhs = lhsSum.getLHS();
243     rhs = SDBMConstantExpr::get(rhs.getDialect(),
244                                 rhs.getValue() + lhsSum.getRHS().getValue());
245   }
246 
247   StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
248   return uniquer.get<detail::SDBMBinaryExprStorage>(
249       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
250 }
251 
getLHS() const252 SDBMTermExpr SDBMSumExpr::getLHS() const {
253   return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
254 }
255 
getRHS() const256 SDBMConstantExpr SDBMSumExpr::getRHS() const {
257   return static_cast<ImplType *>(impl)->rhs;
258 }
259 
getAsAffineExpr() const260 AffineExpr SDBMExpr::getAsAffineExpr() const {
261   struct Converter : public SDBMVisitor<Converter, AffineExpr> {
262     AffineExpr visitSum(SDBMSumExpr expr) {
263       AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
264       return lhs + rhs;
265     }
266 
267     AffineExpr visitStripe(SDBMStripeExpr expr) {
268       AffineExpr lhs = visit(expr.getLHS()),
269                  rhs = visit(expr.getStripeFactor());
270       return lhs - (lhs % rhs);
271     }
272 
273     AffineExpr visitDiff(SDBMDiffExpr expr) {
274       AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
275       return lhs - rhs;
276     }
277 
278     AffineExpr visitDim(SDBMDimExpr expr) {
279       return getAffineDimExpr(expr.getPosition(), expr.getContext());
280     }
281 
282     AffineExpr visitSymbol(SDBMSymbolExpr expr) {
283       return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
284     }
285 
286     AffineExpr visitNeg(SDBMNegExpr expr) {
287       return getAffineBinaryOpExpr(AffineExprKind::Mul,
288                                    getAffineConstantExpr(-1, expr.getContext()),
289                                    visit(expr.getVar()));
290     }
291 
292     AffineExpr visitConstant(SDBMConstantExpr expr) {
293       return getAffineConstantExpr(expr.getValue(), expr.getContext());
294     }
295   } converter;
296   return converter.visit(*this);
297 }
298 
299 // Given a direct expression `expr`, add the given constant to it and pass the
300 // resulting expression to `builder` before returning its result.  If the
301 // expression is already a sum expression, update its constant and extract the
302 // LHS if the constant becomes zero.  Otherwise, construct a sum expression.
303 template <typename Result>
addConstantAndSink(SDBMDirectExpr expr,int64_t constant,bool negated,function_ref<Result (SDBMDirectExpr)> builder)304 static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
305                                  bool negated,
306                                  function_ref<Result(SDBMDirectExpr)> builder) {
307   SDBMDialect *dialect = expr.getDialect();
308   if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
309     if (negated)
310       constant = sumExpr.getRHS().getValue() - constant;
311     else
312       constant += sumExpr.getRHS().getValue();
313 
314     if (constant != 0) {
315       auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
316                                   SDBMConstantExpr::get(dialect, constant));
317       return builder(sum);
318     } else {
319       return builder(sumExpr.getLHS());
320     }
321   }
322   if (constant != 0)
323     return builder(SDBMSumExpr::get(
324         expr.cast<SDBMTermExpr>(),
325         SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
326   return expr;
327 }
328 
329 // Construct an expression lhs + constant while maintaining the canonical form
330 // of the SDBM expressions, in particular sink the constant expression to the
331 // nearest sum expression in the left subtree of the expression tree.
addConstant(SDBMVaryingExpr lhs,int64_t constant)332 static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
333   if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
334     return addConstantAndSink<SDBMExpr>(
335         lhsDiff.getLHS(), constant, /*negated=*/false,
336         [lhsDiff](SDBMDirectExpr e) {
337           return SDBMDiffExpr::get(e, lhsDiff.getRHS());
338         });
339   if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
340     return addConstantAndSink<SDBMExpr>(
341         lhsNeg.getVar(), constant, /*negated=*/true,
342         [](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
343   if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
344     return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
345                                         [](SDBMDirectExpr e) { return e; });
346   if (constant != 0)
347     return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
348                             SDBMConstantExpr::get(lhs.getDialect(), constant));
349   return lhs;
350 }
351 
352 // Build a difference expression given a direct expression and a negation
353 // expression.
buildDiffExpr(SDBMDirectExpr lhs,SDBMNegExpr rhs)354 static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
355   // Fold (x + C) - (x + D) = C - D.
356   if (lhs.getTerm() == rhs.getVar().getTerm())
357     return SDBMConstantExpr::get(
358         lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
359 
360   return SDBMDiffExpr::get(
361       addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
362                                          /*negated=*/false,
363                                          [](SDBMDirectExpr e) { return e; }),
364       rhs.getVar().getTerm());
365 }
366 
367 // Try folding an expression (lhs + rhs) where at least one of the operands
368 // contains a negated variable, i.e. is a negation or a difference expression.
foldSumDiff(SDBMExpr lhs,SDBMExpr rhs)369 static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
370   // If exactly one of LHS, RHS is a negation expression, we can construct
371   // a difference expression, which is a special kind in SDBM.
372   auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
373   auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
374   auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
375   auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
376 
377   if (lhsDirect && rhsNeg)
378     return buildDiffExpr(lhsDirect, rhsNeg);
379   if (lhsNeg && rhsDirect)
380     return buildDiffExpr(rhsDirect, lhsNeg);
381 
382   // If a subexpression appears in a diff expression on the LHS(RHS) of a
383   // sum expression where it also appears on the RHS(LHS) with the opposite
384   // sign, we can simplify it away and obtain the SDBM form.
385   auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
386   auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
387 
388   // -(x + A) + ((x + B) - y) = -(y + (A - B))
389   if (lhsNeg && rhsDiff &&
390       lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
391     int64_t constant =
392         lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
393     // RHS of the diff is a term expression, its sum with a constant is a direct
394     // expression.
395     return SDBMNegExpr::get(
396         addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
397   }
398 
399   // (x + A) + ((y + B) - x) = (y + B) + A.
400   if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
401     return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
402 
403   // ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
404   if (lhsDiff && rhsNeg &&
405       lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
406     int64_t constant =
407         rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
408     // RHS of the diff is a term expression, its sum with a constant is a direct
409     // expression.
410     return SDBMNegExpr::get(
411         addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
412   }
413 
414   // ((x + A) - y) + (y + B) = (x + A) + B.
415   if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
416     return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
417 
418   return {};
419 }
420 
tryConvertAffineExpr(AffineExpr affine)421 Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
422   struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
423     SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
424       auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
425       if (!lhs || !rhs)
426         return {};
427 
428       // In a "add" AffineExpr, the constant always appears on the right.  If
429       // there were two constants, they would have been folded away.
430       assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
431 
432       // If RHS is a constant, we can always extend the SDBM expression to
433       // include it by sinking the constant into the nearest sum expression.
434       if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
435         int64_t constant = rhsConstant.getValue();
436         auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
437         assert(varying && "unexpected uncanonicalized sum of constants");
438         return addConstant(varying, constant);
439       }
440 
441       // Try building a difference expression if one of the values is negated,
442       // or check if a difference on either hand side cancels out the outer term
443       // so as to remain correct within SDBM. Return null otherwise.
444       return foldSumDiff(lhs, rhs);
445     }
446 
447     SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
448       // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
449       AffineExprMatcher x, C;
450       AffineExprMatcher pattern = (x.floorDiv(C)) * C;
451       if (pattern.match(expr)) {
452         if (SDBMExpr converted = visit(x.matched())) {
453           if (auto varConverted = converted.dyn_cast<SDBMTermExpr>())
454             // TODO: return varConverted.stripe(C.getConstantValue());
455             return SDBMStripeExpr::get(
456                 varConverted,
457                 SDBMConstantExpr::get(dialect,
458                                       C.getMatchedConstantValue().getValue()));
459         }
460       }
461 
462       auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
463       if (!lhs || !rhs)
464         return {};
465 
466       // In a "mul" AffineExpr, the constant always appears on the right.  If
467       // there were two constants, they would have been folded away.
468       assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
469       auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
470       if (!rhsConstant)
471         return {};
472 
473       // The only supported "multiplication" expression is an SDBM is dimension
474       // negation, that is a product of dimension and constant -1.
475       if (rhsConstant.getValue() != -1)
476         return {};
477 
478       if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
479         return SDBMNegExpr::get(lhsVar);
480       if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
481         return SDBMNegator().visitDiff(lhsDiff);
482 
483       // Other multiplications are not allowed in SDBM.
484       return {};
485     }
486 
487     SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
488       auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
489       if (!lhs || !rhs)
490         return {};
491 
492       // 'mod' can only be converted to SDBM if its LHS is a direct expression
493       // and its RHS is a constant.  Then it `x mod c = x - x stripe c`.
494       auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
495       auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
496       if (!lhsVar || !rhsConstant)
497         return {};
498       return SDBMDiffExpr::get(lhsVar,
499                                SDBMStripeExpr::get(lhsVar, rhsConstant));
500     }
501 
502     // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
503     SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
504     SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
505 
506     // Dimensions, symbols and constants are converted trivially.
507     SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
508       return SDBMConstantExpr::get(dialect, expr.getValue());
509     }
510     SDBMExpr visitDimExpr(AffineDimExpr expr) {
511       return SDBMDimExpr::get(dialect, expr.getPosition());
512     }
513     SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
514       return SDBMSymbolExpr::get(dialect, expr.getPosition());
515     }
516 
517     SDBMDialect *dialect;
518   } converter;
519   converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
520 
521   if (auto result = converter.visit(affine))
522     return result;
523   return None;
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // SDBMDiffExpr
528 //===----------------------------------------------------------------------===//
529 
get(SDBMDirectExpr lhs,SDBMTermExpr rhs)530 SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
531   assert(lhs && "expected SDBM dimension");
532   assert(rhs && "expected SDBM dimension");
533 
534   StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
535   return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
536 }
537 
getLHS() const538 SDBMDirectExpr SDBMDiffExpr::getLHS() const {
539   return static_cast<ImplType *>(impl)->lhs;
540 }
541 
getRHS() const542 SDBMTermExpr SDBMDiffExpr::getRHS() const {
543   return static_cast<ImplType *>(impl)->rhs;
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // SDBMDirectExpr
548 //===----------------------------------------------------------------------===//
549 
getTerm()550 SDBMTermExpr SDBMDirectExpr::getTerm() {
551   if (auto sum = dyn_cast<SDBMSumExpr>())
552     return sum.getLHS();
553   return cast<SDBMTermExpr>();
554 }
555 
getConstant()556 int64_t SDBMDirectExpr::getConstant() {
557   if (auto sum = dyn_cast<SDBMSumExpr>())
558     return sum.getRHS().getValue();
559   return 0;
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // SDBMStripeExpr
564 //===----------------------------------------------------------------------===//
565 
get(SDBMDirectExpr var,SDBMConstantExpr stripeFactor)566 SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
567                                    SDBMConstantExpr stripeFactor) {
568   assert(var && "expected SDBM variable expression");
569   assert(stripeFactor && "expected non-null stripe factor");
570   if (stripeFactor.getValue() <= 0)
571     llvm::report_fatal_error("non-positive stripe factor");
572 
573   StorageUniquer &uniquer = var.getDialect()->getUniquer();
574   return uniquer.get<detail::SDBMBinaryExprStorage>(
575       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
576       stripeFactor);
577 }
578 
getLHS() const579 SDBMDirectExpr SDBMStripeExpr::getLHS() const {
580   if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
581     return lhs.cast<SDBMDirectExpr>();
582   return {};
583 }
584 
getStripeFactor() const585 SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
586   return static_cast<ImplType *>(impl)->rhs;
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // SDBMInputExpr
591 //===----------------------------------------------------------------------===//
592 
getPosition() const593 unsigned SDBMInputExpr::getPosition() const {
594   return static_cast<ImplType *>(impl)->position;
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // SDBMDimExpr
599 //===----------------------------------------------------------------------===//
600 
get(SDBMDialect * dialect,unsigned position)601 SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
602   assert(dialect && "expected non-null dialect");
603 
604   auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
605     storage->dialect = dialect;
606   };
607 
608   StorageUniquer &uniquer = dialect->getUniquer();
609   return uniquer.get<detail::SDBMTermExprStorage>(
610       assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // SDBMSymbolExpr
615 //===----------------------------------------------------------------------===//
616 
get(SDBMDialect * dialect,unsigned position)617 SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
618   assert(dialect && "expected non-null dialect");
619 
620   auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
621     storage->dialect = dialect;
622   };
623 
624   StorageUniquer &uniquer = dialect->getUniquer();
625   return uniquer.get<detail::SDBMTermExprStorage>(
626       assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
627 }
628 
629 //===----------------------------------------------------------------------===//
630 // SDBMConstantExpr
631 //===----------------------------------------------------------------------===//
632 
get(SDBMDialect * dialect,int64_t value)633 SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
634   assert(dialect && "expected non-null dialect");
635 
636   auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
637     storage->dialect = dialect;
638   };
639 
640   StorageUniquer &uniquer = dialect->getUniquer();
641   return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
642 }
643 
getValue() const644 int64_t SDBMConstantExpr::getValue() const {
645   return static_cast<ImplType *>(impl)->constant;
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // SDBMNegExpr
650 //===----------------------------------------------------------------------===//
651 
get(SDBMDirectExpr var)652 SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
653   assert(var && "expected non-null SDBM direct expression");
654 
655   StorageUniquer &uniquer = var.getDialect()->getUniquer();
656   return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
657 }
658 
getVar() const659 SDBMDirectExpr SDBMNegExpr::getVar() const {
660   return static_cast<ImplType *>(impl)->expr;
661 }
662 
operator +(SDBMExpr lhs,SDBMExpr rhs)663 SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) {
664   if (auto folded = foldSumDiff(lhs, rhs))
665     return folded;
666   assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
667          "a sum of negated expressions is a negation of a sum of variables and "
668          "not a correct SDBM");
669 
670   // Fold (x - y) + (y - x) = 0.
671   auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
672   auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
673   if (lhsDiff && rhsDiff) {
674     if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
675         lhsDiff.getRHS() == rhsDiff.getLHS())
676       return SDBMConstantExpr::get(lhs.getDialect(), 0);
677   }
678 
679   // If LHS is a constant and RHS is not, swap the order to get into a supported
680   // sum case.  From now on, RHS must be a constant.
681   auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
682   auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
683   if (!rhsConstant && lhsConstant) {
684     std::swap(lhs, rhs);
685     std::swap(lhsConstant, rhsConstant);
686   }
687   assert(rhsConstant && "at least one operand must be a constant");
688 
689   // Constant-fold if LHS is also a constant.
690   if (lhsConstant)
691     return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
692                                                        rhsConstant.getValue());
693   return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
694 }
695 
operator -(SDBMExpr lhs,SDBMExpr rhs)696 SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) {
697   // Fold x - x == 0.
698   if (lhs == rhs)
699     return SDBMConstantExpr::get(lhs.getDialect(), 0);
700 
701   // LHS and RHS may be constants.
702   auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
703   auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
704 
705   // Constant fold if both LHS and RHS are constants.
706   if (lhsConstant && rhsConstant)
707     return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
708                                                        rhsConstant.getValue());
709 
710   // Replace a difference with a sum with a negated value if one of LHS and RHS
711   // is a constant:
712   //   x - C == x + (-C);
713   //   C - x == -x + C.
714   // This calls into operator+ for further simplification.
715   if (rhsConstant)
716     return lhs + (-rhsConstant);
717   if (lhsConstant)
718     return -rhs + lhsConstant;
719 
720   return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
721 }
722 
stripe(SDBMExpr expr,SDBMExpr factor)723 SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) {
724   auto constantFactor = factor.cast<SDBMConstantExpr>();
725   assert(constantFactor.getValue() > 0 && "non-positive stripe");
726 
727   // Fold x # 1 = x.
728   if (constantFactor.getValue() == 1)
729     return expr;
730 
731   return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
732 }
733