1 //===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A striped difference-bound matrix (SDBM) expression is a constant expression,
10 // an identifier, a binary expression with constant RHS and +, stripe operators
11 // or a difference expression between two identifiers.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/SDBM/SDBMExpr.h"
16 #include "SDBMExprDetail.h"
17 #include "mlir/Dialect/SDBM/SDBMDialect.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20
21 #include "llvm/Support/raw_ostream.h"
22
23 using namespace mlir;
24
25 namespace {
26 /// A simple compositional matcher for AffineExpr
27 ///
28 /// Example usage:
29 ///
30 /// ```c++
31 /// AffineExprMatcher x, C, m;
32 /// AffineExprMatcher pattern1 = ((x % C) * m) + x;
33 /// AffineExprMatcher pattern2 = x + ((x % C) * m);
34 /// if (pattern1.match(expr) || pattern2.match(expr)) {
35 /// ...
36 /// }
37 /// ```
38 class AffineExprMatcherStorage;
39 class AffineExprMatcher {
40 public:
41 AffineExprMatcher();
42 AffineExprMatcher(const AffineExprMatcher &other);
43
operator +(AffineExprMatcher other)44 AffineExprMatcher operator+(AffineExprMatcher other) {
45 return AffineExprMatcher(AffineExprKind::Add, *this, other);
46 }
operator *(AffineExprMatcher other)47 AffineExprMatcher operator*(AffineExprMatcher other) {
48 return AffineExprMatcher(AffineExprKind::Mul, *this, other);
49 }
floorDiv(AffineExprMatcher other)50 AffineExprMatcher floorDiv(AffineExprMatcher other) {
51 return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
52 }
ceilDiv(AffineExprMatcher other)53 AffineExprMatcher ceilDiv(AffineExprMatcher other) {
54 return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
55 }
operator %(AffineExprMatcher other)56 AffineExprMatcher operator%(AffineExprMatcher other) {
57 return AffineExprMatcher(AffineExprKind::Mod, *this, other);
58 }
59
60 AffineExpr match(AffineExpr expr);
61 AffineExpr matched();
62 Optional<int> getMatchedConstantValue();
63
64 private:
65 AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
66 AffineExprKind kind; // only used to match in binary op cases.
67 // A shared_ptr allows multiple references to same matcher storage without
68 // worrying about ownership or dealing with an arena. To be cleaned up if we
69 // go with this.
70 std::shared_ptr<AffineExprMatcherStorage> storage;
71 };
72
73 class AffineExprMatcherStorage {
74 public:
AffineExprMatcherStorage()75 AffineExprMatcherStorage() {}
AffineExprMatcherStorage(const AffineExprMatcherStorage & other)76 AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
77 : subExprs(other.subExprs.begin(), other.subExprs.end()),
78 matched(other.matched) {}
AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)79 AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
80 : subExprs(exprs.begin(), exprs.end()) {}
AffineExprMatcherStorage(AffineExprMatcher & a,AffineExprMatcher & b)81 AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
82 : subExprs({a, b}) {}
83 SmallVector<AffineExprMatcher, 0> subExprs;
84 AffineExpr matched;
85 };
86 } // namespace
87
AffineExprMatcher()88 AffineExprMatcher::AffineExprMatcher()
89 : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
90
AffineExprMatcher(const AffineExprMatcher & other)91 AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
92 : kind(other.kind), storage(other.storage) {}
93
getMatchedConstantValue()94 Optional<int> AffineExprMatcher::getMatchedConstantValue() {
95 if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
96 return cst.getValue();
97 return None;
98 }
99
match(AffineExpr expr)100 AffineExpr AffineExprMatcher::match(AffineExpr expr) {
101 if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
102 if (storage->matched)
103 if (storage->matched != expr)
104 return AffineExpr();
105 storage->matched = expr;
106 return storage->matched;
107 }
108 if (kind != expr.getKind()) {
109 return AffineExpr();
110 }
111 if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
112 if (!storage->subExprs.empty() &&
113 !storage->subExprs[0].match(bin.getLHS())) {
114 return AffineExpr();
115 }
116 if (!storage->subExprs.empty() &&
117 !storage->subExprs[1].match(bin.getRHS())) {
118 return AffineExpr();
119 }
120 if (storage->matched)
121 if (storage->matched != expr)
122 return AffineExpr();
123 storage->matched = expr;
124 return storage->matched;
125 }
126 llvm_unreachable("binary expected");
127 }
128
matched()129 AffineExpr AffineExprMatcher::matched() { return storage->matched; }
130
AffineExprMatcher(AffineExprKind k,AffineExprMatcher a,AffineExprMatcher b)131 AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
132 AffineExprMatcher b)
133 : kind(k), storage(new AffineExprMatcherStorage(a, b)) {
134 storage->subExprs.push_back(a);
135 storage->subExprs.push_back(b);
136 }
137
138 //===----------------------------------------------------------------------===//
139 // SDBMExpr
140 //===----------------------------------------------------------------------===//
141
getKind() const142 SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
143
getContext() const144 MLIRContext *SDBMExpr::getContext() const {
145 return impl->dialect->getContext();
146 }
147
getDialect() const148 SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
149
print(raw_ostream & os) const150 void SDBMExpr::print(raw_ostream &os) const {
151 struct Printer : public SDBMVisitor<Printer> {
152 Printer(raw_ostream &ostream) : prn(ostream) {}
153
154 void visitSum(SDBMSumExpr expr) {
155 visit(expr.getLHS());
156 prn << " + ";
157 visit(expr.getRHS());
158 }
159 void visitDiff(SDBMDiffExpr expr) {
160 visit(expr.getLHS());
161 prn << " - ";
162 visit(expr.getRHS());
163 }
164 void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
165 void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
166 void visitStripe(SDBMStripeExpr expr) {
167 SDBMDirectExpr lhs = expr.getLHS();
168 bool isTerm = lhs.isa<SDBMTermExpr>();
169 if (!isTerm)
170 prn << '(';
171 visit(lhs);
172 if (!isTerm)
173 prn << ')';
174 prn << " # ";
175 visitConstant(expr.getStripeFactor());
176 }
177 void visitNeg(SDBMNegExpr expr) {
178 bool isSum = expr.getVar().isa<SDBMSumExpr>();
179 prn << '-';
180 if (isSum)
181 prn << '(';
182 visit(expr.getVar());
183 if (isSum)
184 prn << ')';
185 }
186 void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
187
188 raw_ostream &prn;
189 };
190 Printer printer(os);
191 printer.visit(*this);
192 }
193
dump() const194 void SDBMExpr::dump() const {
195 print(llvm::errs());
196 llvm::errs() << '\n';
197 }
198
199 namespace {
200 // Helper class to perform negation of an SDBM expression.
201 struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
202 // Any term expression is wrapped into a negation expression.
203 // -(x) = -x
visitDirect__anon65c8f2b80211::SDBMNegator204 SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
205 // A negation expression is unwrapped.
206 // -(-x) = x
visitNeg__anon65c8f2b80211::SDBMNegator207 SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
208 // The value of the constant is negated.
visitConstant__anon65c8f2b80211::SDBMNegator209 SDBMExpr visitConstant(SDBMConstantExpr expr) {
210 return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
211 }
212
213 // Terms of a difference are interchanged. Since only the LHS of a diff
214 // expression is allowed to be a sum with a constant, we need to recreate the
215 // sum with the negated value:
216 // -((x + C) - y) = (y - C) - x.
visitDiff__anon65c8f2b80211::SDBMNegator217 SDBMExpr visitDiff(SDBMDiffExpr expr) {
218 // If the LHS is just a term, we can do straightforward interchange.
219 if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
220 return SDBMDiffExpr::get(expr.getRHS(), term);
221
222 auto sum = expr.getLHS().cast<SDBMSumExpr>();
223 auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
224 return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
225 sum.getLHS());
226 }
227 };
228 } // namespace
229
operator -()230 SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
231
232 //===----------------------------------------------------------------------===//
233 // SDBMSumExpr
234 //===----------------------------------------------------------------------===//
235
get(SDBMTermExpr lhs,SDBMConstantExpr rhs)236 SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
237 assert(lhs && "expected SDBM variable expression");
238 assert(rhs && "expected SDBM constant");
239
240 // If LHS of a sum is another sum, fold the constant RHS parts.
241 if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
242 lhs = lhsSum.getLHS();
243 rhs = SDBMConstantExpr::get(rhs.getDialect(),
244 rhs.getValue() + lhsSum.getRHS().getValue());
245 }
246
247 StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
248 return uniquer.get<detail::SDBMBinaryExprStorage>(
249 /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
250 }
251
getLHS() const252 SDBMTermExpr SDBMSumExpr::getLHS() const {
253 return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
254 }
255
getRHS() const256 SDBMConstantExpr SDBMSumExpr::getRHS() const {
257 return static_cast<ImplType *>(impl)->rhs;
258 }
259
getAsAffineExpr() const260 AffineExpr SDBMExpr::getAsAffineExpr() const {
261 struct Converter : public SDBMVisitor<Converter, AffineExpr> {
262 AffineExpr visitSum(SDBMSumExpr expr) {
263 AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
264 return lhs + rhs;
265 }
266
267 AffineExpr visitStripe(SDBMStripeExpr expr) {
268 AffineExpr lhs = visit(expr.getLHS()),
269 rhs = visit(expr.getStripeFactor());
270 return lhs - (lhs % rhs);
271 }
272
273 AffineExpr visitDiff(SDBMDiffExpr expr) {
274 AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
275 return lhs - rhs;
276 }
277
278 AffineExpr visitDim(SDBMDimExpr expr) {
279 return getAffineDimExpr(expr.getPosition(), expr.getContext());
280 }
281
282 AffineExpr visitSymbol(SDBMSymbolExpr expr) {
283 return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
284 }
285
286 AffineExpr visitNeg(SDBMNegExpr expr) {
287 return getAffineBinaryOpExpr(AffineExprKind::Mul,
288 getAffineConstantExpr(-1, expr.getContext()),
289 visit(expr.getVar()));
290 }
291
292 AffineExpr visitConstant(SDBMConstantExpr expr) {
293 return getAffineConstantExpr(expr.getValue(), expr.getContext());
294 }
295 } converter;
296 return converter.visit(*this);
297 }
298
299 // Given a direct expression `expr`, add the given constant to it and pass the
300 // resulting expression to `builder` before returning its result. If the
301 // expression is already a sum expression, update its constant and extract the
302 // LHS if the constant becomes zero. Otherwise, construct a sum expression.
303 template <typename Result>
addConstantAndSink(SDBMDirectExpr expr,int64_t constant,bool negated,function_ref<Result (SDBMDirectExpr)> builder)304 static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
305 bool negated,
306 function_ref<Result(SDBMDirectExpr)> builder) {
307 SDBMDialect *dialect = expr.getDialect();
308 if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
309 if (negated)
310 constant = sumExpr.getRHS().getValue() - constant;
311 else
312 constant += sumExpr.getRHS().getValue();
313
314 if (constant != 0) {
315 auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
316 SDBMConstantExpr::get(dialect, constant));
317 return builder(sum);
318 } else {
319 return builder(sumExpr.getLHS());
320 }
321 }
322 if (constant != 0)
323 return builder(SDBMSumExpr::get(
324 expr.cast<SDBMTermExpr>(),
325 SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
326 return expr;
327 }
328
329 // Construct an expression lhs + constant while maintaining the canonical form
330 // of the SDBM expressions, in particular sink the constant expression to the
331 // nearest sum expression in the left subtree of the expression tree.
addConstant(SDBMVaryingExpr lhs,int64_t constant)332 static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
333 if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
334 return addConstantAndSink<SDBMExpr>(
335 lhsDiff.getLHS(), constant, /*negated=*/false,
336 [lhsDiff](SDBMDirectExpr e) {
337 return SDBMDiffExpr::get(e, lhsDiff.getRHS());
338 });
339 if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
340 return addConstantAndSink<SDBMExpr>(
341 lhsNeg.getVar(), constant, /*negated=*/true,
342 [](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
343 if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
344 return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
345 [](SDBMDirectExpr e) { return e; });
346 if (constant != 0)
347 return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
348 SDBMConstantExpr::get(lhs.getDialect(), constant));
349 return lhs;
350 }
351
352 // Build a difference expression given a direct expression and a negation
353 // expression.
buildDiffExpr(SDBMDirectExpr lhs,SDBMNegExpr rhs)354 static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
355 // Fold (x + C) - (x + D) = C - D.
356 if (lhs.getTerm() == rhs.getVar().getTerm())
357 return SDBMConstantExpr::get(
358 lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
359
360 return SDBMDiffExpr::get(
361 addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
362 /*negated=*/false,
363 [](SDBMDirectExpr e) { return e; }),
364 rhs.getVar().getTerm());
365 }
366
367 // Try folding an expression (lhs + rhs) where at least one of the operands
368 // contains a negated variable, i.e. is a negation or a difference expression.
foldSumDiff(SDBMExpr lhs,SDBMExpr rhs)369 static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
370 // If exactly one of LHS, RHS is a negation expression, we can construct
371 // a difference expression, which is a special kind in SDBM.
372 auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
373 auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
374 auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
375 auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
376
377 if (lhsDirect && rhsNeg)
378 return buildDiffExpr(lhsDirect, rhsNeg);
379 if (lhsNeg && rhsDirect)
380 return buildDiffExpr(rhsDirect, lhsNeg);
381
382 // If a subexpression appears in a diff expression on the LHS(RHS) of a
383 // sum expression where it also appears on the RHS(LHS) with the opposite
384 // sign, we can simplify it away and obtain the SDBM form.
385 auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
386 auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
387
388 // -(x + A) + ((x + B) - y) = -(y + (A - B))
389 if (lhsNeg && rhsDiff &&
390 lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
391 int64_t constant =
392 lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
393 // RHS of the diff is a term expression, its sum with a constant is a direct
394 // expression.
395 return SDBMNegExpr::get(
396 addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
397 }
398
399 // (x + A) + ((y + B) - x) = (y + B) + A.
400 if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
401 return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
402
403 // ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
404 if (lhsDiff && rhsNeg &&
405 lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
406 int64_t constant =
407 rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
408 // RHS of the diff is a term expression, its sum with a constant is a direct
409 // expression.
410 return SDBMNegExpr::get(
411 addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
412 }
413
414 // ((x + A) - y) + (y + B) = (x + A) + B.
415 if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
416 return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
417
418 return {};
419 }
420
tryConvertAffineExpr(AffineExpr affine)421 Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
422 struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
423 SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
424 auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
425 if (!lhs || !rhs)
426 return {};
427
428 // In a "add" AffineExpr, the constant always appears on the right. If
429 // there were two constants, they would have been folded away.
430 assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
431
432 // If RHS is a constant, we can always extend the SDBM expression to
433 // include it by sinking the constant into the nearest sum expression.
434 if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
435 int64_t constant = rhsConstant.getValue();
436 auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
437 assert(varying && "unexpected uncanonicalized sum of constants");
438 return addConstant(varying, constant);
439 }
440
441 // Try building a difference expression if one of the values is negated,
442 // or check if a difference on either hand side cancels out the outer term
443 // so as to remain correct within SDBM. Return null otherwise.
444 return foldSumDiff(lhs, rhs);
445 }
446
447 SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
448 // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
449 AffineExprMatcher x, C;
450 AffineExprMatcher pattern = (x.floorDiv(C)) * C;
451 if (pattern.match(expr)) {
452 if (SDBMExpr converted = visit(x.matched())) {
453 if (auto varConverted = converted.dyn_cast<SDBMTermExpr>())
454 // TODO: return varConverted.stripe(C.getConstantValue());
455 return SDBMStripeExpr::get(
456 varConverted,
457 SDBMConstantExpr::get(dialect,
458 C.getMatchedConstantValue().getValue()));
459 }
460 }
461
462 auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
463 if (!lhs || !rhs)
464 return {};
465
466 // In a "mul" AffineExpr, the constant always appears on the right. If
467 // there were two constants, they would have been folded away.
468 assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
469 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
470 if (!rhsConstant)
471 return {};
472
473 // The only supported "multiplication" expression is an SDBM is dimension
474 // negation, that is a product of dimension and constant -1.
475 if (rhsConstant.getValue() != -1)
476 return {};
477
478 if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
479 return SDBMNegExpr::get(lhsVar);
480 if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
481 return SDBMNegator().visitDiff(lhsDiff);
482
483 // Other multiplications are not allowed in SDBM.
484 return {};
485 }
486
487 SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
488 auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
489 if (!lhs || !rhs)
490 return {};
491
492 // 'mod' can only be converted to SDBM if its LHS is a direct expression
493 // and its RHS is a constant. Then it `x mod c = x - x stripe c`.
494 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
495 auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
496 if (!lhsVar || !rhsConstant)
497 return {};
498 return SDBMDiffExpr::get(lhsVar,
499 SDBMStripeExpr::get(lhsVar, rhsConstant));
500 }
501
502 // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
503 SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
504 SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
505
506 // Dimensions, symbols and constants are converted trivially.
507 SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
508 return SDBMConstantExpr::get(dialect, expr.getValue());
509 }
510 SDBMExpr visitDimExpr(AffineDimExpr expr) {
511 return SDBMDimExpr::get(dialect, expr.getPosition());
512 }
513 SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
514 return SDBMSymbolExpr::get(dialect, expr.getPosition());
515 }
516
517 SDBMDialect *dialect;
518 } converter;
519 converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
520
521 if (auto result = converter.visit(affine))
522 return result;
523 return None;
524 }
525
526 //===----------------------------------------------------------------------===//
527 // SDBMDiffExpr
528 //===----------------------------------------------------------------------===//
529
get(SDBMDirectExpr lhs,SDBMTermExpr rhs)530 SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
531 assert(lhs && "expected SDBM dimension");
532 assert(rhs && "expected SDBM dimension");
533
534 StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
535 return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
536 }
537
getLHS() const538 SDBMDirectExpr SDBMDiffExpr::getLHS() const {
539 return static_cast<ImplType *>(impl)->lhs;
540 }
541
getRHS() const542 SDBMTermExpr SDBMDiffExpr::getRHS() const {
543 return static_cast<ImplType *>(impl)->rhs;
544 }
545
546 //===----------------------------------------------------------------------===//
547 // SDBMDirectExpr
548 //===----------------------------------------------------------------------===//
549
getTerm()550 SDBMTermExpr SDBMDirectExpr::getTerm() {
551 if (auto sum = dyn_cast<SDBMSumExpr>())
552 return sum.getLHS();
553 return cast<SDBMTermExpr>();
554 }
555
getConstant()556 int64_t SDBMDirectExpr::getConstant() {
557 if (auto sum = dyn_cast<SDBMSumExpr>())
558 return sum.getRHS().getValue();
559 return 0;
560 }
561
562 //===----------------------------------------------------------------------===//
563 // SDBMStripeExpr
564 //===----------------------------------------------------------------------===//
565
get(SDBMDirectExpr var,SDBMConstantExpr stripeFactor)566 SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
567 SDBMConstantExpr stripeFactor) {
568 assert(var && "expected SDBM variable expression");
569 assert(stripeFactor && "expected non-null stripe factor");
570 if (stripeFactor.getValue() <= 0)
571 llvm::report_fatal_error("non-positive stripe factor");
572
573 StorageUniquer &uniquer = var.getDialect()->getUniquer();
574 return uniquer.get<detail::SDBMBinaryExprStorage>(
575 /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
576 stripeFactor);
577 }
578
getLHS() const579 SDBMDirectExpr SDBMStripeExpr::getLHS() const {
580 if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
581 return lhs.cast<SDBMDirectExpr>();
582 return {};
583 }
584
getStripeFactor() const585 SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
586 return static_cast<ImplType *>(impl)->rhs;
587 }
588
589 //===----------------------------------------------------------------------===//
590 // SDBMInputExpr
591 //===----------------------------------------------------------------------===//
592
getPosition() const593 unsigned SDBMInputExpr::getPosition() const {
594 return static_cast<ImplType *>(impl)->position;
595 }
596
597 //===----------------------------------------------------------------------===//
598 // SDBMDimExpr
599 //===----------------------------------------------------------------------===//
600
get(SDBMDialect * dialect,unsigned position)601 SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
602 assert(dialect && "expected non-null dialect");
603
604 auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
605 storage->dialect = dialect;
606 };
607
608 StorageUniquer &uniquer = dialect->getUniquer();
609 return uniquer.get<detail::SDBMTermExprStorage>(
610 assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
611 }
612
613 //===----------------------------------------------------------------------===//
614 // SDBMSymbolExpr
615 //===----------------------------------------------------------------------===//
616
get(SDBMDialect * dialect,unsigned position)617 SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
618 assert(dialect && "expected non-null dialect");
619
620 auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
621 storage->dialect = dialect;
622 };
623
624 StorageUniquer &uniquer = dialect->getUniquer();
625 return uniquer.get<detail::SDBMTermExprStorage>(
626 assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
627 }
628
629 //===----------------------------------------------------------------------===//
630 // SDBMConstantExpr
631 //===----------------------------------------------------------------------===//
632
get(SDBMDialect * dialect,int64_t value)633 SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
634 assert(dialect && "expected non-null dialect");
635
636 auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
637 storage->dialect = dialect;
638 };
639
640 StorageUniquer &uniquer = dialect->getUniquer();
641 return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
642 }
643
getValue() const644 int64_t SDBMConstantExpr::getValue() const {
645 return static_cast<ImplType *>(impl)->constant;
646 }
647
648 //===----------------------------------------------------------------------===//
649 // SDBMNegExpr
650 //===----------------------------------------------------------------------===//
651
get(SDBMDirectExpr var)652 SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
653 assert(var && "expected non-null SDBM direct expression");
654
655 StorageUniquer &uniquer = var.getDialect()->getUniquer();
656 return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
657 }
658
getVar() const659 SDBMDirectExpr SDBMNegExpr::getVar() const {
660 return static_cast<ImplType *>(impl)->expr;
661 }
662
operator +(SDBMExpr lhs,SDBMExpr rhs)663 SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) {
664 if (auto folded = foldSumDiff(lhs, rhs))
665 return folded;
666 assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
667 "a sum of negated expressions is a negation of a sum of variables and "
668 "not a correct SDBM");
669
670 // Fold (x - y) + (y - x) = 0.
671 auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
672 auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
673 if (lhsDiff && rhsDiff) {
674 if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
675 lhsDiff.getRHS() == rhsDiff.getLHS())
676 return SDBMConstantExpr::get(lhs.getDialect(), 0);
677 }
678
679 // If LHS is a constant and RHS is not, swap the order to get into a supported
680 // sum case. From now on, RHS must be a constant.
681 auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
682 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
683 if (!rhsConstant && lhsConstant) {
684 std::swap(lhs, rhs);
685 std::swap(lhsConstant, rhsConstant);
686 }
687 assert(rhsConstant && "at least one operand must be a constant");
688
689 // Constant-fold if LHS is also a constant.
690 if (lhsConstant)
691 return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
692 rhsConstant.getValue());
693 return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
694 }
695
operator -(SDBMExpr lhs,SDBMExpr rhs)696 SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) {
697 // Fold x - x == 0.
698 if (lhs == rhs)
699 return SDBMConstantExpr::get(lhs.getDialect(), 0);
700
701 // LHS and RHS may be constants.
702 auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
703 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
704
705 // Constant fold if both LHS and RHS are constants.
706 if (lhsConstant && rhsConstant)
707 return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
708 rhsConstant.getValue());
709
710 // Replace a difference with a sum with a negated value if one of LHS and RHS
711 // is a constant:
712 // x - C == x + (-C);
713 // C - x == -x + C.
714 // This calls into operator+ for further simplification.
715 if (rhsConstant)
716 return lhs + (-rhsConstant);
717 if (lhsConstant)
718 return -rhs + lhsConstant;
719
720 return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
721 }
722
stripe(SDBMExpr expr,SDBMExpr factor)723 SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) {
724 auto constantFactor = factor.cast<SDBMConstantExpr>();
725 assert(constantFactor.getValue() > 0 && "non-positive stripe");
726
727 // Fold x # 1 = x.
728 if (constantFactor.getValue() == 1)
729 return expr;
730
731 return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
732 }
733