1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "linalg-dependence-analysis"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 using llvm::dbgs;
26 
find(Value v)27 Value Aliases::find(Value v) {
28   if (v.isa<BlockArgument>())
29     return v;
30 
31   auto it = aliases.find(v);
32   if (it != aliases.end()) {
33     assert(it->getSecond().getType().isa<BaseMemRefType>() &&
34            "Memref expected");
35     return it->getSecond();
36   }
37 
38   while (true) {
39     if (v.isa<BlockArgument>())
40       return v;
41 
42     Operation *defOp = v.getDefiningOp();
43     if (!defOp)
44       return v;
45 
46     if (isa<TensorToMemrefOp>(defOp))
47       return v;
48 
49     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
50       // Collect all memory effects on `v`.
51       SmallVector<MemoryEffects::EffectInstance, 1> effects;
52       memEffect.getEffectsOnValue(v, effects);
53 
54       // If we have the 'Allocate' memory effect on `v`, then `v` should be the
55       // original buffer.
56       if (llvm::any_of(
57               effects, [](const MemoryEffects::EffectInstance &instance) {
58                 return isa<MemoryEffects::Allocate>(instance.getEffect());
59               }))
60         return v;
61     }
62 
63     if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
64       auto it =
65           aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
66       return it.first->second;
67     }
68 
69     llvm::errs() << "View alias analysis reduces to: " << v << "\n";
70     llvm_unreachable("unsupported view alias case");
71   }
72 }
73 
getDependenceTypeStr(DependenceType depType)74 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
75   switch (depType) {
76   case LinalgDependenceGraph::DependenceType::RAW:
77     return "RAW";
78   case LinalgDependenceGraph::DependenceType::RAR:
79     return "RAR";
80   case LinalgDependenceGraph::DependenceType::WAR:
81     return "WAR";
82   case LinalgDependenceGraph::DependenceType::WAW:
83     return "WAW";
84   default:
85     break;
86   }
87   llvm_unreachable("Unexpected DependenceType");
88 }
89 
90 LinalgDependenceGraph
buildDependenceGraph(Aliases & aliases,FuncOp f)91 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
92   SmallVector<LinalgOp, 8> linalgOps;
93   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
94   return LinalgDependenceGraph(aliases, linalgOps);
95 }
96 
LinalgDependenceGraph(Aliases & aliases,ArrayRef<LinalgOp> ops)97 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
98                                              ArrayRef<LinalgOp> ops)
99     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
100   for (auto en : llvm::enumerate(linalgOps)) {
101     linalgOpPositions.insert(
102         std::make_pair(en.value().getOperation(), en.index()));
103   }
104   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
105     for (unsigned j = i + 1; j < e; ++j) {
106       addDependencesBetween(ops[i], ops[j]);
107     }
108   }
109 }
110 
addDependenceElem(DependenceType dt,LinalgOpView indexingOpView,LinalgOpView dependentOpView)111 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
112                                               LinalgOpView indexingOpView,
113                                               LinalgOpView dependentOpView) {
114   LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
115                     << *indexingOpView.op << ", " << indexingOpView.operandIndex
116                     << ") -> \n\t\t(" << *dependentOpView.op << ", "
117                     << dependentOpView.operandIndex << ")");
118   dependencesFromGraphs[dt][indexingOpView.op].push_back(
119       LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
120   dependencesIntoGraphs[dt][dependentOpView.op].push_back(
121       LinalgDependenceGraphElem{indexingOpView, dependentOpView, dt});
122 }
123 
124 LinalgDependenceGraph::dependence_range
getDependencesFrom(LinalgOp src,LinalgDependenceGraph::DependenceType dt) const125 LinalgDependenceGraph::getDependencesFrom(
126     LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
127   return getDependencesFrom(src.getOperation(), dt);
128 }
129 
130 LinalgDependenceGraph::dependence_range
getDependencesFrom(Operation * src,LinalgDependenceGraph::DependenceType dt) const131 LinalgDependenceGraph::getDependencesFrom(
132     Operation *src, LinalgDependenceGraph::DependenceType dt) const {
133   auto iter = dependencesFromGraphs[dt].find(src);
134   if (iter == dependencesFromGraphs[dt].end())
135     return llvm::make_range(nullptr, nullptr);
136   return llvm::make_range(iter->second.begin(), iter->second.end());
137 }
138 
139 LinalgDependenceGraph::dependence_range
getDependencesInto(LinalgOp dst,LinalgDependenceGraph::DependenceType dt) const140 LinalgDependenceGraph::getDependencesInto(
141     LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
142   return getDependencesInto(dst.getOperation(), dt);
143 }
144 
145 LinalgDependenceGraph::dependence_range
getDependencesInto(Operation * dst,LinalgDependenceGraph::DependenceType dt) const146 LinalgDependenceGraph::getDependencesInto(
147     Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
148   auto iter = dependencesIntoGraphs[dt].find(dst);
149   if (iter == dependencesIntoGraphs[dt].end())
150     return llvm::make_range(nullptr, nullptr);
151   return llvm::make_range(iter->second.begin(), iter->second.end());
152 }
153 
addDependencesBetween(LinalgOp src,LinalgOp dst)154 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
155   for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W
156     unsigned srcIndex =
157         src.getOperandIndexForOutputIndex(srcView.index()).getValue();
158     // RAW graph
159     for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
160       if (aliases.alias(srcView.value(),
161                         dstView.value())) { // if alias, fill RAW
162         unsigned dstIndex =
163             dst.getOperandIndexForInputIndex(dstView.index()).getValue();
164         addDependenceElem(DependenceType::RAW,
165                           LinalgOpView{src.getOperation(), srcIndex},
166                           LinalgOpView{dst.getOperation(), dstIndex});
167       }
168     }
169     // WAW graph
170     for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
171       if (aliases.alias(srcView.value(),
172                         dstView.value())) { // if alias, fill WAW
173         unsigned dstIndex =
174             dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
175         addDependenceElem(DependenceType::WAW,
176                           LinalgOpView{src.getOperation(), srcIndex},
177                           LinalgOpView{dst.getOperation(), dstIndex});
178       }
179     }
180   }
181   for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R
182     unsigned srcIndex =
183         src.getOperandIndexForInputIndex(srcView.index()).getValue();
184     // RAR graph
185     for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
186       if (aliases.alias(srcView.value(),
187                         dstView.value())) { // if alias, fill RAR
188         unsigned dstIndex =
189             dst.getOperandIndexForInputIndex(dstView.index()).getValue();
190         addDependenceElem(DependenceType::RAR,
191                           LinalgOpView{src.getOperation(), srcIndex},
192                           LinalgOpView{dst.getOperation(), dstIndex});
193       }
194     }
195     // WAR graph
196     for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
197       if (aliases.alias(srcView.value(),
198                         dstView.value())) { // if alias, fill WAR
199         unsigned dstIndex =
200             dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
201         addDependenceElem(DependenceType::WAR,
202                           LinalgOpView{src.getOperation(), srcIndex},
203                           LinalgOpView{dst.getOperation(), dstIndex});
204       }
205     }
206   }
207 }
208 
209 SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp) const210 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
211                                                LinalgOp dstLinalgOp) const {
212   return findOperationsWithCoveringDependences(
213       srcLinalgOp, dstLinalgOp, nullptr,
214       {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
215 }
216 
findCoveringWrites(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const217 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
218     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
219   return findOperationsWithCoveringDependences(
220       srcLinalgOp, dstLinalgOp, view,
221       {DependenceType::WAW, DependenceType::WAR});
222 }
223 
findCoveringReads(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const224 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
225     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
226   return findOperationsWithCoveringDependences(
227       srcLinalgOp, dstLinalgOp, view,
228       {DependenceType::RAR, DependenceType::RAW});
229 }
230 
231 SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view,ArrayRef<DependenceType> types) const232 LinalgDependenceGraph::findOperationsWithCoveringDependences(
233     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
234     ArrayRef<DependenceType> types) const {
235   auto *src = srcLinalgOp.getOperation();
236   auto *dst = dstLinalgOp.getOperation();
237   auto srcPos = linalgOpPositions.lookup(src);
238   auto dstPos = linalgOpPositions.lookup(dst);
239   assert(srcPos < dstPos && "expected dst after src in IR traversal order");
240 
241   SmallVector<Operation *, 8> res;
242   // Consider an intermediate interleaved `interim` op, look for any dependence
243   // to an aliasing view on a src -> op -> dst path.
244   // TODO: we are not considering paths yet, just interleaved positions.
245   for (auto dt : types) {
246     for (auto dependence : getDependencesFrom(src, dt)) {
247       auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op);
248       // Skip if not interleaved.
249       if (interimPos >= dstPos || interimPos <= srcPos)
250         continue;
251       linalg::LinalgOp consumer =
252           cast<linalg::LinalgOp>(dependence.indexingOpView.op);
253       Value consumerView =
254           consumer.getShapedOperand(dependence.indexingOpView.operandIndex);
255       if (view && !aliases.alias(view, consumerView))
256         continue;
257       auto *op = dependence.dependentOpView.op;
258       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
259                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
260                         << *op << " on " << consumerView);
261       res.push_back(op);
262     }
263   }
264   return res;
265 }
266 
hasDependenceFrom(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const267 bool LinalgDependenceGraph::hasDependenceFrom(
268     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
269     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
270   for (auto dep : depTypes) {
271     for (auto dependence : getDependencesInto(dstLinalgOp, dep)) {
272       if (dependence.dependentOpView.op == srcLinalgOp)
273         return true;
274     }
275   }
276   return false;
277 }
278 
hasDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const279 bool LinalgDependenceGraph::hasDependentOperationsFrom(
280     LinalgOp linalgOp,
281     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
282   for (auto dep : depTypes) {
283     if (!getDependencesFrom(linalgOp, dep).empty())
284       return true;
285   }
286   return false;
287 }
288 
hasDependentOperationsInto(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const289 bool LinalgDependenceGraph::hasDependentOperationsInto(
290     LinalgOp linalgOp,
291     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
292   for (auto dep : depTypes) {
293     if (!getDependencesInto(linalgOp, dep).empty())
294       return true;
295   }
296   return false;
297 }
298 
hasDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const299 bool LinalgDependenceGraph::hasDependentOperations(
300     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
301   return hasDependentOperationsInto(linalgOp, depTypes) ||
302          hasDependentOperationsFrom(linalgOp, depTypes);
303 }
304 
305 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsInto(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const306 LinalgDependenceGraph::getDependentOperationsInto(
307     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
308   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
309       dependentOperations;
310   for (auto dependenceType : depTypes) {
311     auto dependencies = getDependencesInto(linalgOp, dependenceType);
312     dependentOperations.append(dependencies.begin(), dependencies.end());
313   }
314   return dependentOperations;
315 }
316 
317 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const318 LinalgDependenceGraph::getDependentOperationsFrom(
319     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
320   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
321       dependentOperations;
322   for (auto dependenceType : depTypes) {
323     auto dependencies = getDependencesFrom(linalgOp, dependenceType);
324     dependentOperations.append(dependencies.begin(), dependencies.end());
325   }
326   return dependentOperations;
327 }
328 
329 /// Returns all dependent operations (into and from) given `operation`.
330 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const331 LinalgDependenceGraph::getDependentOperations(
332     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
333   SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
334       getDependentOperationsInto(linalgOp, depTypes);
335   SmallVector<LinalgDependenceGraphElem, 2> t =
336       getDependentOperationsFrom(linalgOp, depTypes);
337   dependentOperations.append(t.begin(), t.end());
338   return dependentOperations;
339 }
340