1 //===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for affine structures
10 // (expressions, maps, sets), and other utilities relying on such analysis.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/AffineAnalysis.h"
15 #include "mlir/Analysis/Utils.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #define DEBUG_TYPE "affine-analysis"
28 
29 using namespace mlir;
30 
31 using llvm::dbgs;
32 
33 /// Returns the sequence of AffineApplyOp Operations operation in
34 /// 'affineApplyOps', which are reachable via a search starting from 'operands',
35 /// and ending at operands which are not defined by AffineApplyOps.
36 // TODO: Add a method to AffineApplyOp which forward substitutes the
37 // AffineApplyOp into any user AffineApplyOps.
getReachableAffineApplyOps(ArrayRef<Value> operands,SmallVectorImpl<Operation * > & affineApplyOps)38 void mlir::getReachableAffineApplyOps(
39     ArrayRef<Value> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
40   struct State {
41     // The ssa value for this node in the DFS traversal.
42     Value value;
43     // The operand index of 'value' to explore next during DFS traversal.
44     unsigned operandIndex;
45   };
46   SmallVector<State, 4> worklist;
47   for (auto operand : operands) {
48     worklist.push_back({operand, 0});
49   }
50 
51   while (!worklist.empty()) {
52     State &state = worklist.back();
53     auto *opInst = state.value.getDefiningOp();
54     // Note: getDefiningOp will return nullptr if the operand is not an
55     // Operation (i.e. block argument), which is a terminator for the search.
56     if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
57       worklist.pop_back();
58       continue;
59     }
60 
61     if (state.operandIndex == 0) {
62       // Pre-Visit: Add 'opInst' to reachable sequence.
63       affineApplyOps.push_back(opInst);
64     }
65     if (state.operandIndex < opInst->getNumOperands()) {
66       // Visit: Add next 'affineApplyOp' operand to worklist.
67       // Get next operand to visit at 'operandIndex'.
68       auto nextOperand = opInst->getOperand(state.operandIndex);
69       // Increment 'operandIndex' in 'state'.
70       ++state.operandIndex;
71       // Add 'nextOperand' to worklist.
72       worklist.push_back({nextOperand, 0});
73     } else {
74       // Post-visit: done visiting operands AffineApplyOp, pop off stack.
75       worklist.pop_back();
76     }
77   }
78 }
79 
80 // Builds a system of constraints with dimensional identifiers corresponding to
81 // the loop IVs of the forOps appearing in that order. Any symbols founds in
82 // the bound operands are added as symbols in the system. Returns failure for
83 // the yet unimplemented cases.
84 // TODO: Handle non-unit steps through local variables or stride information in
85 // FlatAffineConstraints. (For eg., by using iv - lb % step = 0 and/or by
86 // introducing a method in FlatAffineConstraints setExprStride(ArrayRef<int64_t>
87 // expr, int64_t stride)
getIndexSet(MutableArrayRef<Operation * > ops,FlatAffineConstraints * domain)88 LogicalResult mlir::getIndexSet(MutableArrayRef<Operation *> ops,
89                                 FlatAffineConstraints *domain) {
90   SmallVector<Value, 4> indices;
91   SmallVector<AffineForOp, 8> forOps;
92 
93   for (Operation *op : ops) {
94     assert((isa<AffineForOp, AffineIfOp>(op)) &&
95            "ops should have either AffineForOp or AffineIfOp");
96     if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
97       forOps.push_back(forOp);
98   }
99   extractForInductionVars(forOps, &indices);
100   // Reset while associated Values in 'indices' to the domain.
101   domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
102   for (Operation *op : ops) {
103     // Add constraints from forOp's bounds.
104     if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
105       if (failed(domain->addAffineForOpDomain(forOp)))
106         return failure();
107     } else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
108       domain->addAffineIfOpDomain(ifOp);
109     }
110   }
111   return success();
112 }
113 
114 /// Computes the iteration domain for 'op' and populates 'indexSet', which
115 /// encapsulates the constraints involving loops surrounding 'op' and
116 /// potentially involving any Function symbols. The dimensional identifiers in
117 /// 'indexSet' correspond to the loops surrounding 'op' from outermost to
118 /// innermost.
getOpIndexSet(Operation * op,FlatAffineConstraints * indexSet)119 static LogicalResult getOpIndexSet(Operation *op,
120                                    FlatAffineConstraints *indexSet) {
121   SmallVector<Operation *, 4> ops;
122   getEnclosingAffineForAndIfOps(*op, &ops);
123   return getIndexSet(ops, indexSet);
124 }
125 
126 namespace {
127 // ValuePositionMap manages the mapping from Values which represent dimension
128 // and symbol identifiers from 'src' and 'dst' access functions to positions
129 // in new space where some Values are kept separate (using addSrc/DstValue)
130 // and some Values are merged (addSymbolValue).
131 // Position lookups return the absolute position in the new space which
132 // has the following format:
133 //
134 //   [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers]
135 //
136 // Note: access function non-IV dimension identifiers (that have 'dimension'
137 // positions in the access function position space) are assigned as symbols
138 // in the output position space. Convenience access functions which lookup
139 // an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
140 // the common case of resolving positions for all access function operands.
141 //
142 // TODO: Generalize this: could take a template parameter for the number of maps
143 // (3 in the current case), and lookups could take indices of maps to check. So
144 // getSrcDimOrSymPos would be "getPos(value, {0, 2})".
145 class ValuePositionMap {
146 public:
addSrcValue(Value value)147   void addSrcValue(Value value) {
148     if (addValueAt(value, &srcDimPosMap, numSrcDims))
149       ++numSrcDims;
150   }
addDstValue(Value value)151   void addDstValue(Value value) {
152     if (addValueAt(value, &dstDimPosMap, numDstDims))
153       ++numDstDims;
154   }
addSymbolValue(Value value)155   void addSymbolValue(Value value) {
156     if (addValueAt(value, &symbolPosMap, numSymbols))
157       ++numSymbols;
158   }
getSrcDimOrSymPos(Value value) const159   unsigned getSrcDimOrSymPos(Value value) const {
160     return getDimOrSymPos(value, srcDimPosMap, 0);
161   }
getDstDimOrSymPos(Value value) const162   unsigned getDstDimOrSymPos(Value value) const {
163     return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
164   }
getSymPos(Value value) const165   unsigned getSymPos(Value value) const {
166     auto it = symbolPosMap.find(value);
167     assert(it != symbolPosMap.end());
168     return numSrcDims + numDstDims + it->second;
169   }
170 
getNumSrcDims() const171   unsigned getNumSrcDims() const { return numSrcDims; }
getNumDstDims() const172   unsigned getNumDstDims() const { return numDstDims; }
getNumDims() const173   unsigned getNumDims() const { return numSrcDims + numDstDims; }
getNumSymbols() const174   unsigned getNumSymbols() const { return numSymbols; }
175 
176 private:
addValueAt(Value value,DenseMap<Value,unsigned> * posMap,unsigned position)177   bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap,
178                   unsigned position) {
179     auto it = posMap->find(value);
180     if (it == posMap->end()) {
181       (*posMap)[value] = position;
182       return true;
183     }
184     return false;
185   }
getDimOrSymPos(Value value,const DenseMap<Value,unsigned> & dimPosMap,unsigned dimPosOffset) const186   unsigned getDimOrSymPos(Value value,
187                           const DenseMap<Value, unsigned> &dimPosMap,
188                           unsigned dimPosOffset) const {
189     auto it = dimPosMap.find(value);
190     if (it != dimPosMap.end()) {
191       return dimPosOffset + it->second;
192     }
193     it = symbolPosMap.find(value);
194     assert(it != symbolPosMap.end());
195     return numSrcDims + numDstDims + it->second;
196   }
197 
198   unsigned numSrcDims = 0;
199   unsigned numDstDims = 0;
200   unsigned numSymbols = 0;
201   DenseMap<Value, unsigned> srcDimPosMap;
202   DenseMap<Value, unsigned> dstDimPosMap;
203   DenseMap<Value, unsigned> symbolPosMap;
204 };
205 } // namespace
206 
207 // Builds a map from Value to identifier position in a new merged identifier
208 // list, which is the result of merging dim/symbol lists from src/dst
209 // iteration domains, the format of which is as follows:
210 //
211 //   [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
212 //
213 // This method populates 'valuePosMap' with mappings from operand Values in
214 // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
215 // to the position of these values in the merged list.
buildDimAndSymbolPositionMaps(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,const AffineValueMap & srcAccessMap,const AffineValueMap & dstAccessMap,ValuePositionMap * valuePosMap,FlatAffineConstraints * dependenceConstraints)216 static void buildDimAndSymbolPositionMaps(
217     const FlatAffineConstraints &srcDomain,
218     const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
219     const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
220     FlatAffineConstraints *dependenceConstraints) {
221 
222   // IsDimState is a tri-state boolean. It is used to distinguish three
223   // different cases of the values passed to updateValuePosMap.
224   // - When it is TRUE, we are certain that all values are dim values.
225   // - When it is FALSE, we are certain that all values are symbol values.
226   // - When it is UNKNOWN, we need to further check whether the value is from a
227   // loop IV to determine its type (dim or symbol).
228 
229   // We need this enumeration because sometimes we cannot determine whether a
230   // Value is a symbol or a dim by the information from the Value itself. If a
231   // Value appears in an affine map of a loop, we can determine whether it is a
232   // dim or not by the function `isForInductionVar`. But when a Value is in the
233   // affine set of an if-statement, there is no way to identify its category
234   // (dim/symbol) by itself. Fortunately, the Values to be inserted into
235   // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
236   // information of Value category: `srcDomain` and `dstDomain` organize Values
237   // by their category, such that the position of each Value stored in
238   // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
239   // Therefore, we can separate Values into dim and symbol groups before passing
240   // them to the function `updateValuePosMap`. Specifically, when passing the
241   // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
242   // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
243   // not explicitly categorized into dim or symbol, and we have to rely on
244   // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
245   // this case.
246   enum IsDimState { TRUE, FALSE, UNKNOWN };
247 
248   // This function places each given Value (in `values`) under a respective
249   // category in `valuePosMap`. Specifically, the placement rules are:
250   // 1) If `isDim` is FALSE, then every value in `values` are inserted into
251   // `valuePosMap` as symbols.
252   // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
253   // induction variable of a for-loop, we treat it as symbol as well.
254   // 3) For other cases, we decide whether to add a value to the `src` or the
255   // `dst` section of the dim category simply by the boolean value `isSrc`.
256   auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
257                                IsDimState isDim) {
258     for (unsigned i = 0, e = values.size(); i < e; ++i) {
259       auto value = values[i];
260       if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
261         assert(isValidSymbol(value) &&
262                "access operand has to be either a loop IV or a symbol");
263         valuePosMap->addSymbolValue(value);
264       } else {
265         if (isSrc)
266           valuePosMap->addSrcValue(value);
267         else
268           valuePosMap->addDstValue(value);
269       }
270     }
271   };
272 
273   // Collect values from the src and dst domains. For each domain, we separate
274   // the collected values into dim and symbol parts.
275   SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
276       dstSymbolValues;
277   srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcDimValues);
278   dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstDimValues);
279   srcDomain.getIdValues(srcDomain.getNumDimIds(),
280                         srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
281   dstDomain.getIdValues(dstDomain.getNumDimIds(),
282                         dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
283 
284   // Update value position map with dim values from src iteration domain.
285   updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
286   // Update value position map with dim values from dst iteration domain.
287   updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
288   // Update value position map with symbols from src iteration domain.
289   updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
290   // Update value position map with symbols from dst iteration domain.
291   updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
292   // Update value position map with identifiers from src access function.
293   updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
294                     /*isDim=*/UNKNOWN);
295   // Update value position map with identifiers from dst access function.
296   updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
297                     /*isDim=*/UNKNOWN);
298 }
299 
300 // Sets up dependence constraints columns appropriately, in the format:
301 // [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
initDependenceConstraints(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,const AffineValueMap & srcAccessMap,const AffineValueMap & dstAccessMap,const ValuePositionMap & valuePosMap,FlatAffineConstraints * dependenceConstraints)302 static void initDependenceConstraints(
303     const FlatAffineConstraints &srcDomain,
304     const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
305     const AffineValueMap &dstAccessMap, const ValuePositionMap &valuePosMap,
306     FlatAffineConstraints *dependenceConstraints) {
307   // Calculate number of equalities/inequalities and columns required to
308   // initialize FlatAffineConstraints for 'dependenceDomain'.
309   unsigned numIneq =
310       srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
311   AffineMap srcMap = srcAccessMap.getAffineMap();
312   assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
313   unsigned numEq = srcMap.getNumResults();
314   unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
315   unsigned numSymbols = valuePosMap.getNumSymbols();
316   unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
317   unsigned numIds = numDims + numSymbols + numLocals;
318   unsigned numCols = numIds + 1;
319 
320   // Set flat affine constraints sizes and reserving space for constraints.
321   dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
322                                numLocals);
323 
324   // Set values corresponding to dependence constraint identifiers.
325   SmallVector<Value, 4> srcLoopIVs, dstLoopIVs;
326   srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
327   dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
328 
329   dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs);
330   dependenceConstraints->setIdValues(
331       srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
332 
333   // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
334   // indicates whether we are certain that the `values` passed in are all
335   // symbols. If `isSymbolDetermined` is true, then we treat every Value in
336   // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
337   // distinguish whether a Value in `values` is a symbol or not.
338   auto setSymbolIds = [&](ArrayRef<Value> values,
339                           bool isSymbolDetermined = true) {
340     for (auto value : values) {
341       if (isSymbolDetermined || !isForInductionVar(value)) {
342         assert(isValidSymbol(value) && "expected symbol");
343         dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
344       }
345     }
346   };
347 
348   // We are uncertain about whether all operands in `srcAccessMap` and
349   // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
350   setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
351   setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
352 
353   SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
354   srcDomain.getIdValues(srcDomain.getNumDimIds(),
355                         srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
356   dstDomain.getIdValues(dstDomain.getNumDimIds(),
357                         dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
358   // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
359   // `isSymbolDetermined` is kept to its default value: true.
360   setSymbolIds(srcSymbolValues);
361   setSymbolIds(dstSymbolValues);
362 
363   for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
364        i < e; i++)
365     assert(dependenceConstraints->getIds()[i].hasValue());
366 }
367 
368 // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
369 // 'dependenceDomain'.
370 // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
371 // srcDomain/dstDomain Value maps.
addDomainConstraints(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,const ValuePositionMap & valuePosMap,FlatAffineConstraints * dependenceDomain)372 static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
373                                  const FlatAffineConstraints &dstDomain,
374                                  const ValuePositionMap &valuePosMap,
375                                  FlatAffineConstraints *dependenceDomain) {
376   unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
377 
378   SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
379 
380   auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
381     const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain;
382     unsigned numCsts =
383         isEq ? domain.getNumEqualities() : domain.getNumInequalities();
384     unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
385     auto at = [&](unsigned i, unsigned j) -> int64_t {
386       return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
387     };
388     auto map = [&](unsigned i) -> int64_t {
389       return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i))
390                    : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i));
391     };
392 
393     for (unsigned i = 0; i < numCsts; ++i) {
394       // Zero fill.
395       std::fill(cst.begin(), cst.end(), 0);
396       // Set coefficients for identifiers corresponding to domain.
397       for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
398         cst[map(j)] = at(i, j);
399       // Local terms.
400       for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
401         cst[depNumDimsAndSymbolIds + localOffset + j] =
402             at(i, numDimAndSymbolIds + j);
403       // Set constant term.
404       cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
405       // Add constraint.
406       if (isEq)
407         dependenceDomain->addEquality(cst);
408       else
409         dependenceDomain->addInequality(cst);
410     }
411   };
412 
413   // Add equalities from src domain.
414   addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
415   // Add inequalities from src domain.
416   addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
417   // Add equalities from dst domain.
418   addDomain(/*isSrc=*/false, /*isEq=*/true,
419             /*localOffset=*/srcDomain.getNumLocalIds());
420   // Add inequalities from dst domain.
421   addDomain(/*isSrc=*/false, /*isEq=*/false,
422             /*localOffset=*/srcDomain.getNumLocalIds());
423 }
424 
425 // Adds equality constraints that equate src and dst access functions
426 // represented by 'srcAccessMap' and 'dstAccessMap' for each result.
427 // Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
428 // For example, given the following two accesses functions to a 2D memref:
429 //
430 //   Source access function:
431 //     (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
432 //
433 //   Destination access function:
434 //     (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
435 //
436 // This method constructs the following equality constraints in
437 // 'dependenceDomain', by equating the access functions for each result
438 // (i.e. each memref dim). Notice that 'd0' for the destination access function
439 // is mapped into 'd0' in the equality constraint:
440 //
441 //   d0      d1      s0         c
442 //   --      --      --         --
443 //   a0     -c0      (a1 - c1)  (a1 - c2) = 0
444 //   b0     -f0      (b1 - f1)  (b1 - f2) = 0
445 //
446 // Returns failure if any AffineExpr cannot be flattened (due to it being
447 // semi-affine). Returns success otherwise.
448 static LogicalResult
addMemRefAccessConstraints(const AffineValueMap & srcAccessMap,const AffineValueMap & dstAccessMap,const ValuePositionMap & valuePosMap,FlatAffineConstraints * dependenceDomain)449 addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
450                            const AffineValueMap &dstAccessMap,
451                            const ValuePositionMap &valuePosMap,
452                            FlatAffineConstraints *dependenceDomain) {
453   AffineMap srcMap = srcAccessMap.getAffineMap();
454   AffineMap dstMap = dstAccessMap.getAffineMap();
455   assert(srcMap.getNumResults() == dstMap.getNumResults());
456   unsigned numResults = srcMap.getNumResults();
457 
458   unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
459   ArrayRef<Value> srcOperands = srcAccessMap.getOperands();
460 
461   unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
462   ArrayRef<Value> dstOperands = dstAccessMap.getOperands();
463 
464   std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
465   std::vector<SmallVector<int64_t, 8>> destFlatExprs;
466   FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
467   // Get flattened expressions for the source destination maps.
468   if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
469       failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
470     return failure();
471 
472   unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
473   unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
474   unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
475   unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
476   for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
477     dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
478   }
479 
480   unsigned numDims = dependenceDomain->getNumDimIds();
481   unsigned numSymbols = dependenceDomain->getNumSymbolIds();
482   unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
483   unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
484 
485   // Equality to add.
486   SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
487   for (unsigned i = 0; i < numResults; ++i) {
488     // Zero fill.
489     std::fill(eq.begin(), eq.end(), 0);
490 
491     // Flattened AffineExpr for src result 'i'.
492     const auto &srcFlatExpr = srcFlatExprs[i];
493     // Set identifier coefficients from src access function.
494     for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
495       eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
496     // Local terms.
497     for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
498       eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
499     // Set constant term.
500     eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
501 
502     // Flattened AffineExpr for dest result 'i'.
503     const auto &destFlatExpr = destFlatExprs[i];
504     // Set identifier coefficients from dst access function.
505     for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
506       eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
507     // Local terms.
508     for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
509       eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
510     // Set constant term.
511     eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
512 
513     // Add equality constraint.
514     dependenceDomain->addEquality(eq);
515   }
516 
517   // Add equality constraints for any operands that are defined by constant ops.
518   auto addEqForConstOperands = [&](ArrayRef<Value> operands) {
519     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
520       if (isForInductionVar(operands[i]))
521         continue;
522       auto symbol = operands[i];
523       assert(isValidSymbol(symbol));
524       // Check if the symbol is a constant.
525       if (auto cOp = symbol.getDefiningOp<ConstantIndexOp>())
526         dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
527                                           cOp.getValue());
528     }
529   };
530 
531   // Add equality constraints for any src symbols defined by constant ops.
532   addEqForConstOperands(srcOperands);
533   // Add equality constraints for any dst symbols defined by constant ops.
534   addEqForConstOperands(dstOperands);
535 
536   // By construction (see flattener), local var constraints will not have any
537   // equalities.
538   assert(srcLocalVarCst.getNumEqualities() == 0 &&
539          destLocalVarCst.getNumEqualities() == 0);
540   // Add inequalities from srcLocalVarCst and destLocalVarCst into the
541   // dependence domain.
542   SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
543   for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
544     std::fill(ineq.begin(), ineq.end(), 0);
545 
546     // Set identifier coefficients from src local var constraints.
547     for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
548       ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
549           srcLocalVarCst.atIneq(r, j);
550     // Local terms.
551     for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
552       ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
553     // Set constant term.
554     ineq[ineq.size() - 1] =
555         srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
556     dependenceDomain->addInequality(ineq);
557   }
558 
559   for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
560     std::fill(ineq.begin(), ineq.end(), 0);
561     // Set identifier coefficients from dest local var constraints.
562     for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
563       ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
564           destLocalVarCst.atIneq(r, j);
565     // Local terms.
566     for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
567       ineq[newLocalIdOffset + numSrcLocalIds + j] =
568           destLocalVarCst.atIneq(r, dstNumIds + j);
569     // Set constant term.
570     ineq[ineq.size() - 1] =
571         destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
572 
573     dependenceDomain->addInequality(ineq);
574   }
575   return success();
576 }
577 
578 // Returns the number of outer loop common to 'src/dstDomain'.
579 // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
580 static unsigned
getNumCommonLoops(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,SmallVectorImpl<AffineForOp> * commonLoops=nullptr)581 getNumCommonLoops(const FlatAffineConstraints &srcDomain,
582                   const FlatAffineConstraints &dstDomain,
583                   SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
584   // Find the number of common loops shared by src and dst accesses.
585   unsigned minNumLoops =
586       std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
587   unsigned numCommonLoops = 0;
588   for (unsigned i = 0; i < minNumLoops; ++i) {
589     if (!isForInductionVar(srcDomain.getIdValue(i)) ||
590         !isForInductionVar(dstDomain.getIdValue(i)) ||
591         srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
592       break;
593     if (commonLoops != nullptr)
594       commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i)));
595     ++numCommonLoops;
596   }
597   if (commonLoops != nullptr)
598     assert(commonLoops->size() == numCommonLoops);
599   return numCommonLoops;
600 }
601 
602 /// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
getCommonBlock(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,const FlatAffineConstraints & srcDomain,unsigned numCommonLoops)603 static Block *getCommonBlock(const MemRefAccess &srcAccess,
604                              const MemRefAccess &dstAccess,
605                              const FlatAffineConstraints &srcDomain,
606                              unsigned numCommonLoops) {
607   // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
608   // search terminates when either an op with the `AffineScope` trait or
609   // `endBlock` is reached.
610   auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
611                                       SmallVector<Block *, 4> &ancestorBlocks,
612                                       Block *endBlock = nullptr) {
613     Block *currBlock = access.opInst->getBlock();
614     // Loop terminates when the currBlock is nullptr or equals to the endBlock,
615     // or its parent operation holds an affine scope.
616     while (currBlock && currBlock != endBlock &&
617            !currBlock->getParentOp()->hasTrait<OpTrait::AffineScope>()) {
618       ancestorBlocks.push_back(currBlock);
619       currBlock = currBlock->getParentOp()->getBlock();
620     }
621   };
622 
623   if (numCommonLoops == 0) {
624     Block *block = srcAccess.opInst->getBlock();
625     while (!llvm::isa<FuncOp>(block->getParentOp())) {
626       block = block->getParentOp()->getBlock();
627     }
628     return block;
629   }
630   Value commonForIV = srcDomain.getIdValue(numCommonLoops - 1);
631   AffineForOp forOp = getForInductionVarOwner(commonForIV);
632   assert(forOp && "commonForValue was not an induction variable");
633 
634   // Find the closest common block including those in AffineIf.
635   SmallVector<Block *, 4> srcAncestorBlocks, dstAncestorBlocks;
636   getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody());
637   getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody());
638 
639   Block *commonBlock = forOp.getBody();
640   for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1;
641        i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
642        i--, j--)
643     commonBlock = srcAncestorBlocks[i];
644 
645   return commonBlock;
646 }
647 
648 // Returns true if the ancestor operation of 'srcAccess' appears before the
649 // ancestor operation of 'dstAccess' in the common ancestral block. Returns
650 // false otherwise.
651 // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
652 // the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that
653 // 'numCommonLoops' is the number of contiguous surrounding outer loops.
srcAppearsBeforeDstInAncestralBlock(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,const FlatAffineConstraints & srcDomain,unsigned numCommonLoops)654 static bool srcAppearsBeforeDstInAncestralBlock(
655     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
656     const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) {
657   // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
658   auto *commonBlock =
659       getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
660   // Check the dominance relationship between the respective ancestors of the
661   // src and dst in the Block of the innermost among the common loops.
662   auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst);
663   assert(srcInst != nullptr);
664   auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst);
665   assert(dstInst != nullptr);
666 
667   // Determine whether dstInst comes after srcInst.
668   return srcInst->isBeforeInBlock(dstInst);
669 }
670 
671 // Adds ordering constraints to 'dependenceDomain' based on number of loops
672 // common to 'src/dstDomain' and requested 'loopDepth'.
673 // Note that 'loopDepth' cannot exceed the number of common loops plus one.
674 // EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
675 // *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
676 // *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
677 // *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
addOrderingConstraints(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,unsigned loopDepth,FlatAffineConstraints * dependenceDomain)678 static void addOrderingConstraints(const FlatAffineConstraints &srcDomain,
679                                    const FlatAffineConstraints &dstDomain,
680                                    unsigned loopDepth,
681                                    FlatAffineConstraints *dependenceDomain) {
682   unsigned numCols = dependenceDomain->getNumCols();
683   SmallVector<int64_t, 4> eq(numCols);
684   unsigned numSrcDims = srcDomain.getNumDimIds();
685   unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
686   unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
687   for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
688     std::fill(eq.begin(), eq.end(), 0);
689     eq[i] = -1;
690     eq[i + numSrcDims] = 1;
691     if (i == loopDepth - 1) {
692       eq[numCols - 1] = -1;
693       dependenceDomain->addInequality(eq);
694     } else {
695       dependenceDomain->addEquality(eq);
696     }
697   }
698 }
699 
700 // Computes distance and direction vectors in 'dependences', by adding
701 // variables to 'dependenceDomain' which represent the difference of the IVs,
702 // eliminating all other variables, and reading off distance vectors from
703 // equality constraints (if possible), and direction vectors from inequalities.
computeDirectionVector(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,unsigned loopDepth,FlatAffineConstraints * dependenceDomain,SmallVector<DependenceComponent,2> * dependenceComponents)704 static void computeDirectionVector(
705     const FlatAffineConstraints &srcDomain,
706     const FlatAffineConstraints &dstDomain, unsigned loopDepth,
707     FlatAffineConstraints *dependenceDomain,
708     SmallVector<DependenceComponent, 2> *dependenceComponents) {
709   // Find the number of common loops shared by src and dst accesses.
710   SmallVector<AffineForOp, 4> commonLoops;
711   unsigned numCommonLoops =
712       getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
713   if (numCommonLoops == 0)
714     return;
715   // Compute direction vectors for requested loop depth.
716   unsigned numIdsToEliminate = dependenceDomain->getNumIds();
717   // Add new variables to 'dependenceDomain' to represent the direction
718   // constraints for each shared loop.
719   for (unsigned j = 0; j < numCommonLoops; ++j) {
720     dependenceDomain->addDimId(j);
721   }
722 
723   // Add equality constraints for each common loop, setting newly introduced
724   // variable at column 'j' to the 'dst' IV minus the 'src IV.
725   SmallVector<int64_t, 4> eq;
726   eq.resize(dependenceDomain->getNumCols());
727   unsigned numSrcDims = srcDomain.getNumDimIds();
728   // Constraint variables format:
729   // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant]
730   for (unsigned j = 0; j < numCommonLoops; ++j) {
731     std::fill(eq.begin(), eq.end(), 0);
732     eq[j] = 1;
733     eq[j + numCommonLoops] = 1;
734     eq[j + numCommonLoops + numSrcDims] = -1;
735     dependenceDomain->addEquality(eq);
736   }
737 
738   // Eliminate all variables other than the direction variables just added.
739   dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate);
740 
741   // Scan each common loop variable column and set direction vectors based
742   // on eliminated constraint system.
743   dependenceComponents->resize(numCommonLoops);
744   for (unsigned j = 0; j < numCommonLoops; ++j) {
745     (*dependenceComponents)[j].op = commonLoops[j].getOperation();
746     auto lbConst = dependenceDomain->getConstantLowerBound(j);
747     (*dependenceComponents)[j].lb =
748         lbConst.getValueOr(std::numeric_limits<int64_t>::min());
749     auto ubConst = dependenceDomain->getConstantUpperBound(j);
750     (*dependenceComponents)[j].ub =
751         ubConst.getValueOr(std::numeric_limits<int64_t>::max());
752   }
753 }
754 
755 // Populates 'accessMap' with composition of AffineApplyOps reachable from
756 // indices of MemRefAccess.
getAccessMap(AffineValueMap * accessMap) const757 void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
758   // Get affine map from AffineLoad/Store.
759   AffineMap map;
760   if (auto loadOp = dyn_cast<AffineReadOpInterface>(opInst))
761     map = loadOp.getAffineMap();
762   else
763     map = cast<AffineWriteOpInterface>(opInst).getAffineMap();
764 
765   SmallVector<Value, 8> operands(indices.begin(), indices.end());
766   fullyComposeAffineMapAndOperands(&map, &operands);
767   map = simplifyAffineMap(map);
768   canonicalizeMapAndOperands(&map, &operands);
769   accessMap->reset(map, operands);
770 }
771 
772 // Builds a flat affine constraint system to check if there exists a dependence
773 // between memref accesses 'srcAccess' and 'dstAccess'.
774 // Returns 'NoDependence' if the accesses can be definitively shown not to
775 // access the same element.
776 // Returns 'HasDependence' if the accesses do access the same element.
777 // Returns 'Failure' if an error or unsupported case was encountered.
778 // If a dependence exists, returns in 'dependenceComponents' a direction
779 // vector for the dependence, with a component for each loop IV in loops
780 // common to both accesses (see Dependence in AffineAnalysis.h for details).
781 //
782 // The memref access dependence check is comprised of the following steps:
783 // *) Compute access functions for each access. Access functions are computed
784 //    using AffineValueMaps initialized with the indices from an access, then
785 //    composed with AffineApplyOps reachable from operands of that access,
786 //    until operands of the AffineValueMap are loop IVs or symbols.
787 // *) Build iteration domain constraints for each access. Iteration domain
788 //    constraints are pairs of inequality constraints representing the
789 //    upper/lower loop bounds for each AffineForOp in the loop nest associated
790 //    with each access.
791 // *) Build dimension and symbol position maps for each access, which map
792 //    Values from access functions and iteration domains to their position
793 //    in the merged constraint system built by this method.
794 //
795 // This method builds a constraint system with the following column format:
796 //
797 //  [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
798 //
799 // For example, given the following MLIR code with "source" and "destination"
800 // accesses to the same memref label, and symbols %M, %N, %K:
801 //
802 //   affine.for %i0 = 0 to 100 {
803 //     affine.for %i1 = 0 to 50 {
804 //       %a0 = affine.apply
805 //         (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
806 //       // Source memref access.
807 //       store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
808 //     }
809 //   }
810 //
811 //   affine.for %i2 = 0 to 100 {
812 //     affine.for %i3 = 0 to 50 {
813 //       %a1 = affine.apply
814 //         (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
815 //       // Destination memref access.
816 //       %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
817 //     }
818 //   }
819 //
820 // The access functions would be the following:
821 //
822 //   src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
823 //   dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
824 //
825 // The iteration domains for the src/dst accesses would be the following:
826 //
827 //   src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
828 //   dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
829 //
830 // The symbols by both accesses would be assigned to a canonical position order
831 // which will be used in the dependence constraint system:
832 //
833 //   symbol name: %M  %N  %K
834 //   symbol  pos:  0   1   2
835 //
836 // Equality constraints are built by equating each result of src/destination
837 // access functions. For this example, the following two equality constraints
838 // will be added to the dependence constraint system:
839 //
840 //   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
841 //      2         -4        -7        -9       1      1     0     0    = 0
842 //      0          3         0        -11     -1      0     1     0    = 0
843 //
844 // Inequality constraints from the iteration domain will be meged into
845 // the dependence constraint system
846 //
847 //   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
848 //       1         0         0         0        0     0     0     0    >= 0
849 //      -1         0         0         0        0     0     0     100  >= 0
850 //       0         1         0         0        0     0     0     0    >= 0
851 //       0        -1         0         0        0     0     0     50   >= 0
852 //       0         0         1         0        0     0     0     0    >= 0
853 //       0         0        -1         0        0     0     0     100  >= 0
854 //       0         0         0         1        0     0     0     0    >= 0
855 //       0         0         0        -1        0     0     0     50   >= 0
856 //
857 //
858 // TODO: Support AffineExprs mod/floordiv/ceildiv.
checkMemrefAccessDependence(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,unsigned loopDepth,FlatAffineConstraints * dependenceConstraints,SmallVector<DependenceComponent,2> * dependenceComponents,bool allowRAR)859 DependenceResult mlir::checkMemrefAccessDependence(
860     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
861     unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
862     SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) {
863   LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
864                           << Twine(loopDepth) << " between:\n";);
865   LLVM_DEBUG(srcAccess.opInst->dump(););
866   LLVM_DEBUG(dstAccess.opInst->dump(););
867 
868   // Return 'NoDependence' if these accesses do not access the same memref.
869   if (srcAccess.memref != dstAccess.memref)
870     return DependenceResult::NoDependence;
871 
872   // Return 'NoDependence' if one of these accesses is not an
873   // AffineWriteOpInterface.
874   if (!allowRAR && !isa<AffineWriteOpInterface>(srcAccess.opInst) &&
875       !isa<AffineWriteOpInterface>(dstAccess.opInst))
876     return DependenceResult::NoDependence;
877 
878   // Get composed access function for 'srcAccess'.
879   AffineValueMap srcAccessMap;
880   srcAccess.getAccessMap(&srcAccessMap);
881 
882   // Get composed access function for 'dstAccess'.
883   AffineValueMap dstAccessMap;
884   dstAccess.getAccessMap(&dstAccessMap);
885 
886   // Get iteration domain for the 'srcAccess' operation.
887   FlatAffineConstraints srcDomain;
888   if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
889     return DependenceResult::Failure;
890 
891   // Get iteration domain for 'dstAccess' operation.
892   FlatAffineConstraints dstDomain;
893   if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
894     return DependenceResult::Failure;
895 
896   // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
897   // operation of 'srcAccess' does not properly dominate the ancestor
898   // operation of 'dstAccess' in the same common operation block.
899   // Note: this check is skipped if 'allowRAR' is true, because because RAR
900   // deps can exist irrespective of lexicographic ordering b/w src and dst.
901   unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
902   assert(loopDepth <= numCommonLoops + 1);
903   if (!allowRAR && loopDepth > numCommonLoops &&
904       !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
905                                            numCommonLoops)) {
906     return DependenceResult::NoDependence;
907   }
908   // Build dim and symbol position maps for each access from access operand
909   // Value to position in merged constraint system.
910   ValuePositionMap valuePosMap;
911   buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
912                                 dstAccessMap, &valuePosMap,
913                                 dependenceConstraints);
914   initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
915                             valuePosMap, dependenceConstraints);
916 
917   assert(valuePosMap.getNumDims() ==
918          srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
919 
920   // Create memref access constraint by equating src/dst access functions.
921   // Note that this check is conservative, and will fail in the future when
922   // local variables for mod/div exprs are supported.
923   if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
924                                         dependenceConstraints)))
925     return DependenceResult::Failure;
926 
927   // Add 'src' happens before 'dst' ordering constraints.
928   addOrderingConstraints(srcDomain, dstDomain, loopDepth,
929                          dependenceConstraints);
930   // Add src and dst domain constraints.
931   addDomainConstraints(srcDomain, dstDomain, valuePosMap,
932                        dependenceConstraints);
933 
934   // Return 'NoDependence' if the solution space is empty: no dependence.
935   if (dependenceConstraints->isEmpty()) {
936     return DependenceResult::NoDependence;
937   }
938 
939   // Compute dependence direction vector and return true.
940   if (dependenceComponents != nullptr) {
941     computeDirectionVector(srcDomain, dstDomain, loopDepth,
942                            dependenceConstraints, dependenceComponents);
943   }
944 
945   LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
946   LLVM_DEBUG(dependenceConstraints->dump());
947   return DependenceResult::HasDependence;
948 }
949 
950 /// Gathers dependence components for dependences between all ops in loop nest
951 /// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
getDependenceComponents(AffineForOp forOp,unsigned maxLoopDepth,std::vector<SmallVector<DependenceComponent,2>> * depCompsVec)952 void mlir::getDependenceComponents(
953     AffineForOp forOp, unsigned maxLoopDepth,
954     std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) {
955   // Collect all load and store ops in loop nest rooted at 'forOp'.
956   SmallVector<Operation *, 8> loadAndStoreOps;
957   forOp->walk([&](Operation *op) {
958     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
959       loadAndStoreOps.push_back(op);
960   });
961 
962   unsigned numOps = loadAndStoreOps.size();
963   for (unsigned d = 1; d <= maxLoopDepth; ++d) {
964     for (unsigned i = 0; i < numOps; ++i) {
965       auto *srcOp = loadAndStoreOps[i];
966       MemRefAccess srcAccess(srcOp);
967       for (unsigned j = 0; j < numOps; ++j) {
968         auto *dstOp = loadAndStoreOps[j];
969         MemRefAccess dstAccess(dstOp);
970 
971         FlatAffineConstraints dependenceConstraints;
972         SmallVector<DependenceComponent, 2> depComps;
973         // TODO: Explore whether it would be profitable to pre-compute and store
974         // deps instead of repeatedly checking.
975         DependenceResult result = checkMemrefAccessDependence(
976             srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
977         if (hasDependence(result))
978           depCompsVec->push_back(depComps);
979       }
980     }
981   }
982 }
983