1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/Simplex.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineExprVisitor.h"
19 #include "mlir/IR/IntegerSet.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25
26 #define DEBUG_TYPE "affine-structures"
27
28 using namespace mlir;
29 using llvm::SmallDenseMap;
30 using llvm::SmallDenseSet;
31
32 namespace {
33
34 // See comments for SimpleAffineExprFlattener.
35 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
36 // constraint information associated with mod's, floordiv's, and ceildiv's
37 // in FlatAffineConstraints 'localVarCst'.
38 struct AffineExprFlattener : public SimpleAffineExprFlattener {
39 public:
40 // Constraints connecting newly introduced local variables (for mod's and
41 // div's) to existing (dimensional and symbolic) ones. These are always
42 // inequalities.
43 FlatAffineConstraints localVarCst;
44
AffineExprFlattener__anon1dd723590111::AffineExprFlattener45 AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
46 : SimpleAffineExprFlattener(nDims, nSymbols) {
47 localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
48 }
49
50 private:
51 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
52 // The local identifier added is always a floordiv of a pure add/mul affine
53 // function of other identifiers, coefficients of which are specified in
54 // `dividend' and with respect to the positive constant `divisor'. localExpr
55 // is the simplified tree expression (AffineExpr) corresponding to the
56 // quantifier.
addLocalFloorDivId__anon1dd723590111::AffineExprFlattener57 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
58 AffineExpr localExpr) override {
59 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
60 // Update localVarCst.
61 localVarCst.addLocalFloorDiv(dividend, divisor);
62 }
63 };
64
65 } // end anonymous namespace
66
67 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
68 // flattened (i.e., semi-affine expressions not handled yet).
69 static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs,unsigned numDims,unsigned numSymbols,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)70 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
71 unsigned numSymbols,
72 std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
73 FlatAffineConstraints *localVarCst) {
74 if (exprs.empty()) {
75 localVarCst->reset(numDims, numSymbols);
76 return success();
77 }
78
79 AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
80 // Use the same flattener to simplify each expression successively. This way
81 // local identifiers / expressions are shared.
82 for (auto expr : exprs) {
83 if (!expr.isPureAffine())
84 return failure();
85
86 flattener.walkPostOrder(expr);
87 }
88
89 assert(flattener.operandExprStack.size() == exprs.size());
90 flattenedExprs->clear();
91 flattenedExprs->assign(flattener.operandExprStack.begin(),
92 flattener.operandExprStack.end());
93
94 if (localVarCst)
95 localVarCst->clearAndCopyFrom(flattener.localVarCst);
96
97 return success();
98 }
99
100 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
101 // be flattened (semi-affine expressions not handled yet).
102 LogicalResult
getFlattenedAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols,SmallVectorImpl<int64_t> * flattenedExpr,FlatAffineConstraints * localVarCst)103 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
104 unsigned numSymbols,
105 SmallVectorImpl<int64_t> *flattenedExpr,
106 FlatAffineConstraints *localVarCst) {
107 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
108 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
109 &flattenedExprs, localVarCst);
110 *flattenedExpr = flattenedExprs[0];
111 return ret;
112 }
113
114 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
115 /// flattened (i.e., semi-affine expressions not handled yet).
getFlattenedAffineExprs(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)116 LogicalResult mlir::getFlattenedAffineExprs(
117 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
118 FlatAffineConstraints *localVarCst) {
119 if (map.getNumResults() == 0) {
120 localVarCst->reset(map.getNumDims(), map.getNumSymbols());
121 return success();
122 }
123 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
124 map.getNumSymbols(), flattenedExprs,
125 localVarCst);
126 }
127
getFlattenedAffineExprs(IntegerSet set,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)128 LogicalResult mlir::getFlattenedAffineExprs(
129 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
130 FlatAffineConstraints *localVarCst) {
131 if (set.getNumConstraints() == 0) {
132 localVarCst->reset(set.getNumDims(), set.getNumSymbols());
133 return success();
134 }
135 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
136 set.getNumSymbols(), flattenedExprs,
137 localVarCst);
138 }
139
140 //===----------------------------------------------------------------------===//
141 // FlatAffineConstraints.
142 //===----------------------------------------------------------------------===//
143
144 // Copy constructor.
FlatAffineConstraints(const FlatAffineConstraints & other)145 FlatAffineConstraints::FlatAffineConstraints(
146 const FlatAffineConstraints &other) {
147 numReservedCols = other.numReservedCols;
148 numDims = other.getNumDimIds();
149 numSymbols = other.getNumSymbolIds();
150 numIds = other.getNumIds();
151
152 auto otherIds = other.getIds();
153 ids.reserve(numReservedCols);
154 ids.append(otherIds.begin(), otherIds.end());
155
156 unsigned numReservedEqualities = other.getNumReservedEqualities();
157 unsigned numReservedInequalities = other.getNumReservedInequalities();
158
159 equalities.reserve(numReservedEqualities * numReservedCols);
160 inequalities.reserve(numReservedInequalities * numReservedCols);
161
162 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
163 addInequality(other.getInequality(r));
164 }
165 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
166 addEquality(other.getEquality(r));
167 }
168 }
169
170 // Clones this object.
clone() const171 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
172 return std::make_unique<FlatAffineConstraints>(*this);
173 }
174
175 // Construct from an IntegerSet.
FlatAffineConstraints(IntegerSet set)176 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
177 : numReservedCols(set.getNumInputs() + 1),
178 numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
179 numSymbols(set.getNumSymbols()) {
180 equalities.reserve(set.getNumEqualities() * numReservedCols);
181 inequalities.reserve(set.getNumInequalities() * numReservedCols);
182 ids.resize(numIds, None);
183
184 // Flatten expressions and add them to the constraint system.
185 std::vector<SmallVector<int64_t, 8>> flatExprs;
186 FlatAffineConstraints localVarCst;
187 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
188 assert(false && "flattening unimplemented for semi-affine integer sets");
189 return;
190 }
191 assert(flatExprs.size() == set.getNumConstraints());
192 for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
193 addLocalId(getNumLocalIds());
194 }
195
196 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
197 const auto &flatExpr = flatExprs[i];
198 assert(flatExpr.size() == getNumCols());
199 if (set.getEqFlags()[i]) {
200 addEquality(flatExpr);
201 } else {
202 addInequality(flatExpr);
203 }
204 }
205 // Add the other constraints involving local id's from flattening.
206 append(localVarCst);
207 }
208
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> idArgs)209 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
210 unsigned numReservedEqualities,
211 unsigned newNumReservedCols,
212 unsigned newNumDims, unsigned newNumSymbols,
213 unsigned newNumLocals,
214 ArrayRef<Value> idArgs) {
215 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
216 "minimum 1 column");
217 numReservedCols = newNumReservedCols;
218 numDims = newNumDims;
219 numSymbols = newNumSymbols;
220 numIds = numDims + numSymbols + newNumLocals;
221 assert(idArgs.empty() || idArgs.size() == numIds);
222
223 clearConstraints();
224 if (numReservedEqualities >= 1)
225 equalities.reserve(newNumReservedCols * numReservedEqualities);
226 if (numReservedInequalities >= 1)
227 inequalities.reserve(newNumReservedCols * numReservedInequalities);
228 if (idArgs.empty()) {
229 ids.resize(numIds, None);
230 } else {
231 ids.assign(idArgs.begin(), idArgs.end());
232 }
233 }
234
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> idArgs)235 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
236 unsigned newNumLocals,
237 ArrayRef<Value> idArgs) {
238 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
239 newNumSymbols, newNumLocals, idArgs);
240 }
241
append(const FlatAffineConstraints & other)242 void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
243 assert(other.getNumCols() == getNumCols());
244 assert(other.getNumDimIds() == getNumDimIds());
245 assert(other.getNumSymbolIds() == getNumSymbolIds());
246
247 inequalities.reserve(inequalities.size() +
248 other.getNumInequalities() * numReservedCols);
249 equalities.reserve(equalities.size() +
250 other.getNumEqualities() * numReservedCols);
251
252 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
253 addInequality(other.getInequality(r));
254 }
255 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
256 addEquality(other.getEquality(r));
257 }
258 }
259
addLocalId(unsigned pos)260 void FlatAffineConstraints::addLocalId(unsigned pos) {
261 addId(IdKind::Local, pos);
262 }
263
addDimId(unsigned pos,Value id)264 void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
265 addId(IdKind::Dimension, pos, id);
266 }
267
addSymbolId(unsigned pos,Value id)268 void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
269 addId(IdKind::Symbol, pos, id);
270 }
271
272 /// Adds a dimensional identifier. The added column is initialized to
273 /// zero.
addId(IdKind kind,unsigned pos,Value id)274 void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
275 if (kind == IdKind::Dimension)
276 assert(pos <= getNumDimIds());
277 else if (kind == IdKind::Symbol)
278 assert(pos <= getNumSymbolIds());
279 else
280 assert(pos <= getNumLocalIds());
281
282 unsigned oldNumReservedCols = numReservedCols;
283
284 // Check if a resize is necessary.
285 if (getNumCols() + 1 > numReservedCols) {
286 equalities.resize(getNumEqualities() * (getNumCols() + 1));
287 inequalities.resize(getNumInequalities() * (getNumCols() + 1));
288 numReservedCols++;
289 }
290
291 int absolutePos;
292
293 if (kind == IdKind::Dimension) {
294 absolutePos = pos;
295 numDims++;
296 } else if (kind == IdKind::Symbol) {
297 absolutePos = pos + getNumDimIds();
298 numSymbols++;
299 } else {
300 absolutePos = pos + getNumDimIds() + getNumSymbolIds();
301 }
302 numIds++;
303
304 // Note that getNumCols() now will already return the new size, which will be
305 // at least one.
306 int numInequalities = static_cast<int>(getNumInequalities());
307 int numEqualities = static_cast<int>(getNumEqualities());
308 int numCols = static_cast<int>(getNumCols());
309 for (int r = numInequalities - 1; r >= 0; r--) {
310 for (int c = numCols - 2; c >= 0; c--) {
311 if (c < absolutePos)
312 atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
313 else
314 atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
315 }
316 atIneq(r, absolutePos) = 0;
317 }
318
319 for (int r = numEqualities - 1; r >= 0; r--) {
320 for (int c = numCols - 2; c >= 0; c--) {
321 // All values in column absolutePositions < absolutePos have the same
322 // coordinates in the 2-d view of the coefficient buffer.
323 if (c < absolutePos)
324 atEq(r, c) = equalities[r * oldNumReservedCols + c];
325 else
326 // Those at absolutePosition >= absolutePos, get a shifted
327 // absolutePosition.
328 atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
329 }
330 // Initialize added dimension to zero.
331 atEq(r, absolutePos) = 0;
332 }
333
334 // If an 'id' is provided, insert it; otherwise use None.
335 if (id)
336 ids.insert(ids.begin() + absolutePos, id);
337 else
338 ids.insert(ids.begin() + absolutePos, None);
339 assert(ids.size() == getNumIds());
340 }
341
342 /// Checks if two constraint systems are in the same space, i.e., if they are
343 /// associated with the same set of identifiers, appearing in the same order.
areIdsAligned(const FlatAffineConstraints & A,const FlatAffineConstraints & B)344 static bool areIdsAligned(const FlatAffineConstraints &A,
345 const FlatAffineConstraints &B) {
346 return A.getNumDimIds() == B.getNumDimIds() &&
347 A.getNumSymbolIds() == B.getNumSymbolIds() &&
348 A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
349 }
350
351 /// Calls areIdsAligned to check if two constraint systems have the same set
352 /// of identifiers in the same order.
areIdsAlignedWithOther(const FlatAffineConstraints & other)353 bool FlatAffineConstraints::areIdsAlignedWithOther(
354 const FlatAffineConstraints &other) {
355 return areIdsAligned(*this, other);
356 }
357
358 /// Checks if the SSA values associated with `cst''s identifiers are unique.
359 static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineConstraints & cst)360 areIdsUnique(const FlatAffineConstraints &cst) {
361 SmallPtrSet<Value, 8> uniqueIds;
362 for (auto id : cst.getIds()) {
363 if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
364 return false;
365 }
366 return true;
367 }
368
369 /// Merge and align the identifiers of A and B starting at 'offset', so that
370 /// both constraint systems get the union of the contained identifiers that is
371 /// dimension-wise and symbol-wise unique; both constraint systems are updated
372 /// so that they have the union of all identifiers, with A's original
373 /// identifiers appearing first followed by any of B's identifiers that didn't
374 /// appear in A. Local identifiers of each system are by design separate/local
375 /// and are placed one after other (A's followed by B's).
376 // Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
377 // Output: both A, B have (%i, %j, %k) [%M, %N, %P]
378 //
mergeAndAlignIds(unsigned offset,FlatAffineConstraints * A,FlatAffineConstraints * B)379 static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
380 FlatAffineConstraints *B) {
381 assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
382 // A merge/align isn't meaningful if a cst's ids aren't distinct.
383 assert(areIdsUnique(*A) && "A's id values aren't unique");
384 assert(areIdsUnique(*B) && "B's id values aren't unique");
385
386 assert(std::all_of(A->getIds().begin() + offset,
387 A->getIds().begin() + A->getNumDimAndSymbolIds(),
388 [](Optional<Value> id) { return id.hasValue(); }));
389
390 assert(std::all_of(B->getIds().begin() + offset,
391 B->getIds().begin() + B->getNumDimAndSymbolIds(),
392 [](Optional<Value> id) { return id.hasValue(); }));
393
394 // Place local id's of A after local id's of B.
395 for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
396 B->addLocalId(0);
397 }
398 for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
399 t++) {
400 A->addLocalId(A->getNumLocalIds());
401 }
402
403 SmallVector<Value, 4> aDimValues, aSymValues;
404 A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
405 A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
406 {
407 // Merge dims from A into B.
408 unsigned d = offset;
409 for (auto aDimValue : aDimValues) {
410 unsigned loc;
411 if (B->findId(aDimValue, &loc)) {
412 assert(loc >= offset && "A's dim appears in B's aligned range");
413 assert(loc < B->getNumDimIds() &&
414 "A's dim appears in B's non-dim position");
415 B->swapId(d, loc);
416 } else {
417 B->addDimId(d);
418 B->setIdValue(d, aDimValue);
419 }
420 d++;
421 }
422
423 // Dimensions that are in B, but not in A, are added at the end.
424 for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
425 A->addDimId(A->getNumDimIds());
426 A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
427 }
428 }
429 {
430 // Merge symbols: merge A's symbols into B first.
431 unsigned s = B->getNumDimIds();
432 for (auto aSymValue : aSymValues) {
433 unsigned loc;
434 if (B->findId(aSymValue, &loc)) {
435 assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
436 "A's symbol appears in B's non-symbol position");
437 B->swapId(s, loc);
438 } else {
439 B->addSymbolId(s - B->getNumDimIds());
440 B->setIdValue(s, aSymValue);
441 }
442 s++;
443 }
444 // Symbols that are in B, but not in A, are added at the end.
445 for (unsigned t = A->getNumDimAndSymbolIds(),
446 e = B->getNumDimAndSymbolIds();
447 t < e; t++) {
448 A->addSymbolId(A->getNumSymbolIds());
449 A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
450 }
451 }
452 assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
453 }
454
455 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
mergeAndAlignIdsWithOther(unsigned offset,FlatAffineConstraints * other)456 void FlatAffineConstraints::mergeAndAlignIdsWithOther(
457 unsigned offset, FlatAffineConstraints *other) {
458 mergeAndAlignIds(offset, this, other);
459 }
460
461 // This routine may add additional local variables if the flattened expression
462 // corresponding to the map has such variables due to mod's, ceildiv's, and
463 // floordiv's in it.
composeMap(const AffineValueMap * vMap)464 LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
465 std::vector<SmallVector<int64_t, 8>> flatExprs;
466 FlatAffineConstraints localCst;
467 if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
468 &localCst))) {
469 LLVM_DEBUG(llvm::dbgs()
470 << "composition unimplemented for semi-affine maps\n");
471 return failure();
472 }
473 assert(flatExprs.size() == vMap->getNumResults());
474
475 // Add localCst information.
476 if (localCst.getNumLocalIds() > 0) {
477 localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
478 /*values=*/vMap->getOperands());
479 // Align localCst and this.
480 mergeAndAlignIds(/*offset=*/0, &localCst, this);
481 // Finally, append localCst to this constraint set.
482 append(localCst);
483 }
484
485 // Add dimensions corresponding to the map's results.
486 for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
487 // TODO: Consider using a batched version to add a range of IDs.
488 addDimId(0);
489 }
490
491 // We add one equality for each result connecting the result dim of the map to
492 // the other identifiers.
493 // For eg: if the expression is 16*i0 + i1, and this is the r^th
494 // iteration/result of the value map, we are adding the equality:
495 // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
496 // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
497 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
498 const auto &flatExpr = flatExprs[r];
499 assert(flatExpr.size() >= vMap->getNumOperands() + 1);
500
501 // eqToAdd is the equality corresponding to the flattened affine expression.
502 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
503 // Set the coefficient for this result to one.
504 eqToAdd[r] = 1;
505
506 // Dims and symbols.
507 for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
508 unsigned loc;
509 bool ret = findId(vMap->getOperand(i), &loc);
510 assert(ret && "value map's id can't be found");
511 (void)ret;
512 // Negate 'eq[r]' since the newly added dimension will be set to this one.
513 eqToAdd[loc] = -flatExpr[i];
514 }
515 // Local vars common to eq and localCst are at the beginning.
516 unsigned j = getNumDimIds() + getNumSymbolIds();
517 unsigned end = flatExpr.size() - 1;
518 for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
519 eqToAdd[j] = -flatExpr[i];
520 }
521
522 // Constant term.
523 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
524
525 // Add the equality connecting the result of the map to this constraint set.
526 addEquality(eqToAdd);
527 }
528
529 return success();
530 }
531
532 // Similar to composeMap except that no Value's need be associated with the
533 // constraint system nor are they looked at -- since the dimensions and
534 // symbols of 'other' are expected to correspond 1:1 to 'this' system. It
535 // is thus not convenient to share code with composeMap.
composeMatchingMap(AffineMap other)536 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
537 assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
538 assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
539
540 std::vector<SmallVector<int64_t, 8>> flatExprs;
541 FlatAffineConstraints localCst;
542 if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
543 LLVM_DEBUG(llvm::dbgs()
544 << "composition unimplemented for semi-affine maps\n");
545 return failure();
546 }
547 assert(flatExprs.size() == other.getNumResults());
548
549 // Add localCst information.
550 if (localCst.getNumLocalIds() > 0) {
551 // Place local id's of A after local id's of B.
552 for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
553 addLocalId(0);
554 }
555 // Finally, append localCst to this constraint set.
556 append(localCst);
557 }
558
559 // Add dimensions corresponding to the map's results.
560 for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
561 addDimId(0);
562 }
563
564 // We add one equality for each result connecting the result dim of the map to
565 // the other identifiers.
566 // For eg: if the expression is 16*i0 + i1, and this is the r^th
567 // iteration/result of the value map, we are adding the equality:
568 // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
569 // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
570 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
571 const auto &flatExpr = flatExprs[r];
572 assert(flatExpr.size() >= other.getNumInputs() + 1);
573
574 // eqToAdd is the equality corresponding to the flattened affine expression.
575 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
576 // Set the coefficient for this result to one.
577 eqToAdd[r] = 1;
578
579 // Dims and symbols.
580 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
581 // Negate 'eq[r]' since the newly added dimension will be set to this one.
582 eqToAdd[e + i] = -flatExpr[i];
583 }
584 // Local vars common to eq and localCst are at the beginning.
585 unsigned j = getNumDimIds() + getNumSymbolIds();
586 unsigned end = flatExpr.size() - 1;
587 for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
588 eqToAdd[j] = -flatExpr[i];
589 }
590
591 // Constant term.
592 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
593
594 // Add the equality connecting the result of the map to this constraint set.
595 addEquality(eqToAdd);
596 }
597
598 return success();
599 }
600
601 // Turn a dimension into a symbol.
turnDimIntoSymbol(FlatAffineConstraints * cst,Value id)602 static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
603 unsigned pos;
604 if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
605 cst->swapId(pos, cst->getNumDimIds() - 1);
606 cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
607 }
608 }
609
610 // Turn a symbol into a dimension.
turnSymbolIntoDim(FlatAffineConstraints * cst,Value id)611 static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
612 unsigned pos;
613 if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
614 pos < cst->getNumDimAndSymbolIds()) {
615 cst->swapId(pos, cst->getNumDimIds());
616 cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
617 }
618 }
619
620 // Changes all symbol identifiers which are loop IVs to dim identifiers.
convertLoopIVSymbolsToDims()621 void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
622 // Gather all symbols which are loop IVs.
623 SmallVector<Value, 4> loopIVs;
624 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
625 if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
626 loopIVs.push_back(ids[i].getValue());
627 }
628 // Turn each symbol in 'loopIVs' into a dim identifier.
629 for (auto iv : loopIVs) {
630 turnSymbolIntoDim(this, iv);
631 }
632 }
633
addInductionVarOrTerminalSymbol(Value id)634 void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
635 if (containsId(id))
636 return;
637
638 // Caller is expected to fully compose map/operands if necessary.
639 assert((isTopLevelValue(id) || isForInductionVar(id)) &&
640 "non-terminal symbol / loop IV expected");
641 // Outer loop IVs could be used in forOp's bounds.
642 if (auto loop = getForInductionVarOwner(id)) {
643 addDimId(getNumDimIds(), id);
644 if (failed(this->addAffineForOpDomain(loop)))
645 LLVM_DEBUG(
646 loop.emitWarning("failed to add domain info to constraint system"));
647 return;
648 }
649 // Add top level symbol.
650 addSymbolId(getNumSymbolIds(), id);
651 // Check if the symbol is a constant.
652 if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
653 setIdToConstant(id, constOp.getValue());
654 }
655
addAffineForOpDomain(AffineForOp forOp)656 LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
657 unsigned pos;
658 // Pre-condition for this method.
659 if (!findId(forOp.getInductionVar(), &pos)) {
660 assert(false && "Value not found");
661 return failure();
662 }
663
664 int64_t step = forOp.getStep();
665 if (step != 1) {
666 if (!forOp.hasConstantLowerBound())
667 forOp.emitWarning("domain conservatively approximated");
668 else {
669 // Add constraints for the stride.
670 // (iv - lb) % step = 0 can be written as:
671 // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
672 // Add local variable 'q' and add the above equality.
673 // The first constraint is q = (iv - lb) floordiv step
674 SmallVector<int64_t, 8> dividend(getNumCols(), 0);
675 int64_t lb = forOp.getConstantLowerBound();
676 dividend[pos] = 1;
677 dividend.back() -= lb;
678 addLocalFloorDiv(dividend, step);
679 // Second constraint: (iv - lb) - step * q = 0.
680 SmallVector<int64_t, 8> eq(getNumCols(), 0);
681 eq[pos] = 1;
682 eq.back() -= lb;
683 // For the local var just added above.
684 eq[getNumCols() - 2] = -step;
685 addEquality(eq);
686 }
687 }
688
689 if (forOp.hasConstantLowerBound()) {
690 addConstantLowerBound(pos, forOp.getConstantLowerBound());
691 } else {
692 // Non-constant lower bound case.
693 if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(),
694 forOp.getLowerBoundOperands(),
695 /*eq=*/false, /*lower=*/true)))
696 return failure();
697 }
698
699 if (forOp.hasConstantUpperBound()) {
700 addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
701 return success();
702 }
703 // Non-constant upper bound case.
704 return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(),
705 forOp.getUpperBoundOperands(),
706 /*eq=*/false, /*lower=*/false);
707 }
708
addAffineIfOpDomain(AffineIfOp ifOp)709 void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
710 // Create the base constraints from the integer set attached to ifOp.
711 FlatAffineConstraints cst(ifOp.getIntegerSet());
712
713 // Bind ids in the constraints to ifOp operands.
714 SmallVector<Value, 4> operands = ifOp.getOperands();
715 cst.setIdValues(0, cst.getNumDimAndSymbolIds(), operands);
716
717 // Merge the constraints from ifOp to the current domain. We need first merge
718 // and align the IDs from both constraints, and then append the constraints
719 // from the ifOp into the current one.
720 mergeAndAlignIdsWithOther(0, &cst);
721 append(cst);
722 }
723
724 // Searches for a constraint with a non-zero coefficient at 'colIdx' in
725 // equality (isEq=true) or inequality (isEq=false) constraints.
726 // Returns true and sets row found in search in 'rowIdx'.
727 // Returns false otherwise.
findConstraintWithNonZeroAt(const FlatAffineConstraints & cst,unsigned colIdx,bool isEq,unsigned * rowIdx)728 static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
729 unsigned colIdx, bool isEq,
730 unsigned *rowIdx) {
731 assert(colIdx < cst.getNumCols() && "position out of bounds");
732 auto at = [&](unsigned rowIdx) -> int64_t {
733 return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
734 };
735 unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
736 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
737 if (at(*rowIdx) != 0) {
738 return true;
739 }
740 }
741 return false;
742 }
743
744 // Normalizes the coefficient values across all columns in 'rowIDx' by their
745 // GCD in equality or inequality constraints as specified by 'isEq'.
746 template <bool isEq>
normalizeConstraintByGCD(FlatAffineConstraints * constraints,unsigned rowIdx)747 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
748 unsigned rowIdx) {
749 auto at = [&](unsigned colIdx) -> int64_t {
750 return isEq ? constraints->atEq(rowIdx, colIdx)
751 : constraints->atIneq(rowIdx, colIdx);
752 };
753 uint64_t gcd = std::abs(at(0));
754 for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
755 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
756 }
757 if (gcd > 0 && gcd != 1) {
758 for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
759 int64_t v = at(j) / static_cast<int64_t>(gcd);
760 isEq ? constraints->atEq(rowIdx, j) = v
761 : constraints->atIneq(rowIdx, j) = v;
762 }
763 }
764 }
765
normalizeConstraintsByGCD()766 void FlatAffineConstraints::normalizeConstraintsByGCD() {
767 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
768 normalizeConstraintByGCD</*isEq=*/true>(this, i);
769 }
770 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
771 normalizeConstraintByGCD</*isEq=*/false>(this, i);
772 }
773 }
774
hasConsistentState() const775 bool FlatAffineConstraints::hasConsistentState() const {
776 if (inequalities.size() != getNumInequalities() * numReservedCols)
777 return false;
778 if (equalities.size() != getNumEqualities() * numReservedCols)
779 return false;
780 if (ids.size() != getNumIds())
781 return false;
782
783 // Catches errors where numDims, numSymbols, numIds aren't consistent.
784 if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
785 return false;
786
787 return true;
788 }
789
790 /// Checks all rows of equality/inequality constraints for trivial
791 /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
792 /// after elimination. Returns 'true' if an invalid constraint is found;
793 /// 'false' otherwise.
hasInvalidConstraint() const794 bool FlatAffineConstraints::hasInvalidConstraint() const {
795 assert(hasConsistentState());
796 auto check = [&](bool isEq) -> bool {
797 unsigned numCols = getNumCols();
798 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
799 for (unsigned i = 0, e = numRows; i < e; ++i) {
800 unsigned j;
801 for (j = 0; j < numCols - 1; ++j) {
802 int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
803 // Skip rows with non-zero variable coefficients.
804 if (v != 0)
805 break;
806 }
807 if (j < numCols - 1) {
808 continue;
809 }
810 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
811 // Example invalid constraints include: '1 == 0' or '-1 >= 0'
812 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
813 if ((isEq && v != 0) || (!isEq && v < 0)) {
814 return true;
815 }
816 }
817 return false;
818 };
819 if (check(/*isEq=*/true))
820 return true;
821 return check(/*isEq=*/false);
822 }
823
824 // Eliminate identifier from constraint at 'rowIdx' based on coefficient at
825 // pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
826 // updated as they have already been eliminated.
eliminateFromConstraint(FlatAffineConstraints * constraints,unsigned rowIdx,unsigned pivotRow,unsigned pivotCol,unsigned elimColStart,bool isEq)827 static void eliminateFromConstraint(FlatAffineConstraints *constraints,
828 unsigned rowIdx, unsigned pivotRow,
829 unsigned pivotCol, unsigned elimColStart,
830 bool isEq) {
831 // Skip if equality 'rowIdx' if same as 'pivotRow'.
832 if (isEq && rowIdx == pivotRow)
833 return;
834 auto at = [&](unsigned i, unsigned j) -> int64_t {
835 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
836 };
837 int64_t leadCoeff = at(rowIdx, pivotCol);
838 // Skip if leading coefficient at 'rowIdx' is already zero.
839 if (leadCoeff == 0)
840 return;
841 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
842 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
843 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
844 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
845 int64_t rowMultiplier = lcm / std::abs(leadCoeff);
846
847 unsigned numCols = constraints->getNumCols();
848 for (unsigned j = 0; j < numCols; ++j) {
849 // Skip updating column 'j' if it was just eliminated.
850 if (j >= elimColStart && j < pivotCol)
851 continue;
852 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
853 rowMultiplier * at(rowIdx, j);
854 isEq ? constraints->atEq(rowIdx, j) = v
855 : constraints->atIneq(rowIdx, j) = v;
856 }
857 }
858
859 // Remove coefficients in column range [colStart, colLimit) in place.
860 // This removes in data in the specified column range, and copies any
861 // remaining valid data into place.
shiftColumnsToLeft(FlatAffineConstraints * constraints,unsigned colStart,unsigned colLimit,bool isEq)862 static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
863 unsigned colStart, unsigned colLimit,
864 bool isEq) {
865 assert(colLimit <= constraints->getNumIds());
866 if (colLimit <= colStart)
867 return;
868
869 unsigned numCols = constraints->getNumCols();
870 unsigned numRows = isEq ? constraints->getNumEqualities()
871 : constraints->getNumInequalities();
872 unsigned numToEliminate = colLimit - colStart;
873 for (unsigned r = 0, e = numRows; r < e; ++r) {
874 for (unsigned c = colLimit; c < numCols; ++c) {
875 if (isEq) {
876 constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
877 } else {
878 constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
879 }
880 }
881 }
882 }
883
884 // Removes identifiers in column range [idStart, idLimit), and copies any
885 // remaining valid data into place, and updates member variables.
removeIdRange(unsigned idStart,unsigned idLimit)886 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
887 assert(idLimit < getNumCols() && "invalid id limit");
888
889 if (idStart >= idLimit)
890 return;
891
892 // We are going to be removing one or more identifiers from the range.
893 assert(idStart < numIds && "invalid idStart position");
894
895 // TODO: Make 'removeIdRange' a lambda called from here.
896 // Remove eliminated identifiers from equalities.
897 shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
898
899 // Remove eliminated identifiers from inequalities.
900 shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
901
902 // Update members numDims, numSymbols and numIds.
903 unsigned numDimsEliminated = 0;
904 unsigned numLocalsEliminated = 0;
905 unsigned numColsEliminated = idLimit - idStart;
906 if (idStart < numDims) {
907 numDimsEliminated = std::min(numDims, idLimit) - idStart;
908 }
909 // Check how many local id's were removed. Note that our identifier order is
910 // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
911 if (idLimit > numDims + numSymbols) {
912 numLocalsEliminated = std::min(
913 idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
914 }
915 unsigned numSymbolsEliminated =
916 numColsEliminated - numDimsEliminated - numLocalsEliminated;
917
918 numDims -= numDimsEliminated;
919 numSymbols -= numSymbolsEliminated;
920 numIds = numIds - numColsEliminated;
921
922 ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
923
924 // No resize necessary. numReservedCols remains the same.
925 }
926
927 /// Returns the position of the identifier that has the minimum <number of lower
928 /// bounds> times <number of upper bounds> from the specified range of
929 /// identifiers [start, end). It is often best to eliminate in the increasing
930 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
931 /// that many new constraints.
getBestIdToEliminate(const FlatAffineConstraints & cst,unsigned start,unsigned end)932 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
933 unsigned start, unsigned end) {
934 assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
935
936 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
937 unsigned numLb = 0;
938 unsigned numUb = 0;
939 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
940 if (cst.atIneq(r, pos) > 0) {
941 ++numLb;
942 } else if (cst.atIneq(r, pos) < 0) {
943 ++numUb;
944 }
945 }
946 return numLb * numUb;
947 };
948
949 unsigned minLoc = start;
950 unsigned min = getProductOfNumLowerUpperBounds(start);
951 for (unsigned c = start + 1; c < end; c++) {
952 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
953 if (numLbUbProduct < min) {
954 min = numLbUbProduct;
955 minLoc = c;
956 }
957 }
958 return minLoc;
959 }
960
961 // Checks for emptiness of the set by eliminating identifiers successively and
962 // using the GCD test (on all equality constraints) and checking for trivially
963 // invalid constraints. Returns 'true' if the constraint system is found to be
964 // empty; false otherwise.
isEmpty() const965 bool FlatAffineConstraints::isEmpty() const {
966 if (isEmptyByGCDTest() || hasInvalidConstraint())
967 return true;
968
969 // First, eliminate as many identifiers as possible using Gaussian
970 // elimination.
971 FlatAffineConstraints tmpCst(*this);
972 unsigned currentPos = 0;
973 while (currentPos < tmpCst.getNumIds()) {
974 tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
975 ++currentPos;
976 // We check emptiness through trivial checks after eliminating each ID to
977 // detect emptiness early. Since the checks isEmptyByGCDTest() and
978 // hasInvalidConstraint() are linear time and single sweep on the constraint
979 // buffer, this appears reasonable - but can optimize in the future.
980 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
981 return true;
982 }
983
984 // Eliminate the remaining using FM.
985 for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
986 tmpCst.FourierMotzkinEliminate(
987 getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
988 // Check for a constraint explosion. This rarely happens in practice, but
989 // this check exists as a safeguard against improperly constructed
990 // constraint systems or artificially created arbitrarily complex systems
991 // that aren't the intended use case for FlatAffineConstraints. This is
992 // needed since FM has a worst case exponential complexity in theory.
993 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
994 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
995 return false;
996 }
997
998 // FM wouldn't have modified the equalities in any way. So no need to again
999 // run GCD test. Check for trivial invalid constraints.
1000 if (tmpCst.hasInvalidConstraint())
1001 return true;
1002 }
1003 return false;
1004 }
1005
1006 // Runs the GCD test on all equality constraints. Returns 'true' if this test
1007 // fails on any equality. Returns 'false' otherwise.
1008 // This test can be used to disprove the existence of a solution. If it returns
1009 // true, no integer solution to the equality constraints can exist.
1010 //
1011 // GCD test definition:
1012 //
1013 // The equality constraint:
1014 //
1015 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1016 //
1017 // has an integer solution iff:
1018 //
1019 // GCD of c_1, c_2, ..., c_n divides c_0.
1020 //
isEmptyByGCDTest() const1021 bool FlatAffineConstraints::isEmptyByGCDTest() const {
1022 assert(hasConsistentState());
1023 unsigned numCols = getNumCols();
1024 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1025 uint64_t gcd = std::abs(atEq(i, 0));
1026 for (unsigned j = 1; j < numCols - 1; ++j) {
1027 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1028 }
1029 int64_t v = std::abs(atEq(i, numCols - 1));
1030 if (gcd > 0 && (v % gcd != 0)) {
1031 return true;
1032 }
1033 }
1034 return false;
1035 }
1036
1037 // First, try the GCD test heuristic.
1038 //
1039 // If that doesn't find the set empty, check if the set is unbounded. If it is,
1040 // we cannot use the GBR algorithm and we conservatively return false.
1041 //
1042 // If the set is bounded, we use the complete emptiness check for this case
1043 // provided by Simplex::findIntegerSample(), which gives a definitive answer.
isIntegerEmpty() const1044 bool FlatAffineConstraints::isIntegerEmpty() const {
1045 if (isEmptyByGCDTest())
1046 return true;
1047
1048 Simplex simplex(*this);
1049 if (simplex.isUnbounded())
1050 return false;
1051 return !simplex.findIntegerSample().hasValue();
1052 }
1053
1054 Optional<SmallVector<int64_t, 8>>
findIntegerSample() const1055 FlatAffineConstraints::findIntegerSample() const {
1056 return Simplex(*this).findIntegerSample();
1057 }
1058
1059 /// Helper to evaluate an affine expression at a point.
1060 /// The expression is a list of coefficients for the dimensions followed by the
1061 /// constant term.
valueAt(ArrayRef<int64_t> expr,ArrayRef<int64_t> point)1062 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
1063 assert(expr.size() == 1 + point.size() &&
1064 "Dimensionalities of point and expresion don't match!");
1065 int64_t value = expr.back();
1066 for (unsigned i = 0; i < point.size(); ++i)
1067 value += expr[i] * point[i];
1068 return value;
1069 }
1070
1071 /// A point satisfies an equality iff the value of the equality at the
1072 /// expression is zero, and it satisfies an inequality iff the value of the
1073 /// inequality at that point is non-negative.
containsPoint(ArrayRef<int64_t> point) const1074 bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
1075 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1076 if (valueAt(getEquality(i), point) != 0)
1077 return false;
1078 }
1079 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1080 if (valueAt(getInequality(i), point) < 0)
1081 return false;
1082 }
1083 return true;
1084 }
1085
1086 /// Tightens inequalities given that we are dealing with integer spaces. This is
1087 /// analogous to the GCD test but applied to inequalities. The constant term can
1088 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1089 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a
1090 /// fast method - linear in the number of coefficients.
1091 // Example on how this affects practical cases: consider the scenario:
1092 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1093 // j >= 100 instead of the tighter (exact) j >= 128.
GCDTightenInequalities()1094 void FlatAffineConstraints::GCDTightenInequalities() {
1095 unsigned numCols = getNumCols();
1096 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1097 uint64_t gcd = std::abs(atIneq(i, 0));
1098 for (unsigned j = 1; j < numCols - 1; ++j) {
1099 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1100 }
1101 if (gcd > 0 && gcd != 1) {
1102 int64_t gcdI = static_cast<int64_t>(gcd);
1103 // Tighten the constant term and normalize the constraint by the GCD.
1104 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1105 for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1106 atIneq(i, j) /= gcdI;
1107 }
1108 }
1109 }
1110
1111 // Eliminates all identifier variables in column range [posStart, posLimit).
1112 // Returns the number of variables eliminated.
gaussianEliminateIds(unsigned posStart,unsigned posLimit)1113 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1114 unsigned posLimit) {
1115 // Return if identifier positions to eliminate are out of range.
1116 assert(posLimit <= numIds);
1117 assert(hasConsistentState());
1118
1119 if (posStart >= posLimit)
1120 return 0;
1121
1122 GCDTightenInequalities();
1123
1124 unsigned pivotCol = 0;
1125 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1126 // Find a row which has a non-zero coefficient in column 'j'.
1127 unsigned pivotRow;
1128 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1129 &pivotRow)) {
1130 // No pivot row in equalities with non-zero at 'pivotCol'.
1131 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1132 &pivotRow)) {
1133 // If inequalities are also non-zero in 'pivotCol', it can be
1134 // eliminated.
1135 continue;
1136 }
1137 break;
1138 }
1139
1140 // Eliminate identifier at 'pivotCol' from each equality row.
1141 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1142 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1143 /*isEq=*/true);
1144 normalizeConstraintByGCD</*isEq=*/true>(this, i);
1145 }
1146
1147 // Eliminate identifier at 'pivotCol' from each inequality row.
1148 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1149 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1150 /*isEq=*/false);
1151 normalizeConstraintByGCD</*isEq=*/false>(this, i);
1152 }
1153 removeEquality(pivotRow);
1154 GCDTightenInequalities();
1155 }
1156 // Update position limit based on number eliminated.
1157 posLimit = pivotCol;
1158 // Remove eliminated columns from all constraints.
1159 removeIdRange(posStart, posLimit);
1160 return posLimit - posStart;
1161 }
1162
1163 // Detect the identifier at 'pos' (say id_r) as modulo of another identifier
1164 // (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
1165 // could be detected as the floordiv of n. For eg:
1166 // id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
1167 // id_r = id_n mod 4, id_q = id_n floordiv 4.
1168 // lbConst and ubConst are the constant lower and upper bounds for 'pos' -
1169 // pre-detected at the caller.
detectAsMod(const FlatAffineConstraints & cst,unsigned pos,int64_t lbConst,int64_t ubConst,SmallVectorImpl<AffineExpr> * memo)1170 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1171 int64_t lbConst, int64_t ubConst,
1172 SmallVectorImpl<AffineExpr> *memo) {
1173 assert(pos < cst.getNumIds() && "invalid position");
1174
1175 // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
1176 // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
1177 // and id_q the quotient when dividing id_n by the divisor.
1178
1179 if (lbConst != 0 || ubConst < 1)
1180 return false;
1181
1182 int64_t divisor = ubConst + 1;
1183
1184 // Now check for: id_r = id_n - divisor * id_q. As an example, we
1185 // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
1186 unsigned seenQuotient = 0, seenDividend = 0;
1187 int quotientPos = -1, dividendPos = -1;
1188 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1189 // id_n should have coeff 1 or -1.
1190 if (std::abs(cst.atEq(r, pos)) != 1)
1191 continue;
1192 // constant term should be 0.
1193 if (cst.atEq(r, cst.getNumCols() - 1) != 0)
1194 continue;
1195 unsigned c, f;
1196 int quotientSign = 1, dividendSign = 1;
1197 for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
1198 if (c == pos)
1199 continue;
1200 // The coefficient of the quotient should be +/-divisor.
1201 // TODO: could be extended to detect an affine function for the quotient
1202 // (i.e., the coeff could be a non-zero multiple of divisor).
1203 int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
1204 if (v == divisor || v == -divisor) {
1205 seenQuotient++;
1206 quotientPos = c;
1207 quotientSign = v > 0 ? 1 : -1;
1208 }
1209 // The coefficient of the dividend should be +/-1.
1210 // TODO: could be extended to detect an affine function of the other
1211 // identifiers as the dividend.
1212 else if (v == -1 || v == 1) {
1213 seenDividend++;
1214 dividendPos = c;
1215 dividendSign = v < 0 ? 1 : -1;
1216 } else if (cst.atEq(r, c) != 0) {
1217 // Cannot be inferred as a mod since the constraint has a coefficient
1218 // for an identifier that's neither a unit nor the divisor (see TODOs
1219 // above).
1220 break;
1221 }
1222 }
1223 if (c < f)
1224 // Cannot be inferred as a mod since the constraint has a coefficient for
1225 // an identifier that's neither a unit nor the divisor (see TODOs above).
1226 continue;
1227
1228 // We are looking for exactly one identifier as the dividend.
1229 if (seenDividend == 1 && seenQuotient >= 1) {
1230 if (!(*memo)[dividendPos])
1231 return false;
1232 // Successfully detected a mod.
1233 (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1234 auto ub = cst.getConstantUpperBound(dividendPos);
1235 if (ub.hasValue() && ub.getValue() < divisor)
1236 // The mod can be optimized away.
1237 (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
1238 else
1239 (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1240
1241 if (seenQuotient == 1 && !(*memo)[quotientPos])
1242 // Successfully detected a floordiv as well.
1243 (*memo)[quotientPos] =
1244 (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
1245 return true;
1246 }
1247 }
1248 return false;
1249 }
1250
1251 /// Gather all lower and upper bounds of the identifier at `pos`, and
1252 /// optionally any equalities on it. In addition, the bounds are to be
1253 /// independent of identifiers in position range [`offset`, `offset` + `num`).
getLowerAndUpperBoundIndices(unsigned pos,SmallVectorImpl<unsigned> * lbIndices,SmallVectorImpl<unsigned> * ubIndices,SmallVectorImpl<unsigned> * eqIndices,unsigned offset,unsigned num) const1254 void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1255 unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1256 SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1257 unsigned offset, unsigned num) const {
1258 assert(pos < getNumIds() && "invalid position");
1259 assert(offset + num < getNumCols() && "invalid range");
1260
1261 // Checks for a constraint that has a non-zero coeff for the identifiers in
1262 // the position range [offset, offset + num) while ignoring `pos`.
1263 auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1264 unsigned c, f;
1265 auto cst = isEq ? getEquality(r) : getInequality(r);
1266 for (c = offset, f = offset + num; c < f; ++c) {
1267 if (c == pos)
1268 continue;
1269 if (cst[c] != 0)
1270 break;
1271 }
1272 return c < f;
1273 };
1274
1275 // Gather all lower bounds and upper bounds of the variable. Since the
1276 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1277 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1278 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1279 // The bounds are to be independent of [offset, offset + num) columns.
1280 if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1281 continue;
1282 if (atIneq(r, pos) >= 1) {
1283 // Lower bound.
1284 lbIndices->push_back(r);
1285 } else if (atIneq(r, pos) <= -1) {
1286 // Upper bound.
1287 ubIndices->push_back(r);
1288 }
1289 }
1290
1291 // An equality is both a lower and upper bound. Record any equalities
1292 // involving the pos^th identifier.
1293 if (!eqIndices)
1294 return;
1295
1296 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1297 if (atEq(r, pos) == 0)
1298 continue;
1299 if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1300 continue;
1301 eqIndices->push_back(r);
1302 }
1303 }
1304
1305 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1306 /// function of other identifiers (where the divisor is a positive constant)
1307 /// given the initial set of expressions in `exprs`. If it can be, the
1308 /// corresponding position in `exprs` is set as the detected affine expr. For
1309 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
1310 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1311 /// <= i <= 32q + 31 => q = i floordiv 32.
detectAsFloorDiv(const FlatAffineConstraints & cst,unsigned pos,MLIRContext * context,SmallVectorImpl<AffineExpr> & exprs)1312 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1313 MLIRContext *context,
1314 SmallVectorImpl<AffineExpr> &exprs) {
1315 assert(pos < cst.getNumIds() && "invalid position");
1316
1317 SmallVector<unsigned, 4> lbIndices, ubIndices;
1318 cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1319
1320 // Check if any lower bound, upper bound pair is of the form:
1321 // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id'
1322 // divisor * id <= expr <-- Upper bound for 'id'
1323 // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1).
1324 //
1325 // For example, if -32*k + 16*i + j >= 0
1326 // 32*k - 16*i - j + 31 >= 0 <=>
1327 // k = ( 16*i + j ) floordiv 32
1328 unsigned seenDividends = 0;
1329 for (auto ubPos : ubIndices) {
1330 for (auto lbPos : lbIndices) {
1331 // Check if the lower bound's constant term is divisor - 1. The
1332 // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's
1333 // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'.
1334 int64_t divisor = cst.atIneq(lbPos, pos);
1335 int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1);
1336 if (lbConstTerm != divisor - 1)
1337 continue;
1338 // Check if upper bound's constant term is 0.
1339 if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
1340 continue;
1341 // For the remaining part, check if the lower bound expr's coeff's are
1342 // negations of corresponding upper bound ones'.
1343 unsigned c, f;
1344 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1345 if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
1346 break;
1347 if (c != pos && cst.atIneq(lbPos, c) != 0)
1348 seenDividends++;
1349 }
1350 // Lb coeff's aren't negative of ub coeff's (for the non constant term
1351 // part).
1352 if (c < f)
1353 continue;
1354 if (seenDividends >= 1) {
1355 // Construct the dividend expression.
1356 auto dividendExpr = getAffineConstantExpr(0, context);
1357 unsigned c, f;
1358 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1359 if (c == pos)
1360 continue;
1361 int64_t ubVal = cst.atIneq(ubPos, c);
1362 if (ubVal == 0)
1363 continue;
1364 if (!exprs[c])
1365 break;
1366 dividendExpr = dividendExpr + ubVal * exprs[c];
1367 }
1368 // Expression can't be constructed as it depends on a yet unknown
1369 // identifier.
1370 // TODO: Visit/compute the identifiers in an order so that this doesn't
1371 // happen. More complex but much more efficient.
1372 if (c < f)
1373 continue;
1374 // Successfully detected the floordiv.
1375 exprs[pos] = dividendExpr.floorDiv(divisor);
1376 return true;
1377 }
1378 }
1379 }
1380 return false;
1381 }
1382
1383 // Fills an inequality row with the value 'val'.
fillInequality(FlatAffineConstraints * cst,unsigned r,int64_t val)1384 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1385 int64_t val) {
1386 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1387 cst->atIneq(r, c) = val;
1388 }
1389 }
1390
1391 // Negates an inequality.
negateInequality(FlatAffineConstraints * cst,unsigned r)1392 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1393 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1394 cst->atIneq(r, c) = -cst->atIneq(r, c);
1395 }
1396 }
1397
1398 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1399 // to check if a constraint is redundant.
removeRedundantInequalities()1400 void FlatAffineConstraints::removeRedundantInequalities() {
1401 SmallVector<bool, 32> redun(getNumInequalities(), false);
1402 // To check if an inequality is redundant, we replace the inequality by its
1403 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1404 // system is empty. If it is, the inequality is redundant.
1405 FlatAffineConstraints tmpCst(*this);
1406 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1407 // Change the inequality to its complement.
1408 negateInequality(&tmpCst, r);
1409 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1410 if (tmpCst.isEmpty()) {
1411 redun[r] = true;
1412 // Zero fill the redundant inequality.
1413 fillInequality(this, r, /*val=*/0);
1414 fillInequality(&tmpCst, r, /*val=*/0);
1415 } else {
1416 // Reverse the change (to avoid recreating tmpCst each time).
1417 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1418 negateInequality(&tmpCst, r);
1419 }
1420 }
1421
1422 // Scan to get rid of all rows marked redundant, in-place.
1423 auto copyRow = [&](unsigned src, unsigned dest) {
1424 if (src == dest)
1425 return;
1426 for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1427 atIneq(dest, c) = atIneq(src, c);
1428 }
1429 };
1430 unsigned pos = 0;
1431 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1432 if (!redun[r])
1433 copyRow(r, pos++);
1434 }
1435 inequalities.resize(numReservedCols * pos);
1436 }
1437
1438 // A more complex check to eliminate redundant inequalities and equalities. Uses
1439 // Simplex to check if a constraint is redundant.
removeRedundantConstraints()1440 void FlatAffineConstraints::removeRedundantConstraints() {
1441 // First, we run GCDTightenInequalities. This allows us to catch some
1442 // constraints which are not redundant when considering rational solutions
1443 // but are redundant in terms of integer solutions.
1444 GCDTightenInequalities();
1445 Simplex simplex(*this);
1446 simplex.detectRedundant();
1447
1448 auto copyInequality = [&](unsigned src, unsigned dest) {
1449 if (src == dest)
1450 return;
1451 for (unsigned c = 0, e = getNumCols(); c < e; c++)
1452 atIneq(dest, c) = atIneq(src, c);
1453 };
1454 unsigned pos = 0;
1455 unsigned numIneqs = getNumInequalities();
1456 // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1457 // the first constraints added are the inequalities.
1458 for (unsigned r = 0; r < numIneqs; r++) {
1459 if (!simplex.isMarkedRedundant(r))
1460 copyInequality(r, pos++);
1461 }
1462 inequalities.resize(numReservedCols * pos);
1463
1464 // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1465 // after the inequalities, a pair of constraints for each equality is added.
1466 // An equality is redundant if both the inequalities in its pair are
1467 // redundant.
1468 auto copyEquality = [&](unsigned src, unsigned dest) {
1469 if (src == dest)
1470 return;
1471 for (unsigned c = 0, e = getNumCols(); c < e; c++)
1472 atEq(dest, c) = atEq(src, c);
1473 };
1474 pos = 0;
1475 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1476 if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1477 simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1478 copyEquality(r, pos++);
1479 }
1480 equalities.resize(numReservedCols * pos);
1481 }
1482
getLowerAndUpperBound(unsigned pos,unsigned offset,unsigned num,unsigned symStartPos,ArrayRef<AffineExpr> localExprs,MLIRContext * context) const1483 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1484 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1485 ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1486 assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1487 assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1488 assert(getNumLocalIds() == localExprs.size() &&
1489 "incorrect local exprs count");
1490
1491 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1492 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1493 offset, num);
1494
1495 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1496 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1497 b.clear();
1498 for (unsigned i = 0, e = a.size(); i < e; ++i) {
1499 if (i < offset || i >= offset + num)
1500 b.push_back(a[i]);
1501 }
1502 };
1503
1504 SmallVector<int64_t, 8> lb, ub;
1505 SmallVector<AffineExpr, 4> lbExprs;
1506 unsigned dimCount = symStartPos - num;
1507 unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1508 lbExprs.reserve(lbIndices.size() + eqIndices.size());
1509 // Lower bound expressions.
1510 for (auto idx : lbIndices) {
1511 auto ineq = getInequality(idx);
1512 // Extract the lower bound (in terms of other coeff's + const), i.e., if
1513 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1514 // - 1.
1515 addCoeffs(ineq, lb);
1516 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1517 auto expr =
1518 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1519 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1520 int64_t divisor = std::abs(ineq[pos + offset]);
1521 expr = (expr + divisor - 1).floorDiv(divisor);
1522 lbExprs.push_back(expr);
1523 }
1524
1525 SmallVector<AffineExpr, 4> ubExprs;
1526 ubExprs.reserve(ubIndices.size() + eqIndices.size());
1527 // Upper bound expressions.
1528 for (auto idx : ubIndices) {
1529 auto ineq = getInequality(idx);
1530 // Extract the upper bound (in terms of other coeff's + const).
1531 addCoeffs(ineq, ub);
1532 auto expr =
1533 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1534 expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1535 // Upper bound is exclusive.
1536 ubExprs.push_back(expr + 1);
1537 }
1538
1539 // Equalities. It's both a lower and a upper bound.
1540 SmallVector<int64_t, 4> b;
1541 for (auto idx : eqIndices) {
1542 auto eq = getEquality(idx);
1543 addCoeffs(eq, b);
1544 if (eq[pos + offset] > 0)
1545 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1546
1547 // Extract the upper bound (in terms of other coeff's + const).
1548 auto expr =
1549 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1550 expr = expr.floorDiv(std::abs(eq[pos + offset]));
1551 // Upper bound is exclusive.
1552 ubExprs.push_back(expr + 1);
1553 // Lower bound.
1554 expr =
1555 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1556 expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1557 lbExprs.push_back(expr);
1558 }
1559
1560 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1561 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1562
1563 return {lbMap, ubMap};
1564 }
1565
1566 /// Computes the lower and upper bounds of the first 'num' dimensional
1567 /// identifiers (starting at 'offset') as affine maps of the remaining
1568 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
1569 /// themselves explicitly computed as affine functions of other identifiers in
1570 /// this process if needed.
getSliceBounds(unsigned offset,unsigned num,MLIRContext * context,SmallVectorImpl<AffineMap> * lbMaps,SmallVectorImpl<AffineMap> * ubMaps)1571 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1572 MLIRContext *context,
1573 SmallVectorImpl<AffineMap> *lbMaps,
1574 SmallVectorImpl<AffineMap> *ubMaps) {
1575 assert(num < getNumDimIds() && "invalid range");
1576
1577 // Basic simplification.
1578 normalizeConstraintsByGCD();
1579
1580 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1581 << " identifiers\n");
1582 LLVM_DEBUG(dump());
1583
1584 // Record computed/detected identifiers.
1585 SmallVector<AffineExpr, 8> memo(getNumIds());
1586 // Initialize dimensional and symbolic identifiers.
1587 for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
1588 if (i < offset)
1589 memo[i] = getAffineDimExpr(i, context);
1590 else if (i >= offset + num)
1591 memo[i] = getAffineDimExpr(i - num, context);
1592 }
1593 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
1594 memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
1595
1596 bool changed;
1597 do {
1598 changed = false;
1599 // Identify yet unknown identifiers as constants or mod's / floordiv's of
1600 // other identifiers if possible.
1601 for (unsigned pos = 0; pos < getNumIds(); pos++) {
1602 if (memo[pos])
1603 continue;
1604
1605 auto lbConst = getConstantLowerBound(pos);
1606 auto ubConst = getConstantUpperBound(pos);
1607 if (lbConst.hasValue() && ubConst.hasValue()) {
1608 // Detect equality to a constant.
1609 if (lbConst.getValue() == ubConst.getValue()) {
1610 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1611 changed = true;
1612 continue;
1613 }
1614
1615 // Detect an identifier as modulo of another identifier w.r.t a
1616 // constant.
1617 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1618 &memo)) {
1619 changed = true;
1620 continue;
1621 }
1622 }
1623
1624 // Detect an identifier as a floordiv of an affine function of other
1625 // identifiers (divisor is a positive constant).
1626 if (detectAsFloorDiv(*this, pos, context, memo)) {
1627 changed = true;
1628 continue;
1629 }
1630
1631 // Detect an identifier as an expression of other identifiers.
1632 unsigned idx;
1633 if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
1634 continue;
1635 }
1636
1637 // Build AffineExpr solving for identifier 'pos' in terms of all others.
1638 auto expr = getAffineConstantExpr(0, context);
1639 unsigned j, e;
1640 for (j = 0, e = getNumIds(); j < e; ++j) {
1641 if (j == pos)
1642 continue;
1643 int64_t c = atEq(idx, j);
1644 if (c == 0)
1645 continue;
1646 // If any of the involved IDs hasn't been found yet, we can't proceed.
1647 if (!memo[j])
1648 break;
1649 expr = expr + memo[j] * c;
1650 }
1651 if (j < e)
1652 // Can't construct expression as it depends on a yet uncomputed
1653 // identifier.
1654 continue;
1655
1656 // Add constant term to AffineExpr.
1657 expr = expr + atEq(idx, getNumIds());
1658 int64_t vPos = atEq(idx, pos);
1659 assert(vPos != 0 && "expected non-zero here");
1660 if (vPos > 0)
1661 expr = (-expr).floorDiv(vPos);
1662 else
1663 // vPos < 0.
1664 expr = expr.floorDiv(-vPos);
1665 // Successfully constructed expression.
1666 memo[pos] = expr;
1667 changed = true;
1668 }
1669 // This loop is guaranteed to reach a fixed point - since once an
1670 // identifier's explicit form is computed (in memo[pos]), it's not updated
1671 // again.
1672 } while (changed);
1673
1674 // Set the lower and upper bound maps for all the identifiers that were
1675 // computed as affine expressions of the rest as the "detected expr" and
1676 // "detected expr + 1" respectively; set the undetected ones to null.
1677 Optional<FlatAffineConstraints> tmpClone;
1678 for (unsigned pos = 0; pos < num; pos++) {
1679 unsigned numMapDims = getNumDimIds() - num;
1680 unsigned numMapSymbols = getNumSymbolIds();
1681 AffineExpr expr = memo[pos + offset];
1682 if (expr)
1683 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1684
1685 AffineMap &lbMap = (*lbMaps)[pos];
1686 AffineMap &ubMap = (*ubMaps)[pos];
1687
1688 if (expr) {
1689 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1690 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1691 } else {
1692 // TODO: Whenever there are local identifiers in the dependence
1693 // constraints, we'll conservatively over-approximate, since we don't
1694 // always explicitly compute them above (in the while loop).
1695 if (getNumLocalIds() == 0) {
1696 // Work on a copy so that we don't update this constraint system.
1697 if (!tmpClone) {
1698 tmpClone.emplace(FlatAffineConstraints(*this));
1699 // Removing redundant inequalities is necessary so that we don't get
1700 // redundant loop bounds.
1701 tmpClone->removeRedundantInequalities();
1702 }
1703 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1704 pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1705 }
1706
1707 // If the above fails, we'll just use the constant lower bound and the
1708 // constant upper bound (if they exist) as the slice bounds.
1709 // TODO: being conservative for the moment in cases that
1710 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1711 // fixed (b/126426796).
1712 if (!lbMap || lbMap.getNumResults() > 1) {
1713 LLVM_DEBUG(llvm::dbgs()
1714 << "WARNING: Potentially over-approximating slice lb\n");
1715 auto lbConst = getConstantLowerBound(pos + offset);
1716 if (lbConst.hasValue()) {
1717 lbMap = AffineMap::get(
1718 numMapDims, numMapSymbols,
1719 getAffineConstantExpr(lbConst.getValue(), context));
1720 }
1721 }
1722 if (!ubMap || ubMap.getNumResults() > 1) {
1723 LLVM_DEBUG(llvm::dbgs()
1724 << "WARNING: Potentially over-approximating slice ub\n");
1725 auto ubConst = getConstantUpperBound(pos + offset);
1726 if (ubConst.hasValue()) {
1727 (ubMap) = AffineMap::get(
1728 numMapDims, numMapSymbols,
1729 getAffineConstantExpr(ubConst.getValue() + 1, context));
1730 }
1731 }
1732 }
1733 LLVM_DEBUG(llvm::dbgs()
1734 << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1735 LLVM_DEBUG(lbMap.dump(););
1736 LLVM_DEBUG(llvm::dbgs()
1737 << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1738 LLVM_DEBUG(ubMap.dump(););
1739 }
1740 }
1741
1742 LogicalResult
addLowerOrUpperBound(unsigned pos,AffineMap boundMap,ValueRange boundOperands,bool eq,bool lower)1743 FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
1744 ValueRange boundOperands, bool eq,
1745 bool lower) {
1746 assert(pos < getNumDimAndSymbolIds() && "invalid position");
1747 // Equality follows the logic of lower bound except that we add an equality
1748 // instead of an inequality.
1749 assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1750 if (eq)
1751 lower = true;
1752
1753 // Fully compose map and operands; canonicalize and simplify so that we
1754 // transitively get to terminal symbols or loop IVs.
1755 auto map = boundMap;
1756 SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1757 fullyComposeAffineMapAndOperands(&map, &operands);
1758 map = simplifyAffineMap(map);
1759 canonicalizeMapAndOperands(&map, &operands);
1760 for (auto operand : operands)
1761 addInductionVarOrTerminalSymbol(operand);
1762
1763 FlatAffineConstraints localVarCst;
1764 std::vector<SmallVector<int64_t, 8>> flatExprs;
1765 if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
1766 LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
1767 return failure();
1768 }
1769
1770 // Merge and align with localVarCst.
1771 if (localVarCst.getNumLocalIds() > 0) {
1772 // Set values for localVarCst.
1773 localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
1774 for (auto operand : operands) {
1775 unsigned pos;
1776 if (findId(operand, &pos)) {
1777 if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
1778 // If the local var cst has this as a dim, turn it into its symbol.
1779 turnDimIntoSymbol(&localVarCst, operand);
1780 } else if (pos < getNumDimIds()) {
1781 // Or vice versa.
1782 turnSymbolIntoDim(&localVarCst, operand);
1783 }
1784 }
1785 }
1786 mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
1787 append(localVarCst);
1788 }
1789
1790 // Record positions of the operands in the constraint system. Need to do
1791 // this here since the constraint system changes after a bound is added.
1792 SmallVector<unsigned, 8> positions;
1793 unsigned numOperands = operands.size();
1794 for (auto operand : operands) {
1795 unsigned pos;
1796 if (!findId(operand, &pos))
1797 assert(0 && "expected to be found");
1798 positions.push_back(pos);
1799 }
1800
1801 for (const auto &flatExpr : flatExprs) {
1802 SmallVector<int64_t, 4> ineq(getNumCols(), 0);
1803 ineq[pos] = lower ? 1 : -1;
1804 // Dims and symbols.
1805 for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
1806 ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
1807 }
1808 // Copy over the local id coefficients.
1809 unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
1810 for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
1811 jj++, j++) {
1812 ineq[j] =
1813 lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
1814 }
1815 // Constant term.
1816 ineq[getNumCols() - 1] =
1817 lower ? -flatExpr[flatExpr.size() - 1]
1818 // Upper bound in flattenedExpr is an exclusive one.
1819 : flatExpr[flatExpr.size() - 1] - 1;
1820 eq ? addEquality(ineq) : addInequality(ineq);
1821 }
1822 return success();
1823 }
1824
1825 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1826 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1827 // system. Note that both lower/upper bounds share the same operand list
1828 // 'operands'.
1829 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1830 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1831 // Note that both lower/upper bounds use operands from 'operands'.
1832 // Returns failure for unimplemented cases such as semi-affine expressions or
1833 // expressions with mod/floordiv.
addSliceBounds(ArrayRef<Value> values,ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)1834 LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
1835 ArrayRef<AffineMap> lbMaps,
1836 ArrayRef<AffineMap> ubMaps,
1837 ArrayRef<Value> operands) {
1838 assert(values.size() == lbMaps.size());
1839 assert(lbMaps.size() == ubMaps.size());
1840
1841 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1842 unsigned pos;
1843 if (!findId(values[i], &pos))
1844 continue;
1845
1846 AffineMap lbMap = lbMaps[i];
1847 AffineMap ubMap = ubMaps[i];
1848 assert(!lbMap || lbMap.getNumInputs() == operands.size());
1849 assert(!ubMap || ubMap.getNumInputs() == operands.size());
1850
1851 // Check if this slice is just an equality along this dimension.
1852 if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1853 ubMap.getNumResults() == 1 &&
1854 lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1855 if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
1856 /*lower=*/true)))
1857 return failure();
1858 continue;
1859 }
1860
1861 if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
1862 /*lower=*/true)))
1863 return failure();
1864
1865 if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
1866 /*lower=*/false)))
1867 return failure();
1868 }
1869 return success();
1870 }
1871
addEquality(ArrayRef<int64_t> eq)1872 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
1873 assert(eq.size() == getNumCols());
1874 unsigned offset = equalities.size();
1875 equalities.resize(equalities.size() + numReservedCols);
1876 std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
1877 }
1878
addInequality(ArrayRef<int64_t> inEq)1879 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
1880 assert(inEq.size() == getNumCols());
1881 unsigned offset = inequalities.size();
1882 inequalities.resize(inequalities.size() + numReservedCols);
1883 std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
1884 }
1885
addConstantLowerBound(unsigned pos,int64_t lb)1886 void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
1887 assert(pos < getNumCols());
1888 unsigned offset = inequalities.size();
1889 inequalities.resize(inequalities.size() + numReservedCols);
1890 std::fill(inequalities.begin() + offset,
1891 inequalities.begin() + offset + getNumCols(), 0);
1892 inequalities[offset + pos] = 1;
1893 inequalities[offset + getNumCols() - 1] = -lb;
1894 }
1895
addConstantUpperBound(unsigned pos,int64_t ub)1896 void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
1897 assert(pos < getNumCols());
1898 unsigned offset = inequalities.size();
1899 inequalities.resize(inequalities.size() + numReservedCols);
1900 std::fill(inequalities.begin() + offset,
1901 inequalities.begin() + offset + getNumCols(), 0);
1902 inequalities[offset + pos] = -1;
1903 inequalities[offset + getNumCols() - 1] = ub;
1904 }
1905
addConstantLowerBound(ArrayRef<int64_t> expr,int64_t lb)1906 void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
1907 int64_t lb) {
1908 assert(expr.size() == getNumCols());
1909 unsigned offset = inequalities.size();
1910 inequalities.resize(inequalities.size() + numReservedCols);
1911 std::fill(inequalities.begin() + offset,
1912 inequalities.begin() + offset + getNumCols(), 0);
1913 std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
1914 inequalities[offset + getNumCols() - 1] += -lb;
1915 }
1916
addConstantUpperBound(ArrayRef<int64_t> expr,int64_t ub)1917 void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
1918 int64_t ub) {
1919 assert(expr.size() == getNumCols());
1920 unsigned offset = inequalities.size();
1921 inequalities.resize(inequalities.size() + numReservedCols);
1922 std::fill(inequalities.begin() + offset,
1923 inequalities.begin() + offset + getNumCols(), 0);
1924 for (unsigned i = 0, e = getNumCols(); i < e; i++) {
1925 inequalities[offset + i] = -expr[i];
1926 }
1927 inequalities[offset + getNumCols() - 1] += ub;
1928 }
1929
1930 /// Adds a new local identifier as the floordiv of an affine function of other
1931 /// identifiers, the coefficients of which are provided in 'dividend' and with
1932 /// respect to a positive constant 'divisor'. Two constraints are added to the
1933 /// system to capture equivalence with the floordiv.
1934 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
addLocalFloorDiv(ArrayRef<int64_t> dividend,int64_t divisor)1935 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1936 int64_t divisor) {
1937 assert(dividend.size() == getNumCols() && "incorrect dividend size");
1938 assert(divisor > 0 && "positive divisor expected");
1939
1940 addLocalId(getNumLocalIds());
1941
1942 // Add two constraints for this new identifier 'q'.
1943 SmallVector<int64_t, 8> bound(dividend.size() + 1);
1944
1945 // dividend - q * divisor >= 0
1946 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1947 bound.begin());
1948 bound.back() = dividend.back();
1949 bound[getNumIds() - 1] = -divisor;
1950 addInequality(bound);
1951
1952 // -dividend +qdivisor * q + divisor - 1 >= 0
1953 std::transform(bound.begin(), bound.end(), bound.begin(),
1954 std::negate<int64_t>());
1955 bound[bound.size() - 1] += divisor - 1;
1956 addInequality(bound);
1957 }
1958
findId(Value id,unsigned * pos) const1959 bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
1960 unsigned i = 0;
1961 for (const auto &mayBeId : ids) {
1962 if (mayBeId.hasValue() && mayBeId.getValue() == id) {
1963 *pos = i;
1964 return true;
1965 }
1966 i++;
1967 }
1968 return false;
1969 }
1970
containsId(Value id) const1971 bool FlatAffineConstraints::containsId(Value id) const {
1972 return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
1973 return mayBeId.hasValue() && mayBeId.getValue() == id;
1974 });
1975 }
1976
swapId(unsigned posA,unsigned posB)1977 void FlatAffineConstraints::swapId(unsigned posA, unsigned posB) {
1978 assert(posA < getNumIds() && "invalid position A");
1979 assert(posB < getNumIds() && "invalid position B");
1980
1981 if (posA == posB)
1982 return;
1983
1984 for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1985 std::swap(atIneq(r, posA), atIneq(r, posB));
1986 for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
1987 std::swap(atEq(r, posA), atEq(r, posB));
1988 std::swap(getId(posA), getId(posB));
1989 }
1990
setDimSymbolSeparation(unsigned newSymbolCount)1991 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
1992 assert(newSymbolCount <= numDims + numSymbols &&
1993 "invalid separation position");
1994 numDims = numDims + numSymbols - newSymbolCount;
1995 numSymbols = newSymbolCount;
1996 }
1997
1998 /// Sets the specified identifier to a constant value.
setIdToConstant(unsigned pos,int64_t val)1999 void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
2000 unsigned offset = equalities.size();
2001 equalities.resize(equalities.size() + numReservedCols);
2002 std::fill(equalities.begin() + offset,
2003 equalities.begin() + offset + getNumCols(), 0);
2004 equalities[offset + pos] = 1;
2005 equalities[offset + getNumCols() - 1] = -val;
2006 }
2007
2008 /// Sets the specified identifier to a constant value; asserts if the id is not
2009 /// found.
setIdToConstant(Value id,int64_t val)2010 void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
2011 unsigned pos;
2012 if (!findId(id, &pos))
2013 // This is a pre-condition for this method.
2014 assert(0 && "id not found");
2015 setIdToConstant(pos, val);
2016 }
2017
removeEquality(unsigned pos)2018 void FlatAffineConstraints::removeEquality(unsigned pos) {
2019 unsigned numEqualities = getNumEqualities();
2020 assert(pos < numEqualities);
2021 unsigned outputIndex = pos * numReservedCols;
2022 unsigned inputIndex = (pos + 1) * numReservedCols;
2023 unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
2024 std::copy(equalities.begin() + inputIndex,
2025 equalities.begin() + inputIndex + numElemsToCopy,
2026 equalities.begin() + outputIndex);
2027 assert(equalities.size() >= numReservedCols);
2028 equalities.resize(equalities.size() - numReservedCols);
2029 }
2030
removeInequality(unsigned pos)2031 void FlatAffineConstraints::removeInequality(unsigned pos) {
2032 unsigned numInequalities = getNumInequalities();
2033 assert(pos < numInequalities && "invalid position");
2034 unsigned outputIndex = pos * numReservedCols;
2035 unsigned inputIndex = (pos + 1) * numReservedCols;
2036 unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols;
2037 std::copy(inequalities.begin() + inputIndex,
2038 inequalities.begin() + inputIndex + numElemsToCopy,
2039 inequalities.begin() + outputIndex);
2040 assert(inequalities.size() >= numReservedCols);
2041 inequalities.resize(inequalities.size() - numReservedCols);
2042 }
2043
2044 /// Finds an equality that equates the specified identifier to a constant.
2045 /// Returns the position of the equality row. If 'symbolic' is set to true,
2046 /// symbols are also treated like a constant, i.e., an affine function of the
2047 /// symbols is also treated like a constant. Returns -1 if such an equality
2048 /// could not be found.
findEqualityToConstant(const FlatAffineConstraints & cst,unsigned pos,bool symbolic=false)2049 static int findEqualityToConstant(const FlatAffineConstraints &cst,
2050 unsigned pos, bool symbolic = false) {
2051 assert(pos < cst.getNumIds() && "invalid position");
2052 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2053 int64_t v = cst.atEq(r, pos);
2054 if (v * v != 1)
2055 continue;
2056 unsigned c;
2057 unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
2058 // This checks for zeros in all positions other than 'pos' in [0, f)
2059 for (c = 0; c < f; c++) {
2060 if (c == pos)
2061 continue;
2062 if (cst.atEq(r, c) != 0) {
2063 // Dependent on another identifier.
2064 break;
2065 }
2066 }
2067 if (c == f)
2068 // Equality is free of other identifiers.
2069 return r;
2070 }
2071 return -1;
2072 }
2073
setAndEliminate(unsigned pos,int64_t constVal)2074 void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
2075 assert(pos < getNumIds() && "invalid position");
2076 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2077 atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
2078 }
2079 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2080 atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
2081 }
2082 removeId(pos);
2083 }
2084
constantFoldId(unsigned pos)2085 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2086 assert(pos < getNumIds() && "invalid position");
2087 int rowIdx;
2088 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2089 return failure();
2090
2091 // atEq(rowIdx, pos) is either -1 or 1.
2092 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2093 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2094 setAndEliminate(pos, constVal);
2095 return success();
2096 }
2097
constantFoldIdRange(unsigned pos,unsigned num)2098 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2099 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2100 if (failed(constantFoldId(t)))
2101 t++;
2102 }
2103 }
2104
2105 /// Returns the extent (upper bound - lower bound) of the specified
2106 /// identifier if it is found to be a constant; returns None if it's not a
2107 /// constant. This methods treats symbolic identifiers specially, i.e.,
2108 /// it looks for constant differences between affine expressions involving
2109 /// only the symbolic identifiers. See comments at function definition for
2110 /// example. 'lb', if provided, is set to the lower bound associated with the
2111 /// constant difference. Note that 'lb' is purely symbolic and thus will contain
2112 /// the coefficients of the symbolic identifiers and the constant coefficient.
2113 // Egs: 0 <= i <= 15, return 16.
2114 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2115 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2116 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2117 // ceil(s0 - 7 / 8) = floor(s0 / 8)).
getConstantBoundOnDimSize(unsigned pos,SmallVectorImpl<int64_t> * lb,int64_t * boundFloorDivisor,SmallVectorImpl<int64_t> * ub,unsigned * minLbPos,unsigned * minUbPos) const2118 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2119 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2120 SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2121 unsigned *minUbPos) const {
2122 assert(pos < getNumDimIds() && "Invalid identifier position");
2123
2124 // Find an equality for 'pos'^th identifier that equates it to some function
2125 // of the symbolic identifiers (+ constant).
2126 int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2127 if (eqPos != -1) {
2128 auto eq = getEquality(eqPos);
2129 // If the equality involves a local var, punt for now.
2130 // TODO: this can be handled in the future by using the explicit
2131 // representation of the local vars.
2132 if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2133 [](int64_t coeff) { return coeff == 0; }))
2134 return None;
2135
2136 // This identifier can only take a single value.
2137 if (lb) {
2138 // Set lb to that symbolic value.
2139 lb->resize(getNumSymbolIds() + 1);
2140 if (ub)
2141 ub->resize(getNumSymbolIds() + 1);
2142 for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2143 int64_t v = atEq(eqPos, pos);
2144 // atEq(eqRow, pos) is either -1 or 1.
2145 assert(v * v == 1);
2146 (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2147 : -atEq(eqPos, getNumDimIds() + c) / v;
2148 // Since this is an equality, ub = lb.
2149 if (ub)
2150 (*ub)[c] = (*lb)[c];
2151 }
2152 assert(boundFloorDivisor &&
2153 "both lb and divisor or none should be provided");
2154 *boundFloorDivisor = 1;
2155 }
2156 if (minLbPos)
2157 *minLbPos = eqPos;
2158 if (minUbPos)
2159 *minUbPos = eqPos;
2160 return 1;
2161 }
2162
2163 // Check if the identifier appears at all in any of the inequalities.
2164 unsigned r, e;
2165 for (r = 0, e = getNumInequalities(); r < e; r++) {
2166 if (atIneq(r, pos) != 0)
2167 break;
2168 }
2169 if (r == e)
2170 // If it doesn't, there isn't a bound on it.
2171 return None;
2172
2173 // Positions of constraints that are lower/upper bounds on the variable.
2174 SmallVector<unsigned, 4> lbIndices, ubIndices;
2175
2176 // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2177 // the bounds can only involve symbolic (and local) identifiers. Since the
2178 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2179 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2180 getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2181 /*eqIndices=*/nullptr, /*offset=*/0,
2182 /*num=*/getNumDimIds());
2183
2184 Optional<int64_t> minDiff = None;
2185 unsigned minLbPosition = 0, minUbPosition = 0;
2186 for (auto ubPos : ubIndices) {
2187 for (auto lbPos : lbIndices) {
2188 // Look for a lower bound and an upper bound that only differ by a
2189 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst.
2190 // For example, if ii is the pos^th variable, we are looking for
2191 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2192 // minimum among all such constant differences is kept since that's the
2193 // constant bounding the extent of the pos^th variable.
2194 unsigned j, e;
2195 for (j = 0, e = getNumCols() - 1; j < e; j++)
2196 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2197 break;
2198 }
2199 if (j < getNumCols() - 1)
2200 continue;
2201 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2202 atIneq(lbPos, getNumCols() - 1) + 1,
2203 atIneq(lbPos, pos));
2204 if (minDiff == None || diff < minDiff) {
2205 minDiff = diff;
2206 minLbPosition = lbPos;
2207 minUbPosition = ubPos;
2208 }
2209 }
2210 }
2211 if (lb && minDiff.hasValue()) {
2212 // Set lb to the symbolic lower bound.
2213 lb->resize(getNumSymbolIds() + 1);
2214 if (ub)
2215 ub->resize(getNumSymbolIds() + 1);
2216 // The lower bound is the ceildiv of the lb constraint over the coefficient
2217 // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2218 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2219 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2220 *boundFloorDivisor = atIneq(minLbPosition, pos);
2221 assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2222 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2223 (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2224 }
2225 if (ub) {
2226 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2227 (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2228 }
2229 // The lower bound leads to a ceildiv while the upper bound is a floordiv
2230 // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2231 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2232 // the constant term for the lower bound.
2233 (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2234 }
2235 if (minLbPos)
2236 *minLbPos = minLbPosition;
2237 if (minUbPos)
2238 *minUbPos = minUbPosition;
2239 return minDiff;
2240 }
2241
2242 template <bool isLower>
2243 Optional<int64_t>
computeConstantLowerOrUpperBound(unsigned pos)2244 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2245 assert(pos < getNumIds() && "invalid position");
2246 // Project to 'pos'.
2247 projectOut(0, pos);
2248 projectOut(1, getNumIds() - 1);
2249 // Check if there's an equality equating the '0'^th identifier to a constant.
2250 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2251 if (eqRowIdx != -1)
2252 // atEq(rowIdx, 0) is either -1 or 1.
2253 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2254
2255 // Check if the identifier appears at all in any of the inequalities.
2256 unsigned r, e;
2257 for (r = 0, e = getNumInequalities(); r < e; r++) {
2258 if (atIneq(r, 0) != 0)
2259 break;
2260 }
2261 if (r == e)
2262 // If it doesn't, there isn't a bound on it.
2263 return None;
2264
2265 Optional<int64_t> minOrMaxConst = None;
2266
2267 // Take the max across all const lower bounds (or min across all constant
2268 // upper bounds).
2269 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2270 if (isLower) {
2271 if (atIneq(r, 0) <= 0)
2272 // Not a lower bound.
2273 continue;
2274 } else if (atIneq(r, 0) >= 0) {
2275 // Not an upper bound.
2276 continue;
2277 }
2278 unsigned c, f;
2279 for (c = 0, f = getNumCols() - 1; c < f; c++)
2280 if (c != 0 && atIneq(r, c) != 0)
2281 break;
2282 if (c < getNumCols() - 1)
2283 // Not a constant bound.
2284 continue;
2285
2286 int64_t boundConst =
2287 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2288 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2289 if (isLower) {
2290 if (minOrMaxConst == None || boundConst > minOrMaxConst)
2291 minOrMaxConst = boundConst;
2292 } else {
2293 if (minOrMaxConst == None || boundConst < minOrMaxConst)
2294 minOrMaxConst = boundConst;
2295 }
2296 }
2297 return minOrMaxConst;
2298 }
2299
2300 Optional<int64_t>
getConstantLowerBound(unsigned pos) const2301 FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
2302 FlatAffineConstraints tmpCst(*this);
2303 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2304 }
2305
2306 Optional<int64_t>
getConstantUpperBound(unsigned pos) const2307 FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
2308 FlatAffineConstraints tmpCst(*this);
2309 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2310 }
2311
2312 // A simple (naive and conservative) check for hyper-rectangularity.
isHyperRectangular(unsigned pos,unsigned num) const2313 bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2314 unsigned num) const {
2315 assert(pos < getNumCols() - 1);
2316 // Check for two non-zero coefficients in the range [pos, pos + sum).
2317 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2318 unsigned sum = 0;
2319 for (unsigned c = pos; c < pos + num; c++) {
2320 if (atIneq(r, c) != 0)
2321 sum++;
2322 }
2323 if (sum > 1)
2324 return false;
2325 }
2326 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2327 unsigned sum = 0;
2328 for (unsigned c = pos; c < pos + num; c++) {
2329 if (atEq(r, c) != 0)
2330 sum++;
2331 }
2332 if (sum > 1)
2333 return false;
2334 }
2335 return true;
2336 }
2337
print(raw_ostream & os) const2338 void FlatAffineConstraints::print(raw_ostream &os) const {
2339 assert(hasConsistentState());
2340 os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2341 << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2342 << " constraints)\n";
2343 os << "(";
2344 for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2345 if (ids[i] == None)
2346 os << "None ";
2347 else
2348 os << "Value ";
2349 }
2350 os << " const)\n";
2351 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2352 for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2353 os << atEq(i, j) << " ";
2354 }
2355 os << "= 0\n";
2356 }
2357 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2358 for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2359 os << atIneq(i, j) << " ";
2360 }
2361 os << ">= 0\n";
2362 }
2363 os << '\n';
2364 }
2365
dump() const2366 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2367
2368 /// Removes duplicate constraints, trivially true constraints, and constraints
2369 /// that can be detected as redundant as a result of differing only in their
2370 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
2371 /// considered trivially true.
2372 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2373 // remove duplicates in place.
removeTrivialRedundancy()2374 void FlatAffineConstraints::removeTrivialRedundancy() {
2375 GCDTightenInequalities();
2376 normalizeConstraintsByGCD();
2377
2378 // A map used to detect redundancy stemming from constraints that only differ
2379 // in their constant term. The value stored is <row position, const term>
2380 // for a given row.
2381 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2382 rowsWithoutConstTerm;
2383 // To unique rows.
2384 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2385
2386 // Check if constraint is of the form <non-negative-constant> >= 0.
2387 auto isTriviallyValid = [&](unsigned r) -> bool {
2388 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2389 if (atIneq(r, c) != 0)
2390 return false;
2391 }
2392 return atIneq(r, getNumCols() - 1) >= 0;
2393 };
2394
2395 // Detect and mark redundant constraints.
2396 SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2397 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2398 int64_t *rowStart = inequalities.data() + numReservedCols * r;
2399 auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2400 if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2401 redunIneq[r] = true;
2402 continue;
2403 }
2404
2405 // Among constraints that only differ in the constant term part, mark
2406 // everything other than the one with the smallest constant term redundant.
2407 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2408 // former two are redundant).
2409 int64_t constTerm = atIneq(r, getNumCols() - 1);
2410 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2411 const auto &ret =
2412 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2413 if (!ret.second) {
2414 // Check if the other constraint has a higher constant term.
2415 auto &val = ret.first->second;
2416 if (val.second > constTerm) {
2417 // The stored row is redundant. Mark it so, and update with this one.
2418 redunIneq[val.first] = true;
2419 val = {r, constTerm};
2420 } else {
2421 // The one stored makes this one redundant.
2422 redunIneq[r] = true;
2423 }
2424 }
2425 }
2426
2427 auto copyRow = [&](unsigned src, unsigned dest) {
2428 if (src == dest)
2429 return;
2430 for (unsigned c = 0, e = getNumCols(); c < e; c++) {
2431 atIneq(dest, c) = atIneq(src, c);
2432 }
2433 };
2434
2435 // Scan to get rid of all rows marked redundant, in-place.
2436 unsigned pos = 0;
2437 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2438 if (!redunIneq[r])
2439 copyRow(r, pos++);
2440 }
2441 inequalities.resize(numReservedCols * pos);
2442
2443 // TODO: consider doing this for equalities as well, but probably not worth
2444 // the savings.
2445 }
2446
clearAndCopyFrom(const FlatAffineConstraints & other)2447 void FlatAffineConstraints::clearAndCopyFrom(
2448 const FlatAffineConstraints &other) {
2449 FlatAffineConstraints copy(other);
2450 std::swap(*this, copy);
2451 assert(copy.getNumIds() == copy.getIds().size());
2452 }
2453
removeId(unsigned pos)2454 void FlatAffineConstraints::removeId(unsigned pos) {
2455 removeIdRange(pos, pos + 1);
2456 }
2457
2458 static std::pair<unsigned, unsigned>
getNewNumDimsSymbols(unsigned pos,const FlatAffineConstraints & cst)2459 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2460 unsigned numDims = cst.getNumDimIds();
2461 unsigned numSymbols = cst.getNumSymbolIds();
2462 unsigned newNumDims, newNumSymbols;
2463 if (pos < numDims) {
2464 newNumDims = numDims - 1;
2465 newNumSymbols = numSymbols;
2466 } else if (pos < numDims + numSymbols) {
2467 assert(numSymbols >= 1);
2468 newNumDims = numDims;
2469 newNumSymbols = numSymbols - 1;
2470 } else {
2471 newNumDims = numDims;
2472 newNumSymbols = numSymbols;
2473 }
2474 return {newNumDims, newNumSymbols};
2475 }
2476
2477 #undef DEBUG_TYPE
2478 #define DEBUG_TYPE "fm"
2479
2480 /// Eliminates identifier at the specified position using Fourier-Motzkin
2481 /// variable elimination. This technique is exact for rational spaces but
2482 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
2483 /// to a projection operation yielding the (convex) set of integer points
2484 /// contained in the rational shadow of the set. An emptiness test that relies
2485 /// on this method will guarantee emptiness, i.e., it disproves the existence of
2486 /// a solution if it says it's empty.
2487 /// If a non-null isResultIntegerExact is passed, it is set to true if the
2488 /// result is also integer exact. If it's set to false, the obtained solution
2489 /// *may* not be exact, i.e., it may contain integer points that do not have an
2490 /// integer pre-image in the original set.
2491 ///
2492 /// Eg:
2493 /// j >= 0, j <= i + 1
2494 /// i >= 0, i <= N + 1
2495 /// Eliminating i yields,
2496 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1
2497 ///
2498 /// If darkShadow = true, this method computes the dark shadow on elimination;
2499 /// the dark shadow is a convex integer subset of the exact integer shadow. A
2500 /// non-empty dark shadow proves the existence of an integer solution. The
2501 /// elimination in such a case could however be an under-approximation, and thus
2502 /// should not be used for scanning sets or used by itself for dependence
2503 /// checking.
2504 ///
2505 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2506 /// ^
2507 /// |
2508 /// | * * * * o o
2509 /// i | * * o o o o
2510 /// | o * * * * *
2511 /// --------------->
2512 /// j ->
2513 ///
2514 /// Eliminating i from this system (projecting on the j dimension):
2515 /// rational shadow / integer light shadow: 1 <= j <= 6
2516 /// dark shadow: 3 <= j <= 6
2517 /// exact integer shadow: j = 1 \union 3 <= j <= 6
2518 /// holes/splinters: j = 2
2519 ///
2520 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
2521 // TODO: a slight modification to yield dark shadow version of FM (tightened),
2522 // which can prove the existence of a solution if there is one.
FourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)2523 void FlatAffineConstraints::FourierMotzkinEliminate(
2524 unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2525 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2526 LLVM_DEBUG(dump());
2527 assert(pos < getNumIds() && "invalid position");
2528 assert(hasConsistentState());
2529
2530 // Check if this identifier can be eliminated through a substitution.
2531 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2532 if (atEq(r, pos) != 0) {
2533 // Use Gaussian elimination here (since we have an equality).
2534 LogicalResult ret = gaussianEliminateId(pos);
2535 (void)ret;
2536 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2537 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2538 LLVM_DEBUG(dump());
2539 return;
2540 }
2541 }
2542
2543 // A fast linear time tightening.
2544 GCDTightenInequalities();
2545
2546 // Check if the identifier appears at all in any of the inequalities.
2547 unsigned r, e;
2548 for (r = 0, e = getNumInequalities(); r < e; r++) {
2549 if (atIneq(r, pos) != 0)
2550 break;
2551 }
2552 if (r == getNumInequalities()) {
2553 // If it doesn't appear, just remove the column and return.
2554 // TODO: refactor removeColumns to use it from here.
2555 removeId(pos);
2556 LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2557 LLVM_DEBUG(dump());
2558 return;
2559 }
2560
2561 // Positions of constraints that are lower bounds on the variable.
2562 SmallVector<unsigned, 4> lbIndices;
2563 // Positions of constraints that are lower bounds on the variable.
2564 SmallVector<unsigned, 4> ubIndices;
2565 // Positions of constraints that do not involve the variable.
2566 std::vector<unsigned> nbIndices;
2567 nbIndices.reserve(getNumInequalities());
2568
2569 // Gather all lower bounds and upper bounds of the variable. Since the
2570 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2571 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2572 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2573 if (atIneq(r, pos) == 0) {
2574 // Id does not appear in bound.
2575 nbIndices.push_back(r);
2576 } else if (atIneq(r, pos) >= 1) {
2577 // Lower bound.
2578 lbIndices.push_back(r);
2579 } else {
2580 // Upper bound.
2581 ubIndices.push_back(r);
2582 }
2583 }
2584
2585 // Set the number of dimensions, symbols in the resulting system.
2586 const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
2587 unsigned newNumDims = dimsSymbols.first;
2588 unsigned newNumSymbols = dimsSymbols.second;
2589
2590 SmallVector<Optional<Value>, 8> newIds;
2591 newIds.reserve(numIds - 1);
2592 newIds.append(ids.begin(), ids.begin() + pos);
2593 newIds.append(ids.begin() + pos + 1, ids.end());
2594
2595 /// Create the new system which has one identifier less.
2596 FlatAffineConstraints newFac(
2597 lbIndices.size() * ubIndices.size() + nbIndices.size(),
2598 getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
2599 /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
2600
2601 assert(newFac.getIds().size() == newFac.getNumIds());
2602
2603 // This will be used to check if the elimination was integer exact.
2604 unsigned lcmProducts = 1;
2605
2606 // Let x be the variable we are eliminating.
2607 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
2608 // that c_l, c_u >= 1) we have:
2609 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
2610 // We thus generate a constraint:
2611 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
2612 // Note if c_l = c_u = 1, all integer points captured by the resulting
2613 // constraint correspond to integer points in the original system (i.e., they
2614 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
2615 // integer exact.
2616 for (auto ubPos : ubIndices) {
2617 for (auto lbPos : lbIndices) {
2618 SmallVector<int64_t, 4> ineq;
2619 ineq.reserve(newFac.getNumCols());
2620 int64_t lbCoeff = atIneq(lbPos, pos);
2621 // Note that in the comments above, ubCoeff is the negation of the
2622 // coefficient in the canonical form as the view taken here is that of the
2623 // term being moved to the other size of '>='.
2624 int64_t ubCoeff = -atIneq(ubPos, pos);
2625 // TODO: refactor this loop to avoid all branches inside.
2626 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2627 if (l == pos)
2628 continue;
2629 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
2630 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
2631 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
2632 atIneq(lbPos, l) * (lcm / lbCoeff));
2633 lcmProducts *= lcm;
2634 }
2635 if (darkShadow) {
2636 // The dark shadow is a convex subset of the exact integer shadow. If
2637 // there is a point here, it proves the existence of a solution.
2638 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
2639 }
2640 // TODO: we need to have a way to add inequalities in-place in
2641 // FlatAffineConstraints instead of creating and copying over.
2642 newFac.addInequality(ineq);
2643 }
2644 }
2645
2646 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
2647 << "\n");
2648 if (lcmProducts == 1 && isResultIntegerExact)
2649 *isResultIntegerExact = true;
2650
2651 // Copy over the constraints not involving this variable.
2652 for (auto nbPos : nbIndices) {
2653 SmallVector<int64_t, 4> ineq;
2654 ineq.reserve(getNumCols() - 1);
2655 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2656 if (l == pos)
2657 continue;
2658 ineq.push_back(atIneq(nbPos, l));
2659 }
2660 newFac.addInequality(ineq);
2661 }
2662
2663 assert(newFac.getNumConstraints() ==
2664 lbIndices.size() * ubIndices.size() + nbIndices.size());
2665
2666 // Copy over the equalities.
2667 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2668 SmallVector<int64_t, 4> eq;
2669 eq.reserve(newFac.getNumCols());
2670 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2671 if (l == pos)
2672 continue;
2673 eq.push_back(atEq(r, l));
2674 }
2675 newFac.addEquality(eq);
2676 }
2677
2678 // GCD tightening and normalization allows detection of more trivially
2679 // redundant constraints.
2680 newFac.GCDTightenInequalities();
2681 newFac.normalizeConstraintsByGCD();
2682 newFac.removeTrivialRedundancy();
2683 clearAndCopyFrom(newFac);
2684 LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2685 LLVM_DEBUG(dump());
2686 }
2687
2688 #undef DEBUG_TYPE
2689 #define DEBUG_TYPE "affine-structures"
2690
projectOut(unsigned pos,unsigned num)2691 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
2692 if (num == 0)
2693 return;
2694
2695 // 'pos' can be at most getNumCols() - 2 if num > 0.
2696 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
2697 assert(pos + num < getNumCols() && "invalid range");
2698
2699 // Eliminate as many identifiers as possible using Gaussian elimination.
2700 unsigned currentPos = pos;
2701 unsigned numToEliminate = num;
2702 unsigned numGaussianEliminated = 0;
2703
2704 while (currentPos < getNumIds()) {
2705 unsigned curNumEliminated =
2706 gaussianEliminateIds(currentPos, currentPos + numToEliminate);
2707 ++currentPos;
2708 numToEliminate -= curNumEliminated + 1;
2709 numGaussianEliminated += curNumEliminated;
2710 }
2711
2712 // Eliminate the remaining using Fourier-Motzkin.
2713 for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
2714 unsigned numToEliminate = num - numGaussianEliminated - i;
2715 FourierMotzkinEliminate(
2716 getBestIdToEliminate(*this, pos, pos + numToEliminate));
2717 }
2718
2719 // Fast/trivial simplifications.
2720 GCDTightenInequalities();
2721 // Normalize constraints after tightening since the latter impacts this, but
2722 // not the other way round.
2723 normalizeConstraintsByGCD();
2724 }
2725
projectOut(Value id)2726 void FlatAffineConstraints::projectOut(Value id) {
2727 unsigned pos;
2728 bool ret = findId(id, &pos);
2729 assert(ret);
2730 (void)ret;
2731 FourierMotzkinEliminate(pos);
2732 }
2733
clearConstraints()2734 void FlatAffineConstraints::clearConstraints() {
2735 equalities.clear();
2736 inequalities.clear();
2737 }
2738
2739 namespace {
2740
2741 enum BoundCmpResult { Greater, Less, Equal, Unknown };
2742
2743 /// Compares two affine bounds whose coefficients are provided in 'first' and
2744 /// 'second'. The last coefficient is the constant term.
compareBounds(ArrayRef<int64_t> a,ArrayRef<int64_t> b)2745 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2746 assert(a.size() == b.size());
2747
2748 // For the bounds to be comparable, their corresponding identifier
2749 // coefficients should be equal; the constant terms are then compared to
2750 // determine less/greater/equal.
2751
2752 if (!std::equal(a.begin(), a.end() - 1, b.begin()))
2753 return Unknown;
2754
2755 if (a.back() == b.back())
2756 return Equal;
2757
2758 return a.back() < b.back() ? Less : Greater;
2759 }
2760 } // namespace
2761
2762 // Returns constraints that are common to both A & B.
getCommonConstraints(const FlatAffineConstraints & A,const FlatAffineConstraints & B,FlatAffineConstraints & C)2763 static void getCommonConstraints(const FlatAffineConstraints &A,
2764 const FlatAffineConstraints &B,
2765 FlatAffineConstraints &C) {
2766 C.reset(A.getNumDimIds(), A.getNumSymbolIds(), A.getNumLocalIds());
2767 // A naive O(n^2) check should be enough here given the input sizes.
2768 for (unsigned r = 0, e = A.getNumInequalities(); r < e; ++r) {
2769 for (unsigned s = 0, f = B.getNumInequalities(); s < f; ++s) {
2770 if (A.getInequality(r) == B.getInequality(s)) {
2771 C.addInequality(A.getInequality(r));
2772 break;
2773 }
2774 }
2775 }
2776 for (unsigned r = 0, e = A.getNumEqualities(); r < e; ++r) {
2777 for (unsigned s = 0, f = B.getNumEqualities(); s < f; ++s) {
2778 if (A.getEquality(r) == B.getEquality(s)) {
2779 C.addEquality(A.getEquality(r));
2780 break;
2781 }
2782 }
2783 }
2784 }
2785
2786 // Computes the bounding box with respect to 'other' by finding the min of the
2787 // lower bounds and the max of the upper bounds along each of the dimensions.
2788 LogicalResult
unionBoundingBox(const FlatAffineConstraints & otherCst)2789 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
2790 assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
2791 assert(otherCst.getIds()
2792 .slice(0, getNumDimIds())
2793 .equals(getIds().slice(0, getNumDimIds())) &&
2794 "dim values mismatch");
2795 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
2796 assert(getNumLocalIds() == 0 && "local ids not supported yet here");
2797
2798 // Align `other` to this.
2799 Optional<FlatAffineConstraints> otherCopy;
2800 if (!areIdsAligned(*this, otherCst)) {
2801 otherCopy.emplace(FlatAffineConstraints(otherCst));
2802 mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
2803 }
2804
2805 const auto &otherAligned = otherCopy ? *otherCopy : otherCst;
2806
2807 // Get the constraints common to both systems; these will be added as is to
2808 // the union.
2809 FlatAffineConstraints commonCst;
2810 getCommonConstraints(*this, otherAligned, commonCst);
2811
2812 std::vector<SmallVector<int64_t, 8>> boundingLbs;
2813 std::vector<SmallVector<int64_t, 8>> boundingUbs;
2814 boundingLbs.reserve(2 * getNumDimIds());
2815 boundingUbs.reserve(2 * getNumDimIds());
2816
2817 // To hold lower and upper bounds for each dimension.
2818 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
2819 // To compute min of lower bounds and max of upper bounds for each dimension.
2820 SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
2821 SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
2822 // To compute final new lower and upper bounds for the union.
2823 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
2824
2825 int64_t lbFloorDivisor, otherLbFloorDivisor;
2826 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2827 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2828 if (!extent.hasValue())
2829 // TODO: symbolic extents when necessary.
2830 // TODO: handle union if a dimension is unbounded.
2831 return failure();
2832
2833 auto otherExtent = otherAligned.getConstantBoundOnDimSize(
2834 d, &otherLb, &otherLbFloorDivisor, &otherUb);
2835 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
2836 // TODO: symbolic extents when necessary.
2837 return failure();
2838
2839 assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2840
2841 auto res = compareBounds(lb, otherLb);
2842 // Identify min.
2843 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2844 minLb = lb;
2845 // Since the divisor is for a floordiv, we need to convert to ceildiv,
2846 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2847 // div * i >= expr - div + 1.
2848 minLb.back() -= lbFloorDivisor - 1;
2849 } else if (res == BoundCmpResult::Greater) {
2850 minLb = otherLb;
2851 minLb.back() -= otherLbFloorDivisor - 1;
2852 } else {
2853 // Uncomparable - check for constant lower/upper bounds.
2854 auto constLb = getConstantLowerBound(d);
2855 auto constOtherLb = otherAligned.getConstantLowerBound(d);
2856 if (!constLb.hasValue() || !constOtherLb.hasValue())
2857 return failure();
2858 std::fill(minLb.begin(), minLb.end(), 0);
2859 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
2860 }
2861
2862 // Do the same for ub's but max of upper bounds. Identify max.
2863 auto uRes = compareBounds(ub, otherUb);
2864 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2865 maxUb = ub;
2866 } else if (uRes == BoundCmpResult::Less) {
2867 maxUb = otherUb;
2868 } else {
2869 // Uncomparable - check for constant lower/upper bounds.
2870 auto constUb = getConstantUpperBound(d);
2871 auto constOtherUb = otherAligned.getConstantUpperBound(d);
2872 if (!constUb.hasValue() || !constOtherUb.hasValue())
2873 return failure();
2874 std::fill(maxUb.begin(), maxUb.end(), 0);
2875 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
2876 }
2877
2878 std::fill(newLb.begin(), newLb.end(), 0);
2879 std::fill(newUb.begin(), newUb.end(), 0);
2880
2881 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
2882 // and so it's the divisor for newLb and newUb as well.
2883 newLb[d] = lbFloorDivisor;
2884 newUb[d] = -lbFloorDivisor;
2885 // Copy over the symbolic part + constant term.
2886 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
2887 std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
2888 newLb.begin() + getNumDimIds(), std::negate<int64_t>());
2889 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
2890
2891 boundingLbs.push_back(newLb);
2892 boundingUbs.push_back(newUb);
2893 }
2894
2895 // Clear all constraints and add the lower/upper bounds for the bounding box.
2896 clearConstraints();
2897 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2898 addInequality(boundingLbs[d]);
2899 addInequality(boundingUbs[d]);
2900 }
2901
2902 // Add the constraints that were common to both systems.
2903 append(commonCst);
2904 removeTrivialRedundancy();
2905
2906 // TODO: copy over pure symbolic constraints from this and 'other' over to the
2907 // union (since the above are just the union along dimensions); we shouldn't
2908 // be discarding any other constraints on the symbols.
2909
2910 return success();
2911 }
2912
2913 /// Compute an explicit representation for local vars. For all systems coming
2914 /// from MLIR integer sets, maps, or expressions where local vars were
2915 /// introduced to model floordivs and mods, this always succeeds.
computeLocalVars(const FlatAffineConstraints & cst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)2916 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
2917 SmallVectorImpl<AffineExpr> &memo,
2918 MLIRContext *context) {
2919 unsigned numDims = cst.getNumDimIds();
2920 unsigned numSyms = cst.getNumSymbolIds();
2921
2922 // Initialize dimensional and symbolic identifiers.
2923 for (unsigned i = 0; i < numDims; i++)
2924 memo[i] = getAffineDimExpr(i, context);
2925 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
2926 memo[i] = getAffineSymbolExpr(i - numDims, context);
2927
2928 bool changed;
2929 do {
2930 // Each time `changed` is true at the end of this iteration, one or more
2931 // local vars would have been detected as floordivs and set in memo; so the
2932 // number of null entries in memo[...] strictly reduces; so this converges.
2933 changed = false;
2934 for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
2935 if (!memo[numDims + numSyms + i] &&
2936 detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
2937 changed = true;
2938 } while (changed);
2939
2940 ArrayRef<AffineExpr> localExprs =
2941 ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
2942 return success(
2943 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
2944 }
2945
getIneqAsAffineValueMap(unsigned pos,unsigned ineqPos,AffineValueMap & vmap,MLIRContext * context) const2946 void FlatAffineConstraints::getIneqAsAffineValueMap(
2947 unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
2948 MLIRContext *context) const {
2949 unsigned numDims = getNumDimIds();
2950 unsigned numSyms = getNumSymbolIds();
2951
2952 assert(pos < numDims && "invalid position");
2953 assert(ineqPos < getNumInequalities() && "invalid inequality position");
2954
2955 // Get expressions for local vars.
2956 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
2957 if (failed(computeLocalVars(*this, memo, context)))
2958 assert(false &&
2959 "one or more local exprs do not have an explicit representation");
2960 auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
2961
2962 // Compute the AffineExpr lower/upper bound for this inequality.
2963 ArrayRef<int64_t> inequality = getInequality(ineqPos);
2964 SmallVector<int64_t, 8> bound;
2965 bound.reserve(getNumCols() - 1);
2966 // Everything other than the coefficient at `pos`.
2967 bound.append(inequality.begin(), inequality.begin() + pos);
2968 bound.append(inequality.begin() + pos + 1, inequality.end());
2969
2970 if (inequality[pos] > 0)
2971 // Lower bound.
2972 std::transform(bound.begin(), bound.end(), bound.begin(),
2973 std::negate<int64_t>());
2974 else
2975 // Upper bound (which is exclusive).
2976 bound.back() += 1;
2977
2978 // Convert to AffineExpr (tree) form.
2979 auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
2980 localExprs, context);
2981
2982 // Get the values to bind to this affine expr (all dims and symbols).
2983 SmallVector<Value, 4> operands;
2984 getIdValues(0, pos, &operands);
2985 SmallVector<Value, 4> trailingOperands;
2986 getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
2987 operands.append(trailingOperands.begin(), trailingOperands.end());
2988 vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
2989 }
2990
2991 /// Returns true if the pos^th column is all zero for both inequalities and
2992 /// equalities..
isColZero(const FlatAffineConstraints & cst,unsigned pos)2993 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
2994 unsigned rowPos;
2995 return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
2996 !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
2997 }
2998
getAsIntegerSet(MLIRContext * context) const2999 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
3000 if (getNumConstraints() == 0)
3001 // Return universal set (always true): 0 == 0.
3002 return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
3003 getAffineConstantExpr(/*constant=*/0, context),
3004 /*eqFlags=*/true);
3005
3006 // Construct local references.
3007 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3008
3009 if (failed(computeLocalVars(*this, memo, context))) {
3010 // Check if the local variables without an explicit representation have
3011 // zero coefficients everywhere.
3012 for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
3013 if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
3014 LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
3015 "explicit representation");
3016 return IntegerSet();
3017 }
3018 }
3019 }
3020
3021 ArrayRef<AffineExpr> localExprs =
3022 ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3023
3024 // Construct the IntegerSet from the equalities/inequalities.
3025 unsigned numDims = getNumDimIds();
3026 unsigned numSyms = getNumSymbolIds();
3027
3028 SmallVector<bool, 16> eqFlags(getNumConstraints());
3029 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
3030 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
3031
3032 SmallVector<AffineExpr, 8> exprs;
3033 exprs.reserve(getNumConstraints());
3034
3035 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
3036 exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
3037 localExprs, context));
3038 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
3039 exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
3040 numSyms, localExprs, context));
3041 return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
3042 }
3043
3044 /// Find positions of inequalities and equalities that do not have a coefficient
3045 /// for [pos, pos + num) identifiers.
getIndependentConstraints(const FlatAffineConstraints & cst,unsigned pos,unsigned num,SmallVectorImpl<unsigned> & nbIneqIndices,SmallVectorImpl<unsigned> & nbEqIndices)3046 static void getIndependentConstraints(const FlatAffineConstraints &cst,
3047 unsigned pos, unsigned num,
3048 SmallVectorImpl<unsigned> &nbIneqIndices,
3049 SmallVectorImpl<unsigned> &nbEqIndices) {
3050 assert(pos < cst.getNumIds() && "invalid start position");
3051 assert(pos + num <= cst.getNumIds() && "invalid limit");
3052
3053 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
3054 // The bounds are to be independent of [offset, offset + num) columns.
3055 unsigned c;
3056 for (c = pos; c < pos + num; ++c) {
3057 if (cst.atIneq(r, c) != 0)
3058 break;
3059 }
3060 if (c == pos + num)
3061 nbIneqIndices.push_back(r);
3062 }
3063
3064 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
3065 // The bounds are to be independent of [offset, offset + num) columns.
3066 unsigned c;
3067 for (c = pos; c < pos + num; ++c) {
3068 if (cst.atEq(r, c) != 0)
3069 break;
3070 }
3071 if (c == pos + num)
3072 nbEqIndices.push_back(r);
3073 }
3074 }
3075
removeIndependentConstraints(unsigned pos,unsigned num)3076 void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
3077 unsigned num) {
3078 assert(pos + num <= getNumIds() && "invalid range");
3079
3080 // Remove constraints that are independent of these identifiers.
3081 SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
3082 getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
3083
3084 // Iterate in reverse so that indices don't have to be updated.
3085 // TODO: This method can be made more efficient (because removal of each
3086 // inequality leads to much shifting/copying in the underlying buffer).
3087 for (auto nbIndex : llvm::reverse(nbIneqIndices))
3088 removeInequality(nbIndex);
3089 for (auto nbIndex : llvm::reverse(nbEqIndices))
3090 removeEquality(nbIndex);
3091 }
3092