1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineExpr.h"
10 #include "AffineExprDetail.h"
11 #include "mlir/IR/AffineExprVisitor.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "mlir/Support/TypeID.h"
16 #include "llvm/ADT/STLExtras.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
getContext() const21 MLIRContext *AffineExpr::getContext() const { return expr->context; }
22 
getKind() const23 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
24 
25 /// Walk all of the AffineExprs in this subgraph in postorder.
walk(std::function<void (AffineExpr)> callback) const26 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
27   struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
28     std::function<void(AffineExpr)> callback;
29 
30     AffineExprWalker(std::function<void(AffineExpr)> callback)
31         : callback(callback) {}
32 
33     void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
34     void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
35     void visitDimExpr(AffineDimExpr expr) { callback(expr); }
36     void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
37   };
38 
39   AffineExprWalker(callback).walkPostOrder(*this);
40 }
41 
42 // Dispatch affine expression construction based on kind.
getAffineBinaryOpExpr(AffineExprKind kind,AffineExpr lhs,AffineExpr rhs)43 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
44                                        AffineExpr rhs) {
45   if (kind == AffineExprKind::Add)
46     return lhs + rhs;
47   if (kind == AffineExprKind::Mul)
48     return lhs * rhs;
49   if (kind == AffineExprKind::FloorDiv)
50     return lhs.floorDiv(rhs);
51   if (kind == AffineExprKind::CeilDiv)
52     return lhs.ceilDiv(rhs);
53   if (kind == AffineExprKind::Mod)
54     return lhs % rhs;
55 
56   llvm_unreachable("unknown binary operation on affine expressions");
57 }
58 
59 /// This method substitutes any uses of dimensions and symbols (e.g.
60 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
61 AffineExpr
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements) const62 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
63                                   ArrayRef<AffineExpr> symReplacements) const {
64   switch (getKind()) {
65   case AffineExprKind::Constant:
66     return *this;
67   case AffineExprKind::DimId: {
68     unsigned dimId = cast<AffineDimExpr>().getPosition();
69     if (dimId >= dimReplacements.size())
70       return *this;
71     return dimReplacements[dimId];
72   }
73   case AffineExprKind::SymbolId: {
74     unsigned symId = cast<AffineSymbolExpr>().getPosition();
75     if (symId >= symReplacements.size())
76       return *this;
77     return symReplacements[symId];
78   }
79   case AffineExprKind::Add:
80   case AffineExprKind::Mul:
81   case AffineExprKind::FloorDiv:
82   case AffineExprKind::CeilDiv:
83   case AffineExprKind::Mod:
84     auto binOp = cast<AffineBinaryOpExpr>();
85     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
86     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
87     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
88     if (newLHS == lhs && newRHS == rhs)
89       return *this;
90     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
91   }
92   llvm_unreachable("Unknown AffineExpr");
93 }
94 
95 /// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
shiftSymbols(unsigned numSymbols,unsigned shift) const96 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
97   SmallVector<AffineExpr, 4> symbols;
98   for (unsigned idx = 0; idx < numSymbols; ++idx)
99     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
100   return replaceDimsAndSymbols({}, symbols);
101 }
102 
103 /// Sparse replace method. Return the modified expression tree.
104 AffineExpr
replace(const DenseMap<AffineExpr,AffineExpr> & map) const105 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
106   auto it = map.find(*this);
107   if (it != map.end())
108     return it->second;
109   switch (getKind()) {
110   default:
111     return *this;
112   case AffineExprKind::Add:
113   case AffineExprKind::Mul:
114   case AffineExprKind::FloorDiv:
115   case AffineExprKind::CeilDiv:
116   case AffineExprKind::Mod:
117     auto binOp = cast<AffineBinaryOpExpr>();
118     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
119     auto newLHS = lhs.replace(map);
120     auto newRHS = rhs.replace(map);
121     if (newLHS == lhs && newRHS == rhs)
122       return *this;
123     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
124   }
125   llvm_unreachable("Unknown AffineExpr");
126 }
127 
128 /// Sparse replace method. Return the modified expression tree.
replace(AffineExpr expr,AffineExpr replacement) const129 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
130   DenseMap<AffineExpr, AffineExpr> map;
131   map.insert(std::make_pair(expr, replacement));
132   return replace(map);
133 }
134 /// Returns true if this expression is made out of only symbols and
135 /// constants (no dimensional identifiers).
isSymbolicOrConstant() const136 bool AffineExpr::isSymbolicOrConstant() const {
137   switch (getKind()) {
138   case AffineExprKind::Constant:
139     return true;
140   case AffineExprKind::DimId:
141     return false;
142   case AffineExprKind::SymbolId:
143     return true;
144 
145   case AffineExprKind::Add:
146   case AffineExprKind::Mul:
147   case AffineExprKind::FloorDiv:
148   case AffineExprKind::CeilDiv:
149   case AffineExprKind::Mod: {
150     auto expr = this->cast<AffineBinaryOpExpr>();
151     return expr.getLHS().isSymbolicOrConstant() &&
152            expr.getRHS().isSymbolicOrConstant();
153   }
154   }
155   llvm_unreachable("Unknown AffineExpr");
156 }
157 
158 /// Returns true if this is a pure affine expression, i.e., multiplication,
159 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
isPureAffine() const160 bool AffineExpr::isPureAffine() const {
161   switch (getKind()) {
162   case AffineExprKind::SymbolId:
163   case AffineExprKind::DimId:
164   case AffineExprKind::Constant:
165     return true;
166   case AffineExprKind::Add: {
167     auto op = cast<AffineBinaryOpExpr>();
168     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
169   }
170 
171   case AffineExprKind::Mul: {
172     // TODO: Canonicalize the constants in binary operators to the RHS when
173     // possible, allowing this to merge into the next case.
174     auto op = cast<AffineBinaryOpExpr>();
175     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
176            (op.getLHS().template isa<AffineConstantExpr>() ||
177             op.getRHS().template isa<AffineConstantExpr>());
178   }
179   case AffineExprKind::FloorDiv:
180   case AffineExprKind::CeilDiv:
181   case AffineExprKind::Mod: {
182     auto op = cast<AffineBinaryOpExpr>();
183     return op.getLHS().isPureAffine() &&
184            op.getRHS().template isa<AffineConstantExpr>();
185   }
186   }
187   llvm_unreachable("Unknown AffineExpr");
188 }
189 
190 // Returns the greatest known integral divisor of this affine expression.
getLargestKnownDivisor() const191 int64_t AffineExpr::getLargestKnownDivisor() const {
192   AffineBinaryOpExpr binExpr(nullptr);
193   switch (getKind()) {
194   case AffineExprKind::SymbolId:
195     LLVM_FALLTHROUGH;
196   case AffineExprKind::DimId:
197     return 1;
198   case AffineExprKind::Constant:
199     return std::abs(this->cast<AffineConstantExpr>().getValue());
200   case AffineExprKind::Mul: {
201     binExpr = this->cast<AffineBinaryOpExpr>();
202     return binExpr.getLHS().getLargestKnownDivisor() *
203            binExpr.getRHS().getLargestKnownDivisor();
204   }
205   case AffineExprKind::Add:
206     LLVM_FALLTHROUGH;
207   case AffineExprKind::FloorDiv:
208   case AffineExprKind::CeilDiv:
209   case AffineExprKind::Mod: {
210     binExpr = cast<AffineBinaryOpExpr>();
211     return llvm::GreatestCommonDivisor64(
212         binExpr.getLHS().getLargestKnownDivisor(),
213         binExpr.getRHS().getLargestKnownDivisor());
214   }
215   }
216   llvm_unreachable("Unknown AffineExpr");
217 }
218 
isMultipleOf(int64_t factor) const219 bool AffineExpr::isMultipleOf(int64_t factor) const {
220   AffineBinaryOpExpr binExpr(nullptr);
221   uint64_t l, u;
222   switch (getKind()) {
223   case AffineExprKind::SymbolId:
224     LLVM_FALLTHROUGH;
225   case AffineExprKind::DimId:
226     return factor * factor == 1;
227   case AffineExprKind::Constant:
228     return cast<AffineConstantExpr>().getValue() % factor == 0;
229   case AffineExprKind::Mul: {
230     binExpr = cast<AffineBinaryOpExpr>();
231     // It's probably not worth optimizing this further (to not traverse the
232     // whole sub-tree under - it that would require a version of isMultipleOf
233     // that on a 'false' return also returns the largest known divisor).
234     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
235            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
236            (l * u) % factor == 0;
237   }
238   case AffineExprKind::Add:
239   case AffineExprKind::FloorDiv:
240   case AffineExprKind::CeilDiv:
241   case AffineExprKind::Mod: {
242     binExpr = cast<AffineBinaryOpExpr>();
243     return llvm::GreatestCommonDivisor64(
244                binExpr.getLHS().getLargestKnownDivisor(),
245                binExpr.getRHS().getLargestKnownDivisor()) %
246                factor ==
247            0;
248   }
249   }
250   llvm_unreachable("Unknown AffineExpr");
251 }
252 
isFunctionOfDim(unsigned position) const253 bool AffineExpr::isFunctionOfDim(unsigned position) const {
254   if (getKind() == AffineExprKind::DimId) {
255     return *this == mlir::getAffineDimExpr(position, getContext());
256   }
257   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
258     return expr.getLHS().isFunctionOfDim(position) ||
259            expr.getRHS().isFunctionOfDim(position);
260   }
261   return false;
262 }
263 
AffineBinaryOpExpr(AffineExpr::ImplType * ptr)264 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
265     : AffineExpr(ptr) {}
getLHS() const266 AffineExpr AffineBinaryOpExpr::getLHS() const {
267   return static_cast<ImplType *>(expr)->lhs;
268 }
getRHS() const269 AffineExpr AffineBinaryOpExpr::getRHS() const {
270   return static_cast<ImplType *>(expr)->rhs;
271 }
272 
AffineDimExpr(AffineExpr::ImplType * ptr)273 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
getPosition() const274 unsigned AffineDimExpr::getPosition() const {
275   return static_cast<ImplType *>(expr)->position;
276 }
277 
278 /// Returns true if the expression is divisible by the given symbol with
279 /// position `symbolPos`. The argument `opKind` specifies here what kind of
280 /// division or mod operation called this division. It helps in implementing the
281 /// commutative property of the floordiv and ceildiv operations. If the argument
282 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
283 /// operation, then the commutative property can be used otherwise, the floordiv
284 /// operation is not divisible. The same argument holds for ceildiv operation.
isDivisibleBySymbol(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)285 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
286                                 AffineExprKind opKind) {
287   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
288   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
289           opKind == AffineExprKind::CeilDiv) &&
290          "unexpected opKind");
291   switch (expr.getKind()) {
292   case AffineExprKind::Constant:
293     if (expr.cast<AffineConstantExpr>().getValue())
294       return false;
295     return true;
296   case AffineExprKind::DimId:
297     return false;
298   case AffineExprKind::SymbolId:
299     return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
300   // Checks divisibility by the given symbol for both operands.
301   case AffineExprKind::Add: {
302     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
303     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
304            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
305   }
306   // Checks divisibility by the given symbol for both operands. Consider the
307   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
308   // this is a division by s1 and both the operands of modulo are divisible by
309   // s1 but it is not divisible by s1 always. The third argument is
310   // `AffineExprKind::Mod` for this reason.
311   case AffineExprKind::Mod: {
312     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
313     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
314                                AffineExprKind::Mod) &&
315            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
316                                AffineExprKind::Mod);
317   }
318   // Checks if any of the operand divisible by the given symbol.
319   case AffineExprKind::Mul: {
320     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
321     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
322            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
323   }
324   // Floordiv and ceildiv are divisible by the given symbol when the first
325   // operand is divisible, and the affine expression kind of the argument expr
326   // is same as the argument `opKind`. This can be inferred from commutative
327   // property of floordiv and ceildiv operations and are as follow:
328   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
329   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
330   // It will fail if operations are not same. For example:
331   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
332   case AffineExprKind::FloorDiv:
333   case AffineExprKind::CeilDiv: {
334     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
335     if (opKind != expr.getKind())
336       return false;
337     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
338   }
339   }
340   llvm_unreachable("Unknown AffineExpr");
341 }
342 
343 /// Divides the given expression by the given symbol at position `symbolPos`. It
344 /// considers the divisibility condition is checked before calling itself. A
345 /// null expression is returned whenever the divisibility condition fails.
symbolicDivide(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)346 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
347                                  AffineExprKind opKind) {
348   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
349   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
350           opKind == AffineExprKind::CeilDiv) &&
351          "unexpected opKind");
352   switch (expr.getKind()) {
353   case AffineExprKind::Constant:
354     if (expr.cast<AffineConstantExpr>().getValue() != 0)
355       return nullptr;
356     return getAffineConstantExpr(0, expr.getContext());
357   case AffineExprKind::DimId:
358     return nullptr;
359   case AffineExprKind::SymbolId:
360     return getAffineConstantExpr(1, expr.getContext());
361   // Dividing both operands by the given symbol.
362   case AffineExprKind::Add: {
363     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
364     return getAffineBinaryOpExpr(
365         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
366         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
367   }
368   // Dividing both operands by the given symbol.
369   case AffineExprKind::Mod: {
370     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371     return getAffineBinaryOpExpr(
372         expr.getKind(),
373         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
374         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
375   }
376   // Dividing any of the operand by the given symbol.
377   case AffineExprKind::Mul: {
378     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
379     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
380       return binaryExpr.getLHS() *
381              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
382     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
383            binaryExpr.getRHS();
384   }
385   // Dividing first operand only by the given symbol.
386   case AffineExprKind::FloorDiv:
387   case AffineExprKind::CeilDiv: {
388     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
389     return getAffineBinaryOpExpr(
390         expr.getKind(),
391         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
392         binaryExpr.getRHS());
393   }
394   }
395   llvm_unreachable("Unknown AffineExpr");
396 }
397 
398 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
399 /// operations when the second operand simplifies to a symbol and the first
400 /// operand is divisible by that symbol. It can be applied to any semi-affine
401 /// expression. Returned expression can either be a semi-affine or pure affine
402 /// expression.
simplifySemiAffine(AffineExpr expr)403 static AffineExpr simplifySemiAffine(AffineExpr expr) {
404   switch (expr.getKind()) {
405   case AffineExprKind::Constant:
406   case AffineExprKind::DimId:
407   case AffineExprKind::SymbolId:
408     return expr;
409   case AffineExprKind::Add:
410   case AffineExprKind::Mul: {
411     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
412     return getAffineBinaryOpExpr(expr.getKind(),
413                                  simplifySemiAffine(binaryExpr.getLHS()),
414                                  simplifySemiAffine(binaryExpr.getRHS()));
415   }
416   // Check if the simplification of the second operand is a symbol, and the
417   // first operand is divisible by it. If the operation is a modulo, a constant
418   // zero expression is returned. In the case of floordiv and ceildiv, the
419   // symbol from the simplification of the second operand divides the first
420   // operand. Otherwise, simplification is not possible.
421   case AffineExprKind::FloorDiv:
422   case AffineExprKind::CeilDiv:
423   case AffineExprKind::Mod: {
424     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425     AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
426     AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
427     AffineSymbolExpr symbolExpr =
428         simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
429     if (!symbolExpr)
430       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
431     unsigned symbolPos = symbolExpr.getPosition();
432     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
433       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
434     if (expr.getKind() == AffineExprKind::Mod)
435       return getAffineConstantExpr(0, expr.getContext());
436     return symbolicDivide(sLHS, symbolPos, expr.getKind());
437   }
438   }
439   llvm_unreachable("Unknown AffineExpr");
440 }
441 
getAffineDimOrSymbol(AffineExprKind kind,unsigned position,MLIRContext * context)442 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
443                                        MLIRContext *context) {
444   auto assignCtx = [context](AffineDimExprStorage *storage) {
445     storage->context = context;
446   };
447 
448   StorageUniquer &uniquer = context->getAffineUniquer();
449   return uniquer.get<AffineDimExprStorage>(
450       assignCtx, static_cast<unsigned>(kind), position);
451 }
452 
getAffineDimExpr(unsigned position,MLIRContext * context)453 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
454   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
455 }
456 
AffineSymbolExpr(AffineExpr::ImplType * ptr)457 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
458     : AffineExpr(ptr) {}
getPosition() const459 unsigned AffineSymbolExpr::getPosition() const {
460   return static_cast<ImplType *>(expr)->position;
461 }
462 
getAffineSymbolExpr(unsigned position,MLIRContext * context)463 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
464   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
465   ;
466 }
467 
AffineConstantExpr(AffineExpr::ImplType * ptr)468 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
469     : AffineExpr(ptr) {}
getValue() const470 int64_t AffineConstantExpr::getValue() const {
471   return static_cast<ImplType *>(expr)->constant;
472 }
473 
operator ==(int64_t v) const474 bool AffineExpr::operator==(int64_t v) const {
475   return *this == getAffineConstantExpr(v, getContext());
476 }
477 
getAffineConstantExpr(int64_t constant,MLIRContext * context)478 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
479   auto assignCtx = [context](AffineConstantExprStorage *storage) {
480     storage->context = context;
481   };
482 
483   StorageUniquer &uniquer = context->getAffineUniquer();
484   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
485 }
486 
487 /// Simplify add expression. Return nullptr if it can't be simplified.
simplifyAdd(AffineExpr lhs,AffineExpr rhs)488 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
489   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
490   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
491   // Fold if both LHS, RHS are a constant.
492   if (lhsConst && rhsConst)
493     return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
494                                  lhs.getContext());
495 
496   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
497   // If only one of them is a symbolic expressions, make it the RHS.
498   if (lhs.isa<AffineConstantExpr>() ||
499       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
500     return rhs + lhs;
501   }
502 
503   // At this point, if there was a constant, it would be on the right.
504 
505   // Addition with a zero is a noop, return the other input.
506   if (rhsConst) {
507     if (rhsConst.getValue() == 0)
508       return lhs;
509   }
510   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
511   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
512   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
513     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
514       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
515   }
516 
517   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
518   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
519   // respective multiplicands.
520   Optional<int64_t> rLhsConst, rRhsConst;
521   AffineExpr firstExpr, secondExpr;
522   AffineConstantExpr rLhsConstExpr;
523   auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
524   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
525       (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
526     rLhsConst = rLhsConstExpr.getValue();
527     firstExpr = lBinOpExpr.getLHS();
528   } else {
529     rLhsConst = 1;
530     firstExpr = lhs;
531   }
532 
533   auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
534   AffineConstantExpr rRhsConstExpr;
535   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
536       (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
537     rRhsConst = rRhsConstExpr.getValue();
538     secondExpr = rBinOpExpr.getLHS();
539   } else {
540     rRhsConst = 1;
541     secondExpr = rhs;
542   }
543 
544   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
545     return getAffineBinaryOpExpr(
546         AffineExprKind::Mul, firstExpr,
547         getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
548                               lhs.getContext()));
549 
550   // When doing successive additions, bring constant to the right: turn (d0 + 2)
551   // + d1 into (d0 + d1) + 2.
552   if (lBin && lBin.getKind() == AffineExprKind::Add) {
553     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
554       return lBin.getLHS() + rhs + lrhs;
555     }
556   }
557 
558   // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
559   // leads to a much more efficient form when 'c' is a power of two, and in
560   // general a more compact and readable form.
561 
562   // Process '(expr floordiv c) * (-c)'.
563   if (!rBinOpExpr)
564     return nullptr;
565 
566   auto lrhs = rBinOpExpr.getLHS();
567   auto rrhs = rBinOpExpr.getRHS();
568 
569   // Process lrhs, which is 'expr floordiv c'.
570   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
571   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
572     return nullptr;
573 
574   auto llrhs = lrBinOpExpr.getLHS();
575   auto rlrhs = lrBinOpExpr.getRHS();
576 
577   if (lhs == llrhs && rlrhs == -rrhs) {
578     return lhs % rlrhs;
579   }
580   return nullptr;
581 }
582 
operator +(int64_t v) const583 AffineExpr AffineExpr::operator+(int64_t v) const {
584   return *this + getAffineConstantExpr(v, getContext());
585 }
operator +(AffineExpr other) const586 AffineExpr AffineExpr::operator+(AffineExpr other) const {
587   if (auto simplified = simplifyAdd(*this, other))
588     return simplified;
589 
590   StorageUniquer &uniquer = getContext()->getAffineUniquer();
591   return uniquer.get<AffineBinaryOpExprStorage>(
592       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
593 }
594 
595 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
simplifyMul(AffineExpr lhs,AffineExpr rhs)596 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
597   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
598   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
599 
600   if (lhsConst && rhsConst)
601     return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
602                                  lhs.getContext());
603 
604   assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
605 
606   // Canonicalize the mul expression so that the constant/symbolic term is the
607   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
608   // constant. (Note that a constant is trivially symbolic).
609   if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
610     // At least one of them has to be symbolic.
611     return rhs * lhs;
612   }
613 
614   // At this point, if there was a constant, it would be on the right.
615 
616   // Multiplication with a one is a noop, return the other input.
617   if (rhsConst) {
618     if (rhsConst.getValue() == 1)
619       return lhs;
620     // Multiplication with zero.
621     if (rhsConst.getValue() == 0)
622       return rhsConst;
623   }
624 
625   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
626   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
627   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
628     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
629       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
630   }
631 
632   // When doing successive multiplication, bring constant to the right: turn (d0
633   // * 2) * d1 into (d0 * d1) * 2.
634   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
635     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
636       return (lBin.getLHS() * rhs) * lrhs;
637     }
638   }
639 
640   return nullptr;
641 }
642 
operator *(int64_t v) const643 AffineExpr AffineExpr::operator*(int64_t v) const {
644   return *this * getAffineConstantExpr(v, getContext());
645 }
operator *(AffineExpr other) const646 AffineExpr AffineExpr::operator*(AffineExpr other) const {
647   if (auto simplified = simplifyMul(*this, other))
648     return simplified;
649 
650   StorageUniquer &uniquer = getContext()->getAffineUniquer();
651   return uniquer.get<AffineBinaryOpExprStorage>(
652       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
653 }
654 
655 // Unary minus, delegate to operator*.
operator -() const656 AffineExpr AffineExpr::operator-() const {
657   return *this * getAffineConstantExpr(-1, getContext());
658 }
659 
660 // Delegate to operator+.
operator -(int64_t v) const661 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
operator -(AffineExpr other) const662 AffineExpr AffineExpr::operator-(AffineExpr other) const {
663   return *this + (-other);
664 }
665 
simplifyFloorDiv(AffineExpr lhs,AffineExpr rhs)666 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
667   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
668   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
669 
670   // mlir floordiv by zero or negative numbers is undefined and preserved as is.
671   if (!rhsConst || rhsConst.getValue() < 1)
672     return nullptr;
673 
674   if (lhsConst)
675     return getAffineConstantExpr(
676         floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
677 
678   // Fold floordiv of a multiply with a constant that is a multiple of the
679   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
680   if (rhsConst == 1)
681     return lhs;
682 
683   // Simplify (expr * const) floordiv divConst when expr is known to be a
684   // multiple of divConst.
685   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
686   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
687     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
688       // rhsConst is known to be a positive constant.
689       if (lrhs.getValue() % rhsConst.getValue() == 0)
690         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
691     }
692   }
693 
694   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
695   // known to be a multiple of divConst.
696   if (lBin && lBin.getKind() == AffineExprKind::Add) {
697     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
698     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
699     // rhsConst is known to be a positive constant.
700     if (llhsDiv % rhsConst.getValue() == 0 ||
701         lrhsDiv % rhsConst.getValue() == 0)
702       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
703              lBin.getRHS().floorDiv(rhsConst.getValue());
704   }
705 
706   return nullptr;
707 }
708 
floorDiv(uint64_t v) const709 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
710   return floorDiv(getAffineConstantExpr(v, getContext()));
711 }
floorDiv(AffineExpr other) const712 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
713   if (auto simplified = simplifyFloorDiv(*this, other))
714     return simplified;
715 
716   StorageUniquer &uniquer = getContext()->getAffineUniquer();
717   return uniquer.get<AffineBinaryOpExprStorage>(
718       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
719       other);
720 }
721 
simplifyCeilDiv(AffineExpr lhs,AffineExpr rhs)722 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
723   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
724   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
725 
726   if (!rhsConst || rhsConst.getValue() < 1)
727     return nullptr;
728 
729   if (lhsConst)
730     return getAffineConstantExpr(
731         ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
732 
733   // Fold ceildiv of a multiply with a constant that is a multiple of the
734   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
735   if (rhsConst.getValue() == 1)
736     return lhs;
737 
738   // Simplify (expr * const) ceildiv divConst when const is known to be a
739   // multiple of divConst.
740   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
741   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
742     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
743       // rhsConst is known to be a positive constant.
744       if (lrhs.getValue() % rhsConst.getValue() == 0)
745         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
746     }
747   }
748 
749   return nullptr;
750 }
751 
ceilDiv(uint64_t v) const752 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
753   return ceilDiv(getAffineConstantExpr(v, getContext()));
754 }
ceilDiv(AffineExpr other) const755 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
756   if (auto simplified = simplifyCeilDiv(*this, other))
757     return simplified;
758 
759   StorageUniquer &uniquer = getContext()->getAffineUniquer();
760   return uniquer.get<AffineBinaryOpExprStorage>(
761       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
762       other);
763 }
764 
simplifyMod(AffineExpr lhs,AffineExpr rhs)765 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
766   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
767   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
768 
769   // mod w.r.t zero or negative numbers is undefined and preserved as is.
770   if (!rhsConst || rhsConst.getValue() < 1)
771     return nullptr;
772 
773   if (lhsConst)
774     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
775                                  lhs.getContext());
776 
777   // Fold modulo of an expression that is known to be a multiple of a constant
778   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
779   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
780   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
781     return getAffineConstantExpr(0, lhs.getContext());
782 
783   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
784   // known to be a multiple of divConst.
785   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
786   if (lBin && lBin.getKind() == AffineExprKind::Add) {
787     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
788     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
789     // rhsConst is known to be a positive constant.
790     if (llhsDiv % rhsConst.getValue() == 0)
791       return lBin.getRHS() % rhsConst.getValue();
792     if (lrhsDiv % rhsConst.getValue() == 0)
793       return lBin.getLHS() % rhsConst.getValue();
794   }
795 
796   return nullptr;
797 }
798 
operator %(uint64_t v) const799 AffineExpr AffineExpr::operator%(uint64_t v) const {
800   return *this % getAffineConstantExpr(v, getContext());
801 }
operator %(AffineExpr other) const802 AffineExpr AffineExpr::operator%(AffineExpr other) const {
803   if (auto simplified = simplifyMod(*this, other))
804     return simplified;
805 
806   StorageUniquer &uniquer = getContext()->getAffineUniquer();
807   return uniquer.get<AffineBinaryOpExprStorage>(
808       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
809 }
810 
compose(AffineMap map) const811 AffineExpr AffineExpr::compose(AffineMap map) const {
812   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
813                                              map.getResults().end());
814   return replaceDimsAndSymbols(dimReplacements, {});
815 }
operator <<(raw_ostream & os,AffineExpr expr)816 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
817   expr.print(os);
818   return os;
819 }
820 
821 /// Constructs an affine expression from a flat ArrayRef. If there are local
822 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
823 /// products expression, `localExprs` is expected to have the AffineExpr
824 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
825 /// in the format [dims, symbols, locals, constant term].
getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)826 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
827                                            unsigned numDims,
828                                            unsigned numSymbols,
829                                            ArrayRef<AffineExpr> localExprs,
830                                            MLIRContext *context) {
831   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
832   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
833          "unexpected number of local expressions");
834 
835   auto expr = getAffineConstantExpr(0, context);
836   // Dimensions and symbols.
837   for (unsigned j = 0; j < numDims + numSymbols; j++) {
838     if (flatExprs[j] == 0)
839       continue;
840     auto id = j < numDims ? getAffineDimExpr(j, context)
841                           : getAffineSymbolExpr(j - numDims, context);
842     expr = expr + id * flatExprs[j];
843   }
844 
845   // Local identifiers.
846   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
847        j++) {
848     if (flatExprs[j] == 0)
849       continue;
850     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
851     expr = expr + term;
852   }
853 
854   // Constant term.
855   int64_t constTerm = flatExprs[flatExprs.size() - 1];
856   if (constTerm != 0)
857     expr = expr + constTerm;
858   return expr;
859 }
860 
SimpleAffineExprFlattener(unsigned numDims,unsigned numSymbols)861 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
862                                                      unsigned numSymbols)
863     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
864   operandExprStack.reserve(8);
865 }
866 
visitMulExpr(AffineBinaryOpExpr expr)867 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
868   assert(operandExprStack.size() >= 2);
869   // This is a pure affine expr; the RHS will be a constant.
870   assert(expr.getRHS().isa<AffineConstantExpr>());
871   // Get the RHS constant.
872   auto rhsConst = operandExprStack.back()[getConstantIndex()];
873   operandExprStack.pop_back();
874   // Update the LHS in place instead of pop and push.
875   auto &lhs = operandExprStack.back();
876   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
877     lhs[i] *= rhsConst;
878   }
879 }
880 
visitAddExpr(AffineBinaryOpExpr expr)881 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
882   assert(operandExprStack.size() >= 2);
883   const auto &rhs = operandExprStack.back();
884   auto &lhs = operandExprStack[operandExprStack.size() - 2];
885   assert(lhs.size() == rhs.size());
886   // Update the LHS in place.
887   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
888     lhs[i] += rhs[i];
889   }
890   // Pop off the RHS.
891   operandExprStack.pop_back();
892 }
893 
894 //
895 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
896 //
897 // A mod expression "expr mod c" is thus flattened by introducing a new local
898 // variable q (= expr floordiv c), such that expr mod c is replaced with
899 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
visitModExpr(AffineBinaryOpExpr expr)900 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
901   assert(operandExprStack.size() >= 2);
902   // This is a pure affine expr; the RHS will be a constant.
903   assert(expr.getRHS().isa<AffineConstantExpr>());
904   auto rhsConst = operandExprStack.back()[getConstantIndex()];
905   operandExprStack.pop_back();
906   auto &lhs = operandExprStack.back();
907   // TODO: handle modulo by zero case when this issue is fixed
908   // at the other places in the IR.
909   assert(rhsConst > 0 && "RHS constant has to be positive");
910 
911   // Check if the LHS expression is a multiple of modulo factor.
912   unsigned i, e;
913   for (i = 0, e = lhs.size(); i < e; i++)
914     if (lhs[i] % rhsConst != 0)
915       break;
916   // If yes, modulo expression here simplifies to zero.
917   if (i == lhs.size()) {
918     std::fill(lhs.begin(), lhs.end(), 0);
919     return;
920   }
921 
922   // Add a local variable for the quotient, i.e., expr % c is replaced by
923   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
924   // the GCD of expr and c.
925   SmallVector<int64_t, 8> floorDividend(lhs);
926   uint64_t gcd = rhsConst;
927   for (unsigned i = 0, e = lhs.size(); i < e; i++)
928     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
929   // Simplify the numerator and the denominator.
930   if (gcd != 1) {
931     for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
932       floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
933   }
934   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
935 
936   // Construct the AffineExpr form of the floordiv to store in localExprs.
937   MLIRContext *context = expr.getContext();
938   auto dividendExpr = getAffineExprFromFlatForm(
939       floorDividend, numDims, numSymbols, localExprs, context);
940   auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
941   auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
942   int loc;
943   if ((loc = findLocalId(floorDivExpr)) == -1) {
944     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
945     // Set result at top of stack to "lhs - rhsConst * q".
946     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
947   } else {
948     // Reuse the existing local id.
949     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
950   }
951 }
952 
visitCeilDivExpr(AffineBinaryOpExpr expr)953 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
954   visitDivExpr(expr, /*isCeil=*/true);
955 }
visitFloorDivExpr(AffineBinaryOpExpr expr)956 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
957   visitDivExpr(expr, /*isCeil=*/false);
958 }
959 
visitDimExpr(AffineDimExpr expr)960 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
961   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
962   auto &eq = operandExprStack.back();
963   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
964   eq[getDimStartIndex() + expr.getPosition()] = 1;
965 }
966 
visitSymbolExpr(AffineSymbolExpr expr)967 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
968   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
969   auto &eq = operandExprStack.back();
970   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
971   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
972 }
973 
visitConstantExpr(AffineConstantExpr expr)974 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
975   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
976   auto &eq = operandExprStack.back();
977   eq[getConstantIndex()] = expr.getValue();
978 }
979 
980 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
981 // A floordiv is thus flattened by introducing a new local variable q, and
982 // replacing that expression with 'q' while adding the constraints
983 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
984 // FlatAffineConstraints::addLocalFloorDiv).
985 //
986 // A ceildiv is similarly flattened:
987 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
visitDivExpr(AffineBinaryOpExpr expr,bool isCeil)988 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
989                                              bool isCeil) {
990   assert(operandExprStack.size() >= 2);
991   assert(expr.getRHS().isa<AffineConstantExpr>());
992 
993   // This is a pure affine expr; the RHS is a positive constant.
994   int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
995   // TODO: handle division by zero at the same time the issue is
996   // fixed at other places.
997   assert(rhsConst > 0 && "RHS constant has to be positive");
998   operandExprStack.pop_back();
999   auto &lhs = operandExprStack.back();
1000 
1001   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1002   // common divisors of the numerator and denominator.
1003   uint64_t gcd = std::abs(rhsConst);
1004   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1005     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1006   // Simplify the numerator and the denominator.
1007   if (gcd != 1) {
1008     for (unsigned i = 0, e = lhs.size(); i < e; i++)
1009       lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1010   }
1011   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1012   // If the divisor becomes 1, the updated LHS is the result. (The
1013   // divisor can't be negative since rhsConst is positive).
1014   if (divisor == 1)
1015     return;
1016 
1017   // If the divisor cannot be simplified to one, we will have to retain
1018   // the ceil/floor expr (simplified up until here). Add an existential
1019   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1020   // by a new identifier, q.
1021   MLIRContext *context = expr.getContext();
1022   auto a =
1023       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1024   auto b = getAffineConstantExpr(divisor, context);
1025 
1026   int loc;
1027   auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1028   if ((loc = findLocalId(divExpr)) == -1) {
1029     if (!isCeil) {
1030       SmallVector<int64_t, 8> dividend(lhs);
1031       addLocalFloorDivId(dividend, divisor, divExpr);
1032     } else {
1033       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1034       SmallVector<int64_t, 8> dividend(lhs);
1035       dividend.back() += divisor - 1;
1036       addLocalFloorDivId(dividend, divisor, divExpr);
1037     }
1038   }
1039   // Set the expression on stack to the local var introduced to capture the
1040   // result of the division (floor or ceil).
1041   std::fill(lhs.begin(), lhs.end(), 0);
1042   if (loc == -1)
1043     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1044   else
1045     lhs[getLocalVarStartIndex() + loc] = 1;
1046 }
1047 
1048 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1049 // The local identifier added is always a floordiv of a pure add/mul affine
1050 // function of other identifiers, coefficients of which are specified in
1051 // dividend and with respect to a positive constant divisor. localExpr is the
1052 // simplified tree expression (AffineExpr) corresponding to the quantifier.
addLocalFloorDivId(ArrayRef<int64_t> dividend,int64_t divisor,AffineExpr localExpr)1053 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1054                                                    int64_t divisor,
1055                                                    AffineExpr localExpr) {
1056   assert(divisor > 0 && "positive constant divisor expected");
1057   for (auto &subExpr : operandExprStack)
1058     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1059   localExprs.push_back(localExpr);
1060   numLocals++;
1061   // dividend and divisor are not used here; an override of this method uses it.
1062 }
1063 
findLocalId(AffineExpr localExpr)1064 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1065   SmallVectorImpl<AffineExpr>::iterator it;
1066   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1067     return -1;
1068   return it - localExprs.begin();
1069 }
1070 
1071 /// Simplify the affine expression by flattening it and reconstructing it.
simplifyAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols)1072 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1073                                     unsigned numSymbols) {
1074   // Simplify semi-affine expressions separately.
1075   if (!expr.isPureAffine())
1076     expr = simplifySemiAffine(expr);
1077   if (!expr.isPureAffine())
1078     return expr;
1079 
1080   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1081   flattener.walkPostOrder(expr);
1082   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1083   auto simplifiedExpr =
1084       getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1085                                 flattener.localExprs, expr.getContext());
1086   flattener.operandExprStack.pop_back();
1087   assert(flattener.operandExprStack.empty());
1088 
1089   return simplifiedExpr;
1090 }
1091