1 //===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/PresburgerSet.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 #include "llvm/ADT/STLExtras.h"
12 #include "llvm/ADT/SmallBitVector.h"
13 
14 using namespace mlir;
15 
PresburgerSet(const FlatAffineConstraints & fac)16 PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac)
17     : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) {
18   unionFACInPlace(fac);
19 }
20 
getNumFACs() const21 unsigned PresburgerSet::getNumFACs() const {
22   return flatAffineConstraints.size();
23 }
24 
getNumDims() const25 unsigned PresburgerSet::getNumDims() const { return nDim; }
26 
getNumSyms() const27 unsigned PresburgerSet::getNumSyms() const { return nSym; }
28 
29 ArrayRef<FlatAffineConstraints>
getAllFlatAffineConstraints() const30 PresburgerSet::getAllFlatAffineConstraints() const {
31   return flatAffineConstraints;
32 }
33 
34 const FlatAffineConstraints &
getFlatAffineConstraints(unsigned index) const35 PresburgerSet::getFlatAffineConstraints(unsigned index) const {
36   assert(index < flatAffineConstraints.size() && "index out of bounds!");
37   return flatAffineConstraints[index];
38 }
39 
40 /// Assert that the FlatAffineConstraints and PresburgerSet live in
41 /// compatible spaces.
assertDimensionsCompatible(const FlatAffineConstraints & fac,const PresburgerSet & set)42 static void assertDimensionsCompatible(const FlatAffineConstraints &fac,
43                                        const PresburgerSet &set) {
44   assert(fac.getNumDimIds() == set.getNumDims() &&
45          "Number of dimensions of the FlatAffineConstraints and PresburgerSet"
46          "do not match!");
47   assert(fac.getNumSymbolIds() == set.getNumSyms() &&
48          "Number of symbols of the FlatAffineConstraints and PresburgerSet"
49          "do not match!");
50 }
51 
52 /// Assert that the two PresburgerSets live in compatible spaces.
assertDimensionsCompatible(const PresburgerSet & setA,const PresburgerSet & setB)53 static void assertDimensionsCompatible(const PresburgerSet &setA,
54                                        const PresburgerSet &setB) {
55   assert(setA.getNumDims() == setB.getNumDims() &&
56          "Number of dimensions of the PresburgerSets do not match!");
57   assert(setA.getNumSyms() == setB.getNumSyms() &&
58          "Number of symbols of the PresburgerSets do not match!");
59 }
60 
61 /// Mutate this set, turning it into the union of this set and the given
62 /// FlatAffineConstraints.
unionFACInPlace(const FlatAffineConstraints & fac)63 void PresburgerSet::unionFACInPlace(const FlatAffineConstraints &fac) {
64   assertDimensionsCompatible(fac, *this);
65   flatAffineConstraints.push_back(fac);
66 }
67 
68 /// Mutate this set, turning it into the union of this set and the given set.
69 ///
70 /// This is accomplished by simply adding all the FACs of the given set to this
71 /// set.
unionSetInPlace(const PresburgerSet & set)72 void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
73   assertDimensionsCompatible(set, *this);
74   for (const FlatAffineConstraints &fac : set.flatAffineConstraints)
75     unionFACInPlace(fac);
76 }
77 
78 /// Return the union of this set and the given set.
unionSet(const PresburgerSet & set) const79 PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
80   assertDimensionsCompatible(set, *this);
81   PresburgerSet result = *this;
82   result.unionSetInPlace(set);
83   return result;
84 }
85 
86 /// A point is contained in the union iff any of the parts contain the point.
containsPoint(ArrayRef<int64_t> point) const87 bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
88   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
89     if (fac.containsPoint(point))
90       return true;
91   }
92   return false;
93 }
94 
getUniverse(unsigned nDim,unsigned nSym)95 PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
96   PresburgerSet result(nDim, nSym);
97   result.unionFACInPlace(FlatAffineConstraints::getUniverse(nDim, nSym));
98   return result;
99 }
100 
getEmptySet(unsigned nDim,unsigned nSym)101 PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
102   return PresburgerSet(nDim, nSym);
103 }
104 
105 // Return the intersection of this set with the given set.
106 //
107 // We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
108 // as (S_1 and T_1) or (S_1 and T_2) or ...
intersect(const PresburgerSet & set) const109 PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
110   assertDimensionsCompatible(set, *this);
111 
112   PresburgerSet result(nDim, nSym);
113   for (const FlatAffineConstraints &csA : flatAffineConstraints) {
114     for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
115       FlatAffineConstraints intersection(csA);
116       intersection.append(csB);
117       if (!intersection.isEmpty())
118         result.unionFACInPlace(std::move(intersection));
119     }
120   }
121   return result;
122 }
123 
124 /// Return `coeffs` with all the elements negated.
getNegatedCoeffs(ArrayRef<int64_t> coeffs)125 static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
126   SmallVector<int64_t, 8> negatedCoeffs;
127   negatedCoeffs.reserve(coeffs.size());
128   for (int64_t coeff : coeffs)
129     negatedCoeffs.emplace_back(-coeff);
130   return negatedCoeffs;
131 }
132 
133 /// Return the complement of the given inequality.
134 ///
135 /// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
136 /// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0.
getComplementIneq(ArrayRef<int64_t> ineq)137 static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
138   SmallVector<int64_t, 8> coeffs;
139   coeffs.reserve(ineq.size());
140   for (int64_t coeff : ineq)
141     coeffs.emplace_back(-coeff);
142   --coeffs.back();
143   return coeffs;
144 }
145 
146 /// Return the set difference b \ s and accumulate the result into `result`.
147 /// `simplex` must correspond to b.
148 ///
149 /// In the following, V denotes union, ^ denotes intersection, \ denotes set
150 /// difference and ~ denotes complement.
151 /// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want
152 /// b \ (V_i s_i).
153 ///
154 /// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
155 /// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
156 /// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ...
157 /// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ...
158 /// We recurse by subtracting V_{j > i} S_j from each of these parts and
159 /// returning the union of the results. Each equality is handled as a
160 /// conjunction of two inequalities.
161 ///
162 /// As a heuristic, we try adding all the constraints and check if simplex
163 /// says that the intersection is empty. Also, in the process we find out that
164 /// some constraints are redundant. These redundant constraints are ignored.
subtractRecursively(FlatAffineConstraints & b,Simplex & simplex,const PresburgerSet & s,unsigned i,PresburgerSet & result)165 static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
166                                 const PresburgerSet &s, unsigned i,
167                                 PresburgerSet &result) {
168   if (i == s.getNumFACs()) {
169     result.unionFACInPlace(b);
170     return;
171   }
172   const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
173   unsigned initialSnapshot = simplex.getSnapshot();
174   unsigned offset = simplex.numConstraints();
175   simplex.intersectFlatAffineConstraints(sI);
176 
177   if (simplex.isEmpty()) {
178     /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
179     simplex.rollback(initialSnapshot);
180     subtractRecursively(b, simplex, s, i + 1, result);
181     return;
182   }
183 
184   simplex.detectRedundant();
185   llvm::SmallBitVector isMarkedRedundant;
186   for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
187        j++)
188     isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
189 
190   simplex.rollback(initialSnapshot);
191 
192   // Recurse with the part b ^ ~ineq. Note that b is modified throughout
193   // subtractRecursively. At the time this function is called, the current b is
194   // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
195   // inequality, s_{i,j+1}. This function recurses into the next level i + 1
196   // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
197   auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
198     size_t snapshot = simplex.getSnapshot();
199     b.addInequality(ineq);
200     simplex.addInequality(ineq);
201     subtractRecursively(b, simplex, s, i + 1, result);
202     b.removeInequality(b.getNumInequalities() - 1);
203     simplex.rollback(snapshot);
204   };
205 
206   // For each inequality ineq, we first recurse with the part where ineq
207   // is not satisfied, and then add the ineq to b and simplex because
208   // ineq must be satisfied by all later parts.
209   auto processInequality = [&](ArrayRef<int64_t> ineq) {
210     recurseWithInequality(getComplementIneq(ineq));
211     b.addInequality(ineq);
212     simplex.addInequality(ineq);
213   };
214 
215   // processInequality appends some additional constraints to b. We want to
216   // rollback b to its initial state before returning, which we will do by
217   // removing all constraints beyond the original number of inequalities
218   // and equalities, so we store these counts first.
219   unsigned originalNumIneqs = b.getNumInequalities();
220   unsigned originalNumEqs = b.getNumEqualities();
221 
222   for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
223     if (isMarkedRedundant[j])
224       continue;
225     processInequality(sI.getInequality(j));
226   }
227 
228   offset = sI.getNumInequalities();
229   for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
230     const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
231     // Same as the above loop for inequalities, done once each for the positive
232     // and negative inequalities that make up this equality.
233     if (!isMarkedRedundant[offset + 2 * j])
234       processInequality(coeffs);
235     if (!isMarkedRedundant[offset + 2 * j + 1])
236       processInequality(getNegatedCoeffs(coeffs));
237   }
238 
239   // Rollback b and simplex to their initial states.
240   for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
241     b.removeInequality(i - 1);
242 
243   for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
244     b.removeEquality(i - 1);
245 
246   simplex.rollback(initialSnapshot);
247 }
248 
249 /// Return the set difference fac \ set.
250 ///
251 /// The FAC here is modified in subtractRecursively, so it cannot be a const
252 /// reference even though it is restored to its original state before returning
253 /// from that function.
getSetDifference(FlatAffineConstraints fac,const PresburgerSet & set)254 PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
255                                               const PresburgerSet &set) {
256   assertDimensionsCompatible(fac, set);
257   if (fac.isEmptyByGCDTest())
258     return PresburgerSet::getEmptySet(fac.getNumDimIds(),
259                                       fac.getNumSymbolIds());
260 
261   PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds());
262   Simplex simplex(fac);
263   subtractRecursively(fac, simplex, set, 0, result);
264   return result;
265 }
266 
267 /// Return the complement of this set.
complement() const268 PresburgerSet PresburgerSet::complement() const {
269   return getSetDifference(
270       FlatAffineConstraints::getUniverse(getNumDims(), getNumSyms()), *this);
271 }
272 
273 /// Return the result of subtract the given set from this set, i.e.,
274 /// return `this \ set`.
subtract(const PresburgerSet & set) const275 PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
276   assertDimensionsCompatible(set, *this);
277   PresburgerSet result(nDim, nSym);
278   // We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i).
279   for (const FlatAffineConstraints &fac : flatAffineConstraints)
280     result.unionSetInPlace(getSetDifference(fac, set));
281   return result;
282 }
283 
284 /// Return true if all the sets in the union are known to be integer empty,
285 /// false otherwise.
isIntegerEmpty() const286 bool PresburgerSet::isIntegerEmpty() const {
287   assert(nSym == 0 && "isIntegerEmpty is intended for non-symbolic sets");
288   // The set is empty iff all of the disjuncts are empty.
289   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
290     if (!fac.isIntegerEmpty())
291       return false;
292   }
293   return true;
294 }
295 
findIntegerSample(SmallVectorImpl<int64_t> & sample)296 bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
297   assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets");
298   // A sample exists iff any of the disjuncts contains a sample.
299   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
300     if (Optional<SmallVector<int64_t, 8>> opt = fac.findIntegerSample()) {
301       sample = std::move(*opt);
302       return true;
303     }
304   }
305   return false;
306 }
307 
print(raw_ostream & os) const308 void PresburgerSet::print(raw_ostream &os) const {
309   os << getNumFACs() << " FlatAffineConstraints:\n";
310   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
311     fac.print(os);
312     os << '\n';
313   }
314 }
315 
dump() const316 void PresburgerSet::dump() const { print(llvm::errs()); }
317