1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "AffineMapDetail.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "llvm/ADT/SmallSet.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 
23 // AffineExprConstantFolder evaluates an affine expression using constant
24 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
25 // representing the constant value of the affine expression evaluated on
26 // constant 'operandConsts', or nullptr if it can't be folded.
27 class AffineExprConstantFolder {
28 public:
AffineExprConstantFolder(unsigned numDims,ArrayRef<Attribute> operandConsts)29   AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
30       : numDims(numDims), operandConsts(operandConsts) {}
31 
32   /// Attempt to constant fold the specified affine expr, or return null on
33   /// failure.
constantFold(AffineExpr expr)34   IntegerAttr constantFold(AffineExpr expr) {
35     if (auto result = constantFoldImpl(expr))
36       return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
37     return nullptr;
38   }
39 
40 private:
constantFoldImpl(AffineExpr expr)41   Optional<int64_t> constantFoldImpl(AffineExpr expr) {
42     switch (expr.getKind()) {
43     case AffineExprKind::Add:
44       return constantFoldBinExpr(
45           expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
46     case AffineExprKind::Mul:
47       return constantFoldBinExpr(
48           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
49     case AffineExprKind::Mod:
50       return constantFoldBinExpr(
51           expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
52     case AffineExprKind::FloorDiv:
53       return constantFoldBinExpr(
54           expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
55     case AffineExprKind::CeilDiv:
56       return constantFoldBinExpr(
57           expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
58     case AffineExprKind::Constant:
59       return expr.cast<AffineConstantExpr>().getValue();
60     case AffineExprKind::DimId:
61       if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
62                           .dyn_cast_or_null<IntegerAttr>())
63         return attr.getInt();
64       return llvm::None;
65     case AffineExprKind::SymbolId:
66       if (auto attr = operandConsts[numDims +
67                                     expr.cast<AffineSymbolExpr>().getPosition()]
68                           .dyn_cast_or_null<IntegerAttr>())
69         return attr.getInt();
70       return llvm::None;
71     }
72     llvm_unreachable("Unknown AffineExpr");
73   }
74 
75   // TODO: Change these to operate on APInts too.
constantFoldBinExpr(AffineExpr expr,int64_t (* op)(int64_t,int64_t))76   Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
77                                         int64_t (*op)(int64_t, int64_t)) {
78     auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
79     if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
80       if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
81         return op(*lhs, *rhs);
82     return llvm::None;
83   }
84 
85   // The number of dimension operands in AffineMap containing this expression.
86   unsigned numDims;
87   // The constant valued operands used to evaluate this AffineExpr.
88   ArrayRef<Attribute> operandConsts;
89 };
90 
91 } // end anonymous namespace
92 
93 /// Returns a single constant result affine map.
getConstantMap(int64_t val,MLIRContext * context)94 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
95   return get(/*dimCount=*/0, /*symbolCount=*/0,
96              {getAffineConstantExpr(val, context)});
97 }
98 
99 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
100 /// minor dimensions.
getMinorIdentityMap(unsigned dims,unsigned results,MLIRContext * context)101 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
102                                          MLIRContext *context) {
103   assert(dims >= results && "Dimension mismatch");
104   auto id = AffineMap::getMultiDimIdentityMap(dims, context);
105   return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
106 }
107 
isMinorIdentity() const108 bool AffineMap::isMinorIdentity() const {
109   return *this ==
110          getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
111 }
112 
113 /// Returns an AffineMap representing a permutation.
getPermutationMap(ArrayRef<unsigned> permutation,MLIRContext * context)114 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
115                                        MLIRContext *context) {
116   assert(!permutation.empty() &&
117          "Cannot create permutation map from empty permutation vector");
118   SmallVector<AffineExpr, 4> affExprs;
119   for (auto index : permutation)
120     affExprs.push_back(getAffineDimExpr(index, context));
121   auto m = std::max_element(permutation.begin(), permutation.end());
122   auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
123   assert(permutationMap.isPermutation() && "Invalid permutation vector");
124   return permutationMap;
125 }
126 
127 template <typename AffineExprContainer>
getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,int64_t & maxDim,int64_t & maxSym)128 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
129                                int64_t &maxDim, int64_t &maxSym) {
130   for (const auto &exprs : exprsList) {
131     for (auto expr : exprs) {
132       expr.walk([&maxDim, &maxSym](AffineExpr e) {
133         if (auto d = e.dyn_cast<AffineDimExpr>())
134           maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
135         if (auto s = e.dyn_cast<AffineSymbolExpr>())
136           maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
137       });
138     }
139   }
140 }
141 
142 template <typename AffineExprContainer>
143 static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList)144 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
145   assert(!exprsList.empty());
146   assert(!exprsList[0].empty());
147   auto context = exprsList[0][0].getContext();
148   int64_t maxDim = -1, maxSym = -1;
149   getMaxDimAndSymbol(exprsList, maxDim, maxSym);
150   SmallVector<AffineMap, 4> maps;
151   maps.reserve(exprsList.size());
152   for (const auto &exprs : exprsList)
153     maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
154                                   /*symbolCount=*/maxSym + 1, exprs, context));
155   return maps;
156 }
157 
158 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList)159 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
160   return ::inferFromExprList(exprsList);
161 }
162 
163 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<SmallVector<AffineExpr,4>> exprsList)164 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
165   return ::inferFromExprList(exprsList);
166 }
167 
getMultiDimIdentityMap(unsigned numDims,MLIRContext * context)168 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
169                                             MLIRContext *context) {
170   SmallVector<AffineExpr, 4> dimExprs;
171   dimExprs.reserve(numDims);
172   for (unsigned i = 0; i < numDims; ++i)
173     dimExprs.push_back(mlir::getAffineDimExpr(i, context));
174   return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
175 }
176 
getContext() const177 MLIRContext *AffineMap::getContext() const { return map->context; }
178 
isIdentity() const179 bool AffineMap::isIdentity() const {
180   if (getNumDims() != getNumResults())
181     return false;
182   ArrayRef<AffineExpr> results = getResults();
183   for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
184     auto expr = results[i].dyn_cast<AffineDimExpr>();
185     if (!expr || expr.getPosition() != i)
186       return false;
187   }
188   return true;
189 }
190 
isEmpty() const191 bool AffineMap::isEmpty() const {
192   return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
193 }
194 
isSingleConstant() const195 bool AffineMap::isSingleConstant() const {
196   return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
197 }
198 
getSingleConstantResult() const199 int64_t AffineMap::getSingleConstantResult() const {
200   assert(isSingleConstant() && "map must have a single constant result");
201   return getResult(0).cast<AffineConstantExpr>().getValue();
202 }
203 
getNumDims() const204 unsigned AffineMap::getNumDims() const {
205   assert(map && "uninitialized map storage");
206   return map->numDims;
207 }
getNumSymbols() const208 unsigned AffineMap::getNumSymbols() const {
209   assert(map && "uninitialized map storage");
210   return map->numSymbols;
211 }
getNumResults() const212 unsigned AffineMap::getNumResults() const {
213   assert(map && "uninitialized map storage");
214   return map->results.size();
215 }
getNumInputs() const216 unsigned AffineMap::getNumInputs() const {
217   assert(map && "uninitialized map storage");
218   return map->numDims + map->numSymbols;
219 }
220 
getResults() const221 ArrayRef<AffineExpr> AffineMap::getResults() const {
222   assert(map && "uninitialized map storage");
223   return map->results;
224 }
getResult(unsigned idx) const225 AffineExpr AffineMap::getResult(unsigned idx) const {
226   assert(map && "uninitialized map storage");
227   return map->results[idx];
228 }
229 
getDimPosition(unsigned idx) const230 unsigned AffineMap::getDimPosition(unsigned idx) const {
231   return getResult(idx).cast<AffineDimExpr>().getPosition();
232 }
233 
234 /// Folds the results of the application of an affine map on the provided
235 /// operands to a constant if possible. Returns false if the folding happens,
236 /// true otherwise.
237 LogicalResult
constantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<Attribute> & results) const238 AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
239                         SmallVectorImpl<Attribute> &results) const {
240   // Attempt partial folding.
241   SmallVector<int64_t, 2> integers;
242   partialConstantFold(operandConstants, &integers);
243 
244   // If all expressions folded to a constant, populate results with attributes
245   // containing those constants.
246   if (integers.empty())
247     return failure();
248 
249   auto range = llvm::map_range(integers, [this](int64_t i) {
250     return IntegerAttr::get(IndexType::get(getContext()), i);
251   });
252   results.append(range.begin(), range.end());
253   return success();
254 }
255 
256 AffineMap
partialConstantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<int64_t> * results) const257 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
258                                SmallVectorImpl<int64_t> *results) const {
259   assert(getNumInputs() == operandConstants.size());
260 
261   // Fold each of the result expressions.
262   AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
263   SmallVector<AffineExpr, 4> exprs;
264   exprs.reserve(getNumResults());
265 
266   for (auto expr : getResults()) {
267     auto folded = exprFolder.constantFold(expr);
268     // If did not fold to a constant, keep the original expression, and clear
269     // the integer results vector.
270     if (folded) {
271       exprs.push_back(
272           getAffineConstantExpr(folded.getInt(), folded.getContext()));
273       if (results)
274         results->push_back(folded.getInt());
275     } else {
276       exprs.push_back(expr);
277       if (results) {
278         results->clear();
279         results = nullptr;
280       }
281     }
282   }
283 
284   return get(getNumDims(), getNumSymbols(), exprs, getContext());
285 }
286 
287 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
288 /// tree is visited in postorder.
walkExprs(std::function<void (AffineExpr)> callback) const289 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const {
290   for (auto expr : getResults())
291     expr.walk(callback);
292 }
293 
294 /// This method substitutes any uses of dimensions and symbols (e.g.
295 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
296 /// expression mapping.  Because this can be used to eliminate dims and
297 /// symbols, the client needs to specify the number of dims and symbols in
298 /// the result.  The returned map always has the same number of results.
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements,unsigned numResultDims,unsigned numResultSyms) const299 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
300                                            ArrayRef<AffineExpr> symReplacements,
301                                            unsigned numResultDims,
302                                            unsigned numResultSyms) const {
303   SmallVector<AffineExpr, 8> results;
304   results.reserve(getNumResults());
305   for (auto expr : getResults())
306     results.push_back(
307         expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
308 
309   return get(numResultDims, numResultSyms, results, getContext());
310 }
311 
compose(AffineMap map)312 AffineMap AffineMap::compose(AffineMap map) {
313   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
314   // Prepare `map` by concatenating the symbols and rewriting its exprs.
315   unsigned numDims = map.getNumDims();
316   unsigned numSymbolsThisMap = getNumSymbols();
317   unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
318   SmallVector<AffineExpr, 8> newDims(numDims);
319   for (unsigned idx = 0; idx < numDims; ++idx) {
320     newDims[idx] = getAffineDimExpr(idx, getContext());
321   }
322   SmallVector<AffineExpr, 8> newSymbols(numSymbols);
323   for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
324     newSymbols[idx - numSymbolsThisMap] =
325         getAffineSymbolExpr(idx, getContext());
326   }
327   auto newMap =
328       map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
329   SmallVector<AffineExpr, 8> exprs;
330   exprs.reserve(getResults().size());
331   for (auto expr : getResults())
332     exprs.push_back(expr.compose(newMap));
333   return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
334 }
335 
compose(ArrayRef<int64_t> values)336 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) {
337   assert(getNumSymbols() == 0 && "Expected symbol-less map");
338   SmallVector<AffineExpr, 4> exprs;
339   exprs.reserve(values.size());
340   MLIRContext *ctx = getContext();
341   for (auto v : values)
342     exprs.push_back(getAffineConstantExpr(v, ctx));
343   auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
344   SmallVector<int64_t, 4> res;
345   res.reserve(resMap.getNumResults());
346   for (auto e : resMap.getResults())
347     res.push_back(e.cast<AffineConstantExpr>().getValue());
348   return res;
349 }
350 
isProjectedPermutation()351 bool AffineMap::isProjectedPermutation() {
352   if (getNumSymbols() > 0)
353     return false;
354   SmallVector<bool, 8> seen(getNumInputs(), false);
355   for (auto expr : getResults()) {
356     if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
357       if (seen[dim.getPosition()])
358         return false;
359       seen[dim.getPosition()] = true;
360       continue;
361     }
362     return false;
363   }
364   return true;
365 }
366 
isPermutation()367 bool AffineMap::isPermutation() {
368   if (getNumDims() != getNumResults())
369     return false;
370   return isProjectedPermutation();
371 }
372 
getSubMap(ArrayRef<unsigned> resultPos)373 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
374   SmallVector<AffineExpr, 4> exprs;
375   exprs.reserve(resultPos.size());
376   for (auto idx : resultPos)
377     exprs.push_back(getResult(idx));
378   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
379 }
380 
getMajorSubMap(unsigned numResults)381 AffineMap AffineMap::getMajorSubMap(unsigned numResults) {
382   if (numResults == 0)
383     return AffineMap();
384   if (numResults > getNumResults())
385     return *this;
386   return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
387 }
388 
getMinorSubMap(unsigned numResults)389 AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
390   if (numResults == 0)
391     return AffineMap();
392   if (numResults > getNumResults())
393     return *this;
394   return getSubMap(llvm::to_vector<4>(
395       llvm::seq<unsigned>(getNumResults() - numResults, getNumResults())));
396 }
397 
simplifyAffineMap(AffineMap map)398 AffineMap mlir::simplifyAffineMap(AffineMap map) {
399   SmallVector<AffineExpr, 8> exprs;
400   for (auto e : map.getResults()) {
401     exprs.push_back(
402         simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
403   }
404   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
405                         map.getContext());
406 }
407 
removeDuplicateExprs(AffineMap map)408 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
409   auto results = map.getResults();
410   SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
411   uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
412                     uniqueExprs.end());
413   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
414                         map.getContext());
415 }
416 
inversePermutation(AffineMap map)417 AffineMap mlir::inversePermutation(AffineMap map) {
418   if (map.isEmpty())
419     return map;
420   assert(map.getNumSymbols() == 0 && "expected map without symbols");
421   SmallVector<AffineExpr, 4> exprs(map.getNumDims());
422   for (auto en : llvm::enumerate(map.getResults())) {
423     auto expr = en.value();
424     // Skip non-permutations.
425     if (auto d = expr.dyn_cast<AffineDimExpr>()) {
426       if (exprs[d.getPosition()])
427         continue;
428       exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
429     }
430   }
431   SmallVector<AffineExpr, 4> seenExprs;
432   seenExprs.reserve(map.getNumDims());
433   for (auto expr : exprs)
434     if (expr)
435       seenExprs.push_back(expr);
436   if (seenExprs.size() != map.getNumInputs())
437     return AffineMap();
438   return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
439 }
440 
concatAffineMaps(ArrayRef<AffineMap> maps)441 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
442   unsigned numResults = 0, numDims = 0, numSymbols = 0;
443   for (auto m : maps)
444     numResults += m.getNumResults();
445   SmallVector<AffineExpr, 8> results;
446   results.reserve(numResults);
447   for (auto m : maps) {
448     for (auto res : m.getResults())
449       results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
450 
451     numSymbols += m.getNumSymbols();
452     numDims = std::max(m.getNumDims(), numDims);
453   }
454   return AffineMap::get(numDims, numSymbols, results,
455                         maps.front().getContext());
456 }
457 
getProjectedMap(AffineMap map,ArrayRef<unsigned> projectedDimensions)458 AffineMap mlir::getProjectedMap(AffineMap map,
459                                 ArrayRef<unsigned> projectedDimensions) {
460   DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
461                                    projectedDimensions.end());
462   MLIRContext *context = map.getContext();
463   SmallVector<AffineExpr, 4> resultExprs;
464   for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
465     if (!projectedDims.count(dim.value()))
466       resultExprs.push_back(getAffineDimExpr(dim.index(), context));
467     else
468       resultExprs.push_back(getAffineConstantExpr(0, context));
469   }
470   return map.compose(AffineMap::get(
471       map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // MutableAffineMap.
476 //===----------------------------------------------------------------------===//
477 
MutableAffineMap(AffineMap map)478 MutableAffineMap::MutableAffineMap(AffineMap map)
479     : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
480       context(map.getContext()) {
481   for (auto result : map.getResults())
482     results.push_back(result);
483 }
484 
reset(AffineMap map)485 void MutableAffineMap::reset(AffineMap map) {
486   results.clear();
487   numDims = map.getNumDims();
488   numSymbols = map.getNumSymbols();
489   context = map.getContext();
490   for (auto result : map.getResults())
491     results.push_back(result);
492 }
493 
isMultipleOf(unsigned idx,int64_t factor) const494 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
495   if (results[idx].isMultipleOf(factor))
496     return true;
497 
498   // TODO: use simplifyAffineExpr and FlatAffineConstraints to
499   // complete this (for a more powerful analysis).
500   return false;
501 }
502 
503 // Simplifies the result affine expressions of this map. The expressions have to
504 // be pure for the simplification implemented.
simplify()505 void MutableAffineMap::simplify() {
506   // Simplify each of the results if possible.
507   // TODO: functional-style map
508   for (unsigned i = 0, e = getNumResults(); i < e; i++) {
509     results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
510   }
511 }
512 
getAffineMap() const513 AffineMap MutableAffineMap::getAffineMap() const {
514   return AffineMap::get(numDims, numSymbols, results, context);
515 }
516