1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopUtils.h"
14 
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/LoopAnalysis.h"
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/IntegerSet.h"
26 #include "mlir/Support/MathExtras.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/RegionUtils.h"
29 #include "mlir/Transforms/Utils.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/ADT/SmallPtrSet.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/raw_ostream.h"
36 
37 #define DEBUG_TYPE "LoopUtils"
38 
39 using namespace mlir;
40 using llvm::SetVector;
41 using llvm::SmallMapVector;
42 
43 namespace {
44 // This structure is to pass and return sets of loop parameters without
45 // confusing the order.
46 struct LoopParams {
47   Value lowerBound;
48   Value upperBound;
49   Value step;
50 };
51 } // namespace
52 
53 /// Computes the cleanup loop lower bound of the loop being unrolled with
54 /// the specified unroll factor; this bound will also be upper bound of the main
55 /// part of the unrolled loop. Computes the bound as an AffineMap with its
56 /// operands or a null map when the trip count can't be expressed as an affine
57 /// expression.
getCleanupLoopLowerBound(AffineForOp forOp,unsigned unrollFactor,AffineMap & map,SmallVectorImpl<Value> & operands)58 static void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
59                                      AffineMap &map,
60                                      SmallVectorImpl<Value> &operands) {
61   auto lbMap = forOp.getLowerBoundMap();
62 
63   // Single result lower bound map only.
64   if (lbMap.getNumResults() != 1) {
65     map = AffineMap();
66     return;
67   }
68 
69   AffineMap tripCountMap;
70   SmallVector<Value, 4> tripCountOperands;
71   buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
72 
73   // Sometimes the trip count cannot be expressed as an affine expression.
74   if (!tripCountMap) {
75     map = AffineMap();
76     return;
77   }
78 
79   OpBuilder b(forOp);
80   auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
81                                     forOp.getLowerBoundOperands());
82 
83   // For each upper bound expr, get the range.
84   // Eg: affine.for %i = lb to min (ub1, ub2),
85   // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
86   // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
87   // these affine.apply's make up the cleanup loop lower bound.
88   SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
89   SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
90   int64_t step = forOp.getStep();
91   for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
92     auto tripCountExpr = tripCountMap.getResult(i);
93     bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
94     auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
95                                   tripCountMap.getNumSymbols(), bumpExprs[i]);
96     bumpValues[i] =
97         b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
98   }
99 
100   SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
101   for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
102     newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
103 
104   operands.clear();
105   operands.push_back(lb);
106   operands.append(bumpValues.begin(), bumpValues.end());
107   map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs,
108                        b.getContext());
109   // Simplify the map + operands.
110   fullyComposeAffineMapAndOperands(&map, &operands);
111   map = simplifyAffineMap(map);
112   canonicalizeMapAndOperands(&map, &operands);
113   // Remove any affine.apply's that became dead from the simplification above.
114   for (auto v : bumpValues)
115     if (v.use_empty())
116       v.getDefiningOp()->erase();
117 
118   if (lb.use_empty())
119     lb.erase();
120 }
121 
122 // Build the IR that performs ceil division of a positive value by a constant:
123 //    ceildiv(a, B) = divis(a + (B-1), B)
124 // where divis is rounding-to-zero division.
ceilDivPositive(OpBuilder & builder,Location loc,Value dividend,int64_t divisor)125 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
126                              int64_t divisor) {
127   assert(divisor > 0 && "expected positive divisor");
128   assert(dividend.getType().isIndex() && "expected index-typed value");
129 
130   Value divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
131   Value divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
132   Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
133   return builder.create<SignedDivIOp>(loc, sum, divisorCst);
134 }
135 
136 // Build the IR that performs ceil division of a positive value by another
137 // positive value:
138 //    ceildiv(a, b) = divis(a + (b - 1), b)
139 // where divis is rounding-to-zero division.
ceilDivPositive(OpBuilder & builder,Location loc,Value dividend,Value divisor)140 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
141                              Value divisor) {
142   assert(dividend.getType().isIndex() && "expected index-typed value");
143 
144   Value cstOne = builder.create<ConstantIndexOp>(loc, 1);
145   Value divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
146   Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
147   return builder.create<SignedDivIOp>(loc, sum, divisor);
148 }
149 
150 /// Promotes the loop body of a forOp to its containing block if the forOp
151 /// was known to have a single iteration.
152 // TODO: extend this for arbitrary affine bounds.
promoteIfSingleIteration(AffineForOp forOp)153 LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
154   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
155   if (!tripCount || tripCount.getValue() != 1)
156     return failure();
157 
158   if (forOp.getLowerBoundMap().getNumResults() != 1)
159     return failure();
160 
161   // Replaces all IV uses to its single iteration value.
162   auto iv = forOp.getInductionVar();
163   auto *parentBlock = forOp->getBlock();
164   if (!iv.use_empty()) {
165     if (forOp.hasConstantLowerBound()) {
166       OpBuilder topBuilder(forOp->getParentOfType<FuncOp>().getBody());
167       auto constOp = topBuilder.create<ConstantIndexOp>(
168           forOp.getLoc(), forOp.getConstantLowerBound());
169       iv.replaceAllUsesWith(constOp);
170     } else {
171       auto lbOperands = forOp.getLowerBoundOperands();
172       auto lbMap = forOp.getLowerBoundMap();
173       OpBuilder builder(parentBlock, Block::iterator(forOp));
174       if (lbMap == builder.getDimIdentityMap()) {
175         // No need of generating an affine.apply.
176         iv.replaceAllUsesWith(lbOperands[0]);
177       } else {
178         auto affineApplyOp =
179             builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
180         iv.replaceAllUsesWith(affineApplyOp);
181       }
182     }
183   }
184   // Move the loop body operations, except for its terminator, to the loop's
185   // containing block.
186   forOp.getBody()->back().erase();
187   parentBlock->getOperations().splice(Block::iterator(forOp),
188                                       forOp.getBody()->getOperations());
189   forOp.erase();
190   return success();
191 }
192 
193 /// Promotes the loop body of a forOp to its containing block if the forOp
194 /// it can be determined that the loop has a single iteration.
promoteIfSingleIteration(scf::ForOp forOp)195 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
196   auto lbCstOp = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
197   auto ubCstOp = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
198   auto stepCstOp = forOp.step().getDefiningOp<ConstantIndexOp>();
199   if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.getValue() < 0 ||
200       ubCstOp.getValue() < 0 || stepCstOp.getValue() < 0)
201     return failure();
202   int64_t tripCount = mlir::ceilDiv(ubCstOp.getValue() - lbCstOp.getValue(),
203                                     stepCstOp.getValue());
204   if (tripCount != 1)
205     return failure();
206   auto iv = forOp.getInductionVar();
207   iv.replaceAllUsesWith(lbCstOp);
208 
209   // Replace uses of iterArgs with iterOperands.
210   auto iterOperands = forOp.getIterOperands();
211   auto iterArgs = forOp.getRegionIterArgs();
212   for (auto e : llvm::zip(iterOperands, iterArgs))
213     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
214 
215   // Replace uses of loop results with the values yielded by the loop.
216   auto outerResults = forOp.getResults();
217   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
218   for (auto e : llvm::zip(outerResults, innerResults))
219     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
220 
221   // Move the loop body operations, except for its terminator, to the loop's
222   // containing block.
223   auto *parentBlock = forOp->getBlock();
224   forOp.getBody()->getTerminator()->erase();
225   parentBlock->getOperations().splice(Block::iterator(forOp),
226                                       forOp.getBody()->getOperations());
227   forOp.erase();
228   return success();
229 }
230 
231 /// Promotes all single iteration 'for' ops in `f`, i.e., moves
232 /// their body into the containing Block.
promoteSingleIterationLoops(FuncOp f)233 void mlir::promoteSingleIterationLoops(FuncOp f) {
234   // Gathers all innermost loops through a post order pruned walk.
235   f.walk([](Operation *op) {
236     if (auto forOp = dyn_cast<AffineForOp>(op))
237       promoteIfSingleIteration(forOp);
238     else if (auto forOp = dyn_cast<scf::ForOp>(op))
239       promoteIfSingleIteration(forOp);
240   });
241 }
242 
243 /// Generates an affine.for op with the specified lower and upper bounds
244 /// while generating the right IV remappings to realize shifts for operations in
245 /// its body. The operations that go into the loop body are specified in
246 /// opGroupQueue starting from the specified offset, and in that order. The
247 /// first element of the pair specifies the shift applied to that group of
248 /// operations; the shift is multiplied by the loop step before being applied.
249 /// Returns nullptr if the generated loop simplifies to a single iteration one.
generateShiftedLoop(AffineMap lbMap,AffineMap ubMap,const std::vector<std::pair<uint64_t,ArrayRef<Operation * >>> & opGroupQueue,unsigned offset,AffineForOp srcForOp,OpBuilder b)250 static AffineForOp generateShiftedLoop(
251     AffineMap lbMap, AffineMap ubMap,
252     const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> &opGroupQueue,
253     unsigned offset, AffineForOp srcForOp, OpBuilder b) {
254   auto lbOperands = srcForOp.getLowerBoundOperands();
255   auto ubOperands = srcForOp.getUpperBoundOperands();
256 
257   assert(lbMap.getNumInputs() == lbOperands.size());
258   assert(ubMap.getNumInputs() == ubOperands.size());
259 
260   auto loopChunk = b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap,
261                                          ubOperands, ubMap, srcForOp.getStep());
262   auto loopChunkIV = loopChunk.getInductionVar();
263   auto srcIV = srcForOp.getInductionVar();
264 
265   BlockAndValueMapping operandMap;
266 
267   auto bodyBuilder = OpBuilder::atBlockTerminator(loopChunk.getBody());
268   for (auto it = opGroupQueue.begin() + offset, e = opGroupQueue.end(); it != e;
269        ++it) {
270     uint64_t shift = it->first;
271     auto ops = it->second;
272     // All 'same shift' operations get added with their operands being
273     // remapped to results of cloned operations, and their IV used remapped.
274     // Generate the remapping if the shift is not zero: remappedIV = newIV -
275     // shift.
276     if (!srcIV.use_empty() && shift != 0) {
277       auto ivRemap = bodyBuilder.create<AffineApplyOp>(
278           srcForOp.getLoc(),
279           bodyBuilder.getSingleDimShiftAffineMap(
280               -static_cast<int64_t>(srcForOp.getStep() * shift)),
281           loopChunkIV);
282       operandMap.map(srcIV, ivRemap);
283     } else {
284       operandMap.map(srcIV, loopChunkIV);
285     }
286     for (auto *op : ops)
287       bodyBuilder.clone(*op, operandMap);
288   };
289   if (succeeded(promoteIfSingleIteration(loopChunk)))
290     return AffineForOp();
291   return loopChunk;
292 }
293 
294 // The skewing of operations with respect to one another can be used for
295 // example to allow overlap of asynchronous operations (such as DMA
296 // communication) with computation, or just relative shifting of operations
297 // for better register reuse, locality or parallelism. As such, the shifts are
298 // typically expected to be at most of the order of the number of operations.
299 // This method should not be used as a substitute for loop distribution/fission.
300 // This method uses an algorithm// in time linear in the number of operations
301 // in the body of the for loop - (using the 'sweep line' paradigm). This method
302 // asserts preservation of SSA dominance. A check for that as well as that for
303 // memory-based dependence preservation check rests with the users of this
304 // method.
affineForOpBodySkew(AffineForOp forOp,ArrayRef<uint64_t> shifts,bool unrollPrologueEpilogue)305 LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
306                                         ArrayRef<uint64_t> shifts,
307                                         bool unrollPrologueEpilogue) {
308   assert(forOp.getBody()->getOperations().size() == shifts.size() &&
309          "too few/many shifts");
310   if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
311     return success();
312 
313   // If the trip counts aren't constant, we would need versioning and
314   // conditional guards (or context information to prevent such versioning). The
315   // better way to pipeline for such loops is to first tile them and extract
316   // constant trip count "full tiles" before applying this.
317   auto mayBeConstTripCount = getConstantTripCount(forOp);
318   if (!mayBeConstTripCount.hasValue()) {
319     LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
320     return success();
321   }
322   uint64_t tripCount = mayBeConstTripCount.getValue();
323 
324   assert(isOpwiseShiftValid(forOp, shifts) &&
325          "shifts will lead to an invalid transformation\n");
326 
327   int64_t step = forOp.getStep();
328 
329   unsigned numChildOps = shifts.size();
330 
331   // Do a linear time (counting) sort for the shifts.
332   uint64_t maxShift = *std::max_element(shifts.begin(), shifts.end());
333   if (maxShift >= numChildOps) {
334     // Large shifts are not the typical use case.
335     forOp.emitWarning("not shifting because shifts are unrealistically large");
336     return success();
337   }
338 
339   // An array of operation groups sorted by shift amount; each group has all
340   // operations with the same shift in the order in which they appear in the
341   // body of the 'affine.for' op.
342   std::vector<std::vector<Operation *>> sortedOpGroups(maxShift + 1);
343   unsigned pos = 0;
344   for (auto &op : forOp.getBody()->without_terminator()) {
345     auto shift = shifts[pos++];
346     sortedOpGroups[shift].push_back(&op);
347   }
348 
349   // Unless the shifts have a specific pattern (which actually would be the
350   // common use case), prologue and epilogue are not meaningfully defined.
351   // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
352   // loop generated as the prologue and the last as epilogue and unroll these
353   // fully.
354   AffineForOp prologue, epilogue;
355 
356   // Do a sweep over the sorted shifts while storing open groups in a
357   // vector, and generating loop portions as necessary during the sweep. A block
358   // of operations is paired with its shift.
359   std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> opGroupQueue;
360 
361   auto origLbMap = forOp.getLowerBoundMap();
362   uint64_t lbShift = 0;
363   OpBuilder b(forOp);
364   for (uint64_t d = 0, e = sortedOpGroups.size(); d < e; ++d) {
365     // If nothing is shifted by d, continue.
366     if (sortedOpGroups[d].empty())
367       continue;
368     if (!opGroupQueue.empty()) {
369       assert(d > 0 &&
370              "Queue expected to be empty when the first block is found");
371       // The interval for which the loop needs to be generated here is:
372       // [lbShift, min(lbShift + tripCount, d)) and the body of the
373       // loop needs to have all operations in opQueue in that order.
374       AffineForOp res;
375       if (lbShift + tripCount * step < d * step) {
376         res = generateShiftedLoop(
377             b.getShiftedAffineMap(origLbMap, lbShift),
378             b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
379             opGroupQueue, /*offset=*/0, forOp, b);
380         // Entire loop for the queued op groups generated, empty it.
381         opGroupQueue.clear();
382         lbShift += tripCount * step;
383       } else {
384         res = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
385                                   b.getShiftedAffineMap(origLbMap, d),
386                                   opGroupQueue, /*offset=*/0, forOp, b);
387         lbShift = d * step;
388       }
389 
390       if (res) {
391         // Simplify/canonicalize the affine.for.
392         OwningRewritePatternList patterns;
393         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
394         bool erased;
395         applyOpPatternsAndFold(res, std::move(patterns), &erased);
396 
397         if (!erased && !prologue)
398           prologue = res;
399         if (!erased)
400           epilogue = res;
401       }
402     } else {
403       // Start of first interval.
404       lbShift = d * step;
405     }
406     // Augment the list of operations that get into the current open interval.
407     opGroupQueue.push_back({d, sortedOpGroups[d]});
408   }
409 
410   // Those operations groups left in the queue now need to be processed (FIFO)
411   // and their loops completed.
412   for (unsigned i = 0, e = opGroupQueue.size(); i < e; ++i) {
413     uint64_t ubShift = (opGroupQueue[i].first + tripCount) * step;
414     epilogue = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
415                                    b.getShiftedAffineMap(origLbMap, ubShift),
416                                    opGroupQueue, /*offset=*/i, forOp, b);
417     lbShift = ubShift;
418     if (!prologue)
419       prologue = epilogue;
420   }
421 
422   // Erase the original for op.
423   forOp.erase();
424 
425   if (unrollPrologueEpilogue && prologue)
426     loopUnrollFull(prologue);
427   if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
428     loopUnrollFull(epilogue);
429 
430   return success();
431 }
432 
433 /// Checks the legality of tiling of a hyper-rectangular loop nest by simply
434 /// checking if there is a 'negative' dependence in the memrefs present in
435 /// the loop nest. If yes then tiling is invalid.
436 static bool
checkTilingLegalityImpl(MutableArrayRef<mlir::AffineForOp> origLoops)437 checkTilingLegalityImpl(MutableArrayRef<mlir::AffineForOp> origLoops) {
438   assert(!origLoops.empty() && "no original loops provided");
439 
440   // We first find out all dependences we intend to check.
441   SmallVector<Operation *, 8> loadAndStoreOps;
442   origLoops[0]->walk([&](Operation *op) {
443     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
444       loadAndStoreOps.push_back(op);
445   });
446 
447   unsigned numOps = loadAndStoreOps.size();
448   unsigned numLoops = origLoops.size();
449   FlatAffineConstraints dependenceConstraints;
450   for (unsigned d = 1; d <= numLoops + 1; ++d) {
451     for (unsigned i = 0; i < numOps; ++i) {
452       Operation *srcOp = loadAndStoreOps[i];
453       MemRefAccess srcAccess(srcOp);
454       for (unsigned j = 0; j < numOps; ++j) {
455         Operation *dstOp = loadAndStoreOps[j];
456         MemRefAccess dstAccess(dstOp);
457 
458         SmallVector<DependenceComponent, 2> depComps;
459         dependenceConstraints.reset();
460         DependenceResult result = checkMemrefAccessDependence(
461             srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
462 
463         // Skip if there is no dependence in this case.
464         if (!hasDependence(result))
465           continue;
466 
467         // Check whether there is any negative direction vector in the
468         // dependence components found above, which means that dependence is
469         // violated by the default hyper-rect tiling method.
470         LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
471                                    "for dependence at depth: "
472                                 << Twine(d) << " between:\n";);
473         LLVM_DEBUG(srcAccess.opInst->dump(););
474         LLVM_DEBUG(dstAccess.opInst->dump(););
475         for (unsigned k = 0, e = depComps.size(); k < e; k++) {
476           DependenceComponent depComp = depComps[k];
477           if (depComp.lb.hasValue() && depComp.ub.hasValue() &&
478               depComp.lb.getValue() < depComp.ub.getValue() &&
479               depComp.ub.getValue() < 0) {
480             LLVM_DEBUG(llvm::dbgs()
481                        << "Dependence component lb = "
482                        << Twine(depComp.lb.getValue())
483                        << " ub = " << Twine(depComp.ub.getValue())
484                        << " is negative  at depth: " << Twine(d)
485                        << " and thus violates the legality rule.\n");
486             return false;
487           }
488         }
489       }
490     }
491   }
492 
493   return true;
494 }
495 
496 /// Checks whether hyper-rectangular loop tiling of the nest
497 /// represented by `origLoops` is valid. The validity condition is from Irigoin
498 /// and Triolet, which states that two tiles cannot depend on each other. We
499 /// simplify such condition to just checking whether there is any negative
500 /// dependence direction, since we have the prior knowledge that the tiling
501 /// results will be hyper-rectangles, which are scheduled in the
502 /// lexicographically increasing order on the vector of loop indices. This
503 /// function will return failure when any dependence component is negative along
504 /// any of `origLoops`.
505 LogicalResult
checkTilingLegality(MutableArrayRef<mlir::AffineForOp> origLoops)506 checkTilingLegality(MutableArrayRef<mlir::AffineForOp> origLoops) {
507   return success(checkTilingLegalityImpl(origLoops));
508 }
509 
510 /// Check if the input data is valid and wheter tiled code will be legal or not.
511 template <typename t>
performPreTilingChecks(MutableArrayRef<AffineForOp> input,ArrayRef<t> tileSizes)512 void performPreTilingChecks(MutableArrayRef<AffineForOp> input,
513                             ArrayRef<t> tileSizes) {
514   // Check if the supplied for op's are all successively nested.
515   assert(!input.empty() && "no loops in input band");
516   assert(input.size() == tileSizes.size() && "Too few/many tile sizes");
517 
518   assert(isPerfectlyNested(input) && "input loops not perfectly nested");
519 
520   // Perform tiling legality test.
521   if (failed(checkTilingLegality(input)))
522     input[0].emitRemark("tiled code is illegal due to dependences");
523 }
524 
525 /// Move the loop body of AffineForOp 'src' from 'src' into the specified
526 /// location in destination's body, ignoring the terminator.
moveLoopBodyImpl(AffineForOp src,AffineForOp dest,Block::iterator loc)527 static void moveLoopBodyImpl(AffineForOp src, AffineForOp dest,
528                              Block::iterator loc) {
529   auto &ops = src.getBody()->getOperations();
530   dest.getBody()->getOperations().splice(loc, ops, ops.begin(),
531                                          std::prev(ops.end()));
532 }
533 
534 /// Move the loop body of AffineForOp 'src' from 'src' to the start of dest
535 /// body.
moveLoopBody(AffineForOp src,AffineForOp dest)536 void moveLoopBody(AffineForOp src, AffineForOp dest) {
537   moveLoopBodyImpl(src, dest, dest.getBody()->begin());
538 }
539 
540 /// Constructs tiled loop nest, without setting the loop bounds and move the
541 /// body of the original loop nest to the tiled loop nest.
constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,AffineForOp rootAffineForOp,unsigned width,MutableArrayRef<AffineForOp> tiledLoops)542 void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
543                             AffineForOp rootAffineForOp, unsigned width,
544                             MutableArrayRef<AffineForOp> tiledLoops) {
545   Location loc = rootAffineForOp.getLoc();
546 
547   // The outermost among the loops as we add more..
548   Operation *topLoop = rootAffineForOp.getOperation();
549   AffineForOp innermostPointLoop;
550 
551   // Add intra-tile (or point) loops.
552   for (unsigned i = 0; i < width; i++) {
553     OpBuilder b(topLoop);
554     // Loop bounds will be set later.
555     AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0);
556     pointLoop.getBody()->getOperations().splice(
557         pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
558         topLoop);
559     tiledLoops[2 * width - 1 - i] = pointLoop;
560     topLoop = pointLoop.getOperation();
561     if (i == 0)
562       innermostPointLoop = pointLoop;
563   }
564 
565   // Add tile space loops;
566   for (unsigned i = width; i < 2 * width; i++) {
567     OpBuilder b(topLoop);
568     // Loop bounds will be set later.
569     AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
570     tileSpaceLoop.getBody()->getOperations().splice(
571         tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
572         topLoop);
573     tiledLoops[2 * width - i - 1] = tileSpaceLoop;
574     topLoop = tileSpaceLoop.getOperation();
575   }
576 
577   // Move the loop body of the original nest to the new one.
578   moveLoopBody(origLoops.back(), innermostPointLoop);
579 }
580 
581 /// Checks whether a loop nest is hyper-rectangular or not.
checkIfHyperRectangular(MutableArrayRef<AffineForOp> input,AffineForOp rootAffineForOp,unsigned width)582 LogicalResult checkIfHyperRectangular(MutableArrayRef<AffineForOp> input,
583                                       AffineForOp rootAffineForOp,
584                                       unsigned width) {
585   FlatAffineConstraints cst;
586   SmallVector<Operation *, 8> ops(input.begin(), input.end());
587   getIndexSet(ops, &cst);
588   if (!cst.isHyperRectangular(0, width)) {
589     rootAffineForOp.emitError("tiled code generation unimplemented for the "
590                               "non-hyperrectangular case");
591     return failure();
592   }
593   return success();
594 }
595 
596 /// Set lower and upper bounds of intra-tile loops for parametric tiling.
597 //  TODO: Handle non-constant lower bounds.
setIntraTileBoundsParametric(OpBuilder & b,AffineForOp origLoop,AffineForOp newInterTileLoop,AffineForOp newIntraTileLoop,Value tileSize)598 static void setIntraTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
599                                          AffineForOp newInterTileLoop,
600                                          AffineForOp newIntraTileLoop,
601                                          Value tileSize) {
602   // The lower bound for the intra-tile loop is represented by an affine map
603   // as (%i, %t0)->((%i - %origlb) * %t0 + %origlb). Similarly, the upper bound
604   // for the intra-tile loop is represented by an affine map as (%i, %t0)->((%i
605   // - %origlb) * %t0) + (%t0 * %origLoopStep) + %origlb), where %i is loop IV
606   // of the corresponding inter-tile loop, %t0 is the corresponding tiling
607   // parameter, %origlb is lower bound and %origLoopStep is the loop step of the
608   // corresponding inter-tile loop.
609 
610   assert(origLoop.hasConstantLowerBound() &&
611          "expected input loops to have constant lower bound.");
612 
613   // Get lower bound of original loop as an affine expression.
614   AffineExpr origLowerBoundExpr;
615   origLowerBoundExpr =
616       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
617 
618   // Add dim operands from original lower/upper bound.
619   SmallVector<Value, 4> lbOperands, ubOperands;
620   AffineBound lb = origLoop.getLowerBound();
621   AffineBound ub = origLoop.getUpperBound();
622   lbOperands.reserve(lb.getNumOperands() + 2);
623   ubOperands.reserve(ub.getNumOperands() + 2);
624   AffineMap origLbMap = lb.getMap();
625   AffineMap origUbMap = ub.getMap();
626   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
627     lbOperands.push_back(lb.getOperand(j));
628   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
629     ubOperands.push_back(ub.getOperand(j));
630 
631   // Add a new dim operand in lb/ubOperands corresponding to the origLoop
632   // IV.
633   lbOperands.push_back(newInterTileLoop.getInductionVar());
634   ubOperands.push_back(newInterTileLoop.getInductionVar());
635 
636   // Get loop IV as an affine expression for lower/upper bound. Size of
637   // lb/ubOperands is guaranteed to be atleast one.
638   AffineExpr lbLoopIvExpr = b.getAffineDimExpr(lbOperands.size() - 1);
639   AffineExpr ubLoopIvExpr = b.getAffineDimExpr(ubOperands.size() - 1);
640 
641   // Add symbol operands from original lower/upper bound.
642   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
643     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
644   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
645     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
646 
647   // Add a new symbol operand which is the tile size for this loop.
648   lbOperands.push_back(tileSize);
649   ubOperands.push_back(tileSize);
650 
651   SmallVector<AffineExpr, 4> lbBoundExprs;
652   SmallVector<AffineExpr, 4> ubBoundExprs;
653   lbBoundExprs.reserve(origLbMap.getNumResults());
654   ubBoundExprs.reserve(origUbMap.getNumResults());
655 
656   // Get tiling parameter as an affine expression for lb/ub.
657   AffineExpr lbTileParameter = b.getAffineSymbolExpr(origLbMap.getNumSymbols());
658   AffineExpr ubTileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
659 
660   // Insert lb as inter-tile ((loop IV - origlb) * tilingParameter) + origlb.
661   lbBoundExprs.push_back(
662       ((lbLoopIvExpr - origLowerBoundExpr) * lbTileParameter) +
663       origLowerBoundExpr);
664 
665   // Get the origLoopStep as an affine expression.
666   AffineExpr origLoopStep = b.getAffineConstantExpr(origLoop.getStep());
667 
668   // Insert ub as inter-tile ((loop IV - origlb) * tilingParameter) +
669   // (tilingParameter * origLoopStep) + origlb.
670   ubBoundExprs.push_back(
671       ((ubLoopIvExpr - origLowerBoundExpr) * ubTileParameter) +
672       (ubTileParameter * origLoopStep) + origLowerBoundExpr);
673 
674   ubBoundExprs.append(origUbMap.getResults().begin(),
675                       origUbMap.getResults().end());
676 
677   AffineMap lbMap =
678       AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols() + 1,
679                      lbBoundExprs, b.getContext());
680   newIntraTileLoop.setLowerBound(lbOperands, lbMap);
681 
682   AffineMap ubMap =
683       AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols() + 1,
684                      ubBoundExprs, b.getContext());
685   newIntraTileLoop.setUpperBound(ubOperands, ubMap);
686 
687   // Original loop step must be preserved.
688   newIntraTileLoop.setStep(origLoop.getStep());
689 }
690 
691 /// Set lower and upper bounds of inter-tile loops for parametric tiling.
692 //  TODO: Handle non-constant lower bounds.
setInterTileBoundsParametric(OpBuilder & b,AffineForOp origLoop,AffineForOp newLoop,Value tileSize)693 static void setInterTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
694                                          AffineForOp newLoop, Value tileSize) {
695   OperandRange newLbOperands = origLoop.getLowerBoundOperands();
696 
697   // The lower bounds for inter-tile loops are same as the corresponding lower
698   // bounds of original loops.
699   newLoop.setLowerBound(newLbOperands, origLoop.getLowerBoundMap());
700 
701   // The new upper bound map for inter-tile loops, assuming constant lower
702   // bounds, are now originalLowerBound + ceildiv((originalUpperBound -
703   // originalLowerBound), tiling parameter); where tiling parameter is the
704   // respective tile size for that loop. For e.g. if the original ubmap was
705   // ()->(1024), the new map will be
706   // ()[s0]->(ceildiv((1024 -lb) % s0)), where s0 is the tiling parameter.
707   // Therefore a new symbol operand is inserted in the map and the result
708   // expression is overwritten.
709 
710   assert(origLoop.hasConstantLowerBound() &&
711          "expected input loops to have constant lower bound.");
712 
713   // Get lower bound of original loop as an affine expression.
714   AffineExpr origLowerBoundExpr;
715   origLowerBoundExpr =
716       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
717 
718   // Add dim operands from original upper bound.
719   SmallVector<Value, 4> ubOperands;
720   AffineBound ub = origLoop.getUpperBound();
721   ubOperands.reserve(ub.getNumOperands() + 1);
722   AffineMap origUbMap = ub.getMap();
723   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
724     ubOperands.push_back(ub.getOperand(j));
725 
726   // Add symbol operands from original upper bound.
727   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
728     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
729 
730   // Add a new symbol operand which is the tile size for this loop.
731   ubOperands.push_back(tileSize);
732 
733   // Get tiling parameter as an affine expression.
734   AffineExpr tileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
735 
736   SmallVector<AffineExpr, 4> boundExprs;
737   boundExprs.reserve(origUbMap.getNumResults());
738   int64_t origUpperBound;
739   AffineExpr origUpperBoundExpr;
740 
741   // If upper bound for the original loop is constant, then the constant can
742   // be obtained as an affine expression straight away.
743   if (origLoop.hasConstantUpperBound()) {
744     origUpperBound = origLoop.getConstantUpperBound();
745 
746     // Get original constant upper bound as an affine expression.
747     origUpperBoundExpr = b.getAffineConstantExpr(origUpperBound);
748 
749     // Insert the bound as originalLowerBoundceildiv((originalUpperBound -
750     // originalLowerBound), tilingParameter).
751     boundExprs.push_back(
752         origLowerBoundExpr +
753         (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
754   } else {
755     // If upper bound for the original loop is not constant then two cases
756     // are possible, although there handeling is the same, 1.) The result of
757     // ubmap has only one result expression. For e.g.
758     //    affine.for %i = 5 to %ub
759     //
760     // A symbol operand is added which represents the tiling parameter. The
761     // new loop bounds here will be like ()[s0, s1] -> ((s0 - 5) ceildiv s1 + 5)
762     // where 's0' is the original upper bound and 's1' is the tiling
763     // parameter. 2.) When ubMap has more than one result expression. For e.g.
764     //    #map0 = affine_map<()[s0, s1] -> (s0, s1)
765     //    affine.for %i = 5 to min #map0()[%s0, %s1]
766     //
767     // A symbol operand is added which represents the tiling parameter. The
768     // new loop bounds will be like ()[s0, s1, s2] -> ((s0 - 5) ceildiv s2 + 5,
769     // (s1 -5) ceildiv s2 + 5), where s2 is the tiling parameter.
770 
771     // Insert the bounds as originalLowerBound + ceildiv((originalUpperBound -
772     // originalLowerBound), tilingParameter).
773     for (AffineExpr origUpperBoundExpr : origUbMap.getResults())
774       boundExprs.push_back(
775           origLowerBoundExpr +
776           (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
777   }
778 
779   AffineMap ubMap =
780       AffineMap::get(origUbMap.getNumDims(), origUbMap.getNumSymbols() + 1,
781                      boundExprs, b.getContext());
782   newLoop.setUpperBound(ubOperands, ubMap);
783 
784   // Original loop step must be preserved.
785   newLoop.setStep(origLoop.getStep());
786 }
787 
788 /// Constructs and sets new loop bounds after tiling for the case of
789 /// hyper-rectangular index sets, where the bounds of one dimension do not
790 /// depend on other dimensions and tiling parameters are captured from SSA
791 /// values. Bounds of each dimension can thus be treated independently,
792 /// and deriving the new bounds is much simpler and faster than for the case of
793 /// tiling arbitrary polyhedral shapes.
constructParametricallyTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,MutableArrayRef<AffineForOp> newLoops,ArrayRef<Value> tileSizes)794 static void constructParametricallyTiledIndexSetHyperRect(
795     MutableArrayRef<AffineForOp> origLoops,
796     MutableArrayRef<AffineForOp> newLoops, ArrayRef<Value> tileSizes) {
797   assert(!origLoops.empty() && "expected atleast one loop in band");
798   assert(origLoops.size() == tileSizes.size() &&
799          "expected tiling parameter for each loop in band.");
800 
801   OpBuilder b(origLoops[0].getOperation());
802   unsigned width = origLoops.size();
803 
804   // Set bounds for tile space loops.
805   for (unsigned i = 0; i < width; ++i) {
806     setInterTileBoundsParametric(b, origLoops[i], newLoops[i], tileSizes[i]);
807   }
808 
809   // Set bounds for intra-tile loops.
810   for (unsigned i = 0; i < width; ++i) {
811     setIntraTileBoundsParametric(b, origLoops[i], newLoops[i],
812                                  newLoops[i + width], tileSizes[i]);
813   }
814 }
815 
816 /// Constructs and sets new loop bounds after tiling for the case of
817 /// hyper-rectangular index sets, where the bounds of one dimension do not
818 /// depend on other dimensions. Bounds of each dimension can thus be treated
819 /// independently, and deriving the new bounds is much simpler and faster
820 /// than for the case of tiling arbitrary polyhedral shapes.
821 static void
constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,MutableArrayRef<AffineForOp> newLoops,ArrayRef<unsigned> tileSizes)822 constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
823                                 MutableArrayRef<AffineForOp> newLoops,
824                                 ArrayRef<unsigned> tileSizes) {
825   assert(!origLoops.empty());
826   assert(origLoops.size() == tileSizes.size());
827 
828   OpBuilder b(origLoops[0].getOperation());
829   unsigned width = origLoops.size();
830 
831   // Bounds for tile space loops.
832   for (unsigned i = 0; i < width; i++) {
833     OperandRange newLbOperands = origLoops[i].getLowerBoundOperands();
834     OperandRange newUbOperands = origLoops[i].getUpperBoundOperands();
835     newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
836     newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
837     newLoops[i].setStep(tileSizes[i]);
838   }
839   // Bounds for intra-tile loops.
840   for (unsigned i = 0; i < width; i++) {
841     int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
842     Optional<uint64_t> mayBeConstantCount = getConstantTripCount(origLoops[i]);
843     // The lower bound is just the tile-space loop.
844     AffineMap lbMap = b.getDimIdentityMap();
845     newLoops[width + i].setLowerBound(
846         /*operands=*/newLoops[i].getInductionVar(), lbMap);
847 
848     // Set the upper bound.
849     if (mayBeConstantCount && mayBeConstantCount.getValue() < tileSizes[i]) {
850       // Trip count is less than the tile size: upper bound is lower bound +
851       // trip count.
852       AffineMap ubMap =
853           b.getSingleDimShiftAffineMap(mayBeConstantCount.getValue());
854       newLoops[width + i].setUpperBound(
855           /*operands=*/newLoops[i].getInductionVar(), ubMap);
856     } else if (largestDiv % tileSizes[i] != 0) {
857       // Intra-tile loop ii goes from i to min(i + tileSize, ub_i).
858       // Construct the upper bound map; the operands are the original operands
859       // with 'i' (tile-space loop) appended to it. The new upper bound map is
860       // the original one with an additional expression i + tileSize appended.
861 
862       // Add dim operands from original upper bound.
863       SmallVector<Value, 4> ubOperands;
864       AffineBound ub = origLoops[i].getUpperBound();
865       ubOperands.reserve(ub.getNumOperands() + 1);
866       AffineMap origUbMap = ub.getMap();
867       for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
868         ubOperands.push_back(ub.getOperand(j));
869 
870       // Add dim operand for new loop upper bound.
871       ubOperands.push_back(newLoops[i].getInductionVar());
872 
873       // Add symbol operands from original upper bound.
874       for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
875         ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
876 
877       SmallVector<AffineExpr, 4> boundExprs;
878       boundExprs.reserve(1 + origUbMap.getNumResults());
879       AffineExpr dim = b.getAffineDimExpr(origUbMap.getNumDims());
880       // The new upper bound map is the original one with an additional
881       // expression i + tileSize appended.
882       boundExprs.push_back(dim + tileSizes[i]);
883       boundExprs.append(origUbMap.getResults().begin(),
884                         origUbMap.getResults().end());
885       AffineMap ubMap =
886           AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(),
887                          boundExprs, b.getContext());
888       newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
889     } else {
890       // No need of the min expression.
891       AffineExpr dim = b.getAffineDimExpr(0);
892       AffineMap ubMap = AffineMap::get(1, 0, dim + tileSizes[i]);
893       newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap);
894     }
895   }
896 }
897 
898 /// Tiles the specified band of perfectly nested loops creating tile-space loops
899 /// and intra-tile loops. A band is a contiguous set of loops.
900 //  TODO: handle non hyper-rectangular spaces.
901 LogicalResult
tilePerfectlyNested(MutableArrayRef<AffineForOp> input,ArrayRef<unsigned> tileSizes,SmallVectorImpl<AffineForOp> * tiledNest)902 mlir::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
903                           ArrayRef<unsigned> tileSizes,
904                           SmallVectorImpl<AffineForOp> *tiledNest) {
905   performPreTilingChecks(input, tileSizes);
906 
907   MutableArrayRef<AffineForOp> origLoops = input;
908   AffineForOp rootAffineForOp = origLoops[0];
909   // Note that width is at least one since band isn't empty.
910   unsigned width = input.size();
911   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
912 
913   // Construct a tiled loop nest without setting their bounds. Bounds are
914   // set later.
915   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
916 
917   SmallVector<Value, 8> origLoopIVs;
918   extractForInductionVars(input, &origLoopIVs);
919 
920   if (failed(checkIfHyperRectangular(input, rootAffineForOp, width)))
921     return failure();
922 
923   // Set loop bounds for the tiled loop nest.
924   constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes);
925 
926   // Replace original IVs with intra-tile loop IVs.
927   for (unsigned i = 0; i < width; i++)
928     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
929 
930   // Erase the old loop nest.
931   rootAffineForOp.erase();
932 
933   if (tiledNest)
934     *tiledNest = std::move(tiledLoops);
935 
936   return success();
937 }
938 
939 /// Tiles the specified band of perfectly nested loops creating tile-space
940 /// loops and intra-tile loops, using SSA values as tiling parameters. A band
941 /// is a contiguous set of loops.
942 //  TODO: handle non hyper-rectangular spaces.
943 LogicalResult
tilePerfectlyNestedParametric(MutableArrayRef<AffineForOp> input,ArrayRef<Value> tileSizes,SmallVectorImpl<AffineForOp> * tiledNest)944 mlir::tilePerfectlyNestedParametric(MutableArrayRef<AffineForOp> input,
945                                     ArrayRef<Value> tileSizes,
946                                     SmallVectorImpl<AffineForOp> *tiledNest) {
947   performPreTilingChecks(input, tileSizes);
948 
949   MutableArrayRef<AffineForOp> origLoops = input;
950   AffineForOp rootAffineForOp = origLoops[0];
951   // Note that width is at least one since band isn't empty.
952   unsigned width = input.size();
953   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
954 
955   // Construct a tiled loop nest without setting their bounds. Bounds are
956   // set later.
957   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
958 
959   SmallVector<Value, 8> origLoopIVs;
960   extractForInductionVars(input, &origLoopIVs);
961 
962   if (failed(checkIfHyperRectangular(input, rootAffineForOp, width)))
963     return failure();
964 
965   // Set loop bounds for the tiled loop nest.
966   constructParametricallyTiledIndexSetHyperRect(origLoops, tiledLoops,
967                                                 tileSizes);
968 
969   // Replace original IVs with intra-tile loop IVs.
970   for (unsigned i = 0; i < width; i++)
971     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
972 
973   // Erase the old loop nest.
974   rootAffineForOp.erase();
975 
976   if (tiledNest)
977     *tiledNest = std::move(tiledLoops);
978 
979   return success();
980 }
981 
982 /// Collect perfectly nested loops starting from `rootForOps`.  Loops are
983 /// perfectly nested if each loop is the first and only non-terminator operation
984 /// in the parent loop.  Collect at most `maxLoops` loops and append them to
985 /// `forOps`.
986 template <typename T>
getPerfectlyNestedLoopsImpl(SmallVectorImpl<T> & forOps,T rootForOp,unsigned maxLoops=std::numeric_limits<unsigned>::max ())987 static void getPerfectlyNestedLoopsImpl(
988     SmallVectorImpl<T> &forOps, T rootForOp,
989     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
990   for (unsigned i = 0; i < maxLoops; ++i) {
991     forOps.push_back(rootForOp);
992     Block &body = rootForOp.region().front();
993     if (body.begin() != std::prev(body.end(), 2))
994       return;
995 
996     rootForOp = dyn_cast<T>(&body.front());
997     if (!rootForOp)
998       return;
999   }
1000 }
1001 
1002 /// Get perfectly nested sequence of loops starting at root of loop nest
1003 /// (the first op being another AffineFor, and the second op - a terminator).
1004 /// A loop is perfectly nested iff: the first op in the loop's body is another
1005 /// AffineForOp, and the second op is a terminator).
getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> & nestedLoops,AffineForOp root)1006 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
1007                                    AffineForOp root) {
1008   getPerfectlyNestedLoopsImpl(nestedLoops, root);
1009 }
1010 
getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> & nestedLoops,scf::ForOp root)1011 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
1012                                    scf::ForOp root) {
1013   getPerfectlyNestedLoopsImpl(nestedLoops, root);
1014 }
1015 
1016 /// Identify valid and profitable bands of loops to tile. This is currently just
1017 /// a temporary placeholder to test the mechanics of tiled code generation.
1018 /// Returns all maximal outermost perfect loop nests to tile.
getTileableBands(FuncOp f,std::vector<SmallVector<AffineForOp,6>> * bands)1019 void mlir::getTileableBands(FuncOp f,
1020                             std::vector<SmallVector<AffineForOp, 6>> *bands) {
1021   // Get maximal perfect nest of 'affine.for' insts starting from root
1022   // (inclusive).
1023   for (AffineForOp forOp : f.getOps<AffineForOp>()) {
1024     SmallVector<AffineForOp, 6> band;
1025     getPerfectlyNestedLoops(band, forOp);
1026     bands->push_back(band);
1027   }
1028 }
1029 
1030 /// Unrolls this loop completely.
loopUnrollFull(AffineForOp forOp)1031 LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
1032   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1033   if (mayBeConstantTripCount.hasValue()) {
1034     uint64_t tripCount = mayBeConstantTripCount.getValue();
1035     if (tripCount == 1)
1036       return promoteIfSingleIteration(forOp);
1037     return loopUnrollByFactor(forOp, tripCount);
1038   }
1039   return failure();
1040 }
1041 
1042 /// Unrolls this loop by the specified factor or by the trip count (if constant)
1043 /// whichever is lower.
loopUnrollUpToFactor(AffineForOp forOp,uint64_t unrollFactor)1044 LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
1045                                          uint64_t unrollFactor) {
1046   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1047   if (mayBeConstantTripCount.hasValue() &&
1048       mayBeConstantTripCount.getValue() < unrollFactor)
1049     return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
1050   return loopUnrollByFactor(forOp, unrollFactor);
1051 }
1052 
1053 /// Generates unrolled copies of AffineForOp or scf::ForOp 'loopBodyBlock', with
1054 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
1055 /// 'forOpIV' for each unrolled body.
1056 static void
generateUnrolledLoop(Block * loopBodyBlock,Value forOpIV,uint64_t unrollFactor,function_ref<Value (unsigned,Value,OpBuilder)> ivRemapFn,ValueRange iterArgs,ValueRange yieldedValues)1057 generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
1058                      function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
1059                      ValueRange iterArgs, ValueRange yieldedValues) {
1060   // Builder to insert unrolled bodies just before the terminator of the body of
1061   // 'forOp'.
1062   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
1063 
1064   // Keep a pointer to the last non-terminator operation in the original block
1065   // so that we know what to clone (since we are doing this in-place).
1066   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
1067 
1068   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
1069   SmallVector<Value, 4> lastYielded(yieldedValues);
1070 
1071   for (unsigned i = 1; i < unrollFactor; i++) {
1072     BlockAndValueMapping operandMap;
1073 
1074     // Prepare operand map.
1075     operandMap.map(iterArgs, lastYielded);
1076 
1077     // If the induction variable is used, create a remapping to the value for
1078     // this unrolled instance.
1079     if (!forOpIV.use_empty()) {
1080       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
1081       operandMap.map(forOpIV, ivUnroll);
1082     }
1083 
1084     // Clone the original body of 'forOp'.
1085     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
1086       builder.clone(*it, operandMap);
1087 
1088     // Update yielded values.
1089     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
1090       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
1091   }
1092 
1093   // Update operands of the yield statement.
1094   loopBodyBlock->getTerminator()->setOperands(lastYielded);
1095 }
1096 
1097 /// Unrolls this loop by the specified factor. Returns success if the loop
1098 /// is successfully unrolled.
loopUnrollByFactor(AffineForOp forOp,uint64_t unrollFactor)1099 LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
1100                                        uint64_t unrollFactor) {
1101   assert(unrollFactor > 0 && "unroll factor should be positive");
1102 
1103   if (unrollFactor == 1)
1104     return promoteIfSingleIteration(forOp);
1105 
1106   // Nothing in the loop body other than the terminator.
1107   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1108     return success();
1109 
1110   // Loops where the lower bound is a max expression isn't supported for
1111   // unrolling since the trip count can be expressed as an affine function when
1112   // both the lower bound and the upper bound are multi-result maps. However,
1113   // one meaningful way to do such unrolling would be to specialize the loop for
1114   // the 'hotspot' case and unroll that hotspot.
1115   if (forOp.getLowerBoundMap().getNumResults() != 1)
1116     return failure();
1117 
1118   // If the trip count is lower than the unroll factor, no unrolled body.
1119   // TODO: option to specify cleanup loop unrolling.
1120   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1121   if (mayBeConstantTripCount.hasValue() &&
1122       mayBeConstantTripCount.getValue() < unrollFactor)
1123     return failure();
1124 
1125   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
1126   if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
1127     OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
1128     auto cleanupForOp = cast<AffineForOp>(builder.clone(*forOp));
1129     AffineMap cleanupMap;
1130     SmallVector<Value, 4> cleanupOperands;
1131     getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands);
1132     assert(cleanupMap &&
1133            "cleanup loop lower bound map for single result lower bound maps "
1134            "can always be determined");
1135     cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
1136     // Promote the loop body up if this has turned into a single iteration loop.
1137     promoteIfSingleIteration(cleanupForOp);
1138 
1139     // Adjust upper bound of the original loop; this is the same as the lower
1140     // bound of the cleanup loop.
1141     forOp.setUpperBound(cleanupOperands, cleanupMap);
1142   }
1143 
1144   // Scale the step of loop being unrolled by unroll factor.
1145   int64_t step = forOp.getStep();
1146   forOp.setStep(step * unrollFactor);
1147   generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor,
1148                        [&](unsigned i, Value iv, OpBuilder b) {
1149                          // iv' = iv + i * step
1150                          auto d0 = b.getAffineDimExpr(0);
1151                          auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1152                          return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap,
1153                                                         iv);
1154                        },
1155                        /*iterArgs=*/{}, /*yieldedValues=*/{});
1156 
1157   // Promote the loop body up if this has turned into a single iteration loop.
1158   promoteIfSingleIteration(forOp);
1159   return success();
1160 }
1161 
1162 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
loopUnrollByFactor(scf::ForOp forOp,uint64_t unrollFactor)1163 LogicalResult mlir::loopUnrollByFactor(scf::ForOp forOp,
1164                                        uint64_t unrollFactor) {
1165   assert(unrollFactor > 0 && "expected positive unroll factor");
1166   if (unrollFactor == 1)
1167     return promoteIfSingleIteration(forOp);
1168 
1169   // Return if the loop body is empty.
1170   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1171     return success();
1172 
1173   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
1174   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
1175   OpBuilder boundsBuilder(forOp);
1176   auto loc = forOp.getLoc();
1177   auto step = forOp.step();
1178   Value upperBoundUnrolled;
1179   Value stepUnrolled;
1180   bool generateEpilogueLoop = true;
1181 
1182   auto lbCstOp = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
1183   auto ubCstOp = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
1184   auto stepCstOp = forOp.step().getDefiningOp<ConstantIndexOp>();
1185   if (lbCstOp && ubCstOp && stepCstOp) {
1186     // Constant loop bounds computation.
1187     int64_t lbCst = lbCstOp.getValue();
1188     int64_t ubCst = ubCstOp.getValue();
1189     int64_t stepCst = stepCstOp.getValue();
1190     assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
1191            "expected positive loop bounds and step");
1192     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
1193     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
1194     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
1195     assert(upperBoundUnrolledCst <= ubCst);
1196     int64_t stepUnrolledCst = stepCst * unrollFactor;
1197 
1198     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
1199     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
1200     if (generateEpilogueLoop)
1201       upperBoundUnrolled =
1202           boundsBuilder.create<ConstantIndexOp>(loc, upperBoundUnrolledCst);
1203     else
1204       upperBoundUnrolled = ubCstOp;
1205 
1206     // Create constant for 'stepUnrolled'.
1207     stepUnrolled =
1208         stepCst == stepUnrolledCst
1209             ? step
1210             : boundsBuilder.create<ConstantIndexOp>(loc, stepUnrolledCst);
1211   } else {
1212     // Dynamic loop bounds computation.
1213     // TODO: Add dynamic asserts for negative lb/ub/step, or
1214     // consider using ceilDiv from AffineApplyExpander.
1215     auto lowerBound = forOp.lowerBound();
1216     auto upperBound = forOp.upperBound();
1217     Value diff = boundsBuilder.create<SubIOp>(loc, upperBound, lowerBound);
1218     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
1219     Value unrollFactorCst =
1220         boundsBuilder.create<ConstantIndexOp>(loc, unrollFactor);
1221     Value tripCountRem =
1222         boundsBuilder.create<SignedRemIOp>(loc, tripCount, unrollFactorCst);
1223     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
1224     Value tripCountEvenMultiple =
1225         boundsBuilder.create<SubIOp>(loc, tripCount, tripCountRem);
1226     // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
1227     upperBoundUnrolled = boundsBuilder.create<AddIOp>(
1228         loc, lowerBound,
1229         boundsBuilder.create<MulIOp>(loc, tripCountEvenMultiple, step));
1230     // Scale 'step' by 'unrollFactor'.
1231     stepUnrolled = boundsBuilder.create<MulIOp>(loc, step, unrollFactorCst);
1232   }
1233 
1234   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
1235   if (generateEpilogueLoop) {
1236     OpBuilder epilogueBuilder(forOp->getBlock(),
1237                               std::next(Block::iterator(forOp)));
1238     auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
1239     epilogueForOp.setLowerBound(upperBoundUnrolled);
1240 
1241     // Update uses of loop results.
1242     auto results = forOp.getResults();
1243     auto epilogueResults = epilogueForOp.getResults();
1244     auto epilogueIterOperands = epilogueForOp.getIterOperands();
1245 
1246     for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) {
1247       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
1248       epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
1249     }
1250     promoteIfSingleIteration(epilogueForOp);
1251   }
1252 
1253   // Create unrolled loop.
1254   forOp.setUpperBound(upperBoundUnrolled);
1255   forOp.setStep(stepUnrolled);
1256 
1257   auto iterArgs = ValueRange(forOp.getRegionIterArgs());
1258   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
1259 
1260   generateUnrolledLoop(
1261       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
1262       [&](unsigned i, Value iv, OpBuilder b) {
1263         // iv' = iv + step * i;
1264         auto stride =
1265             b.create<MulIOp>(loc, step, b.create<ConstantIndexOp>(loc, i));
1266         return b.create<AddIOp>(loc, iv, stride);
1267       },
1268       iterArgs, yieldedValues);
1269   // Promote the loop body up if this has turned into a single iteration loop.
1270   promoteIfSingleIteration(forOp);
1271   return success();
1272 }
1273 
loopUnrollJamUpToFactor(AffineForOp forOp,uint64_t unrollJamFactor)1274 LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
1275                                             uint64_t unrollJamFactor) {
1276   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1277   if (mayBeConstantTripCount.hasValue() &&
1278       mayBeConstantTripCount.getValue() < unrollJamFactor)
1279     return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
1280   return loopUnrollJamByFactor(forOp, unrollJamFactor);
1281 }
1282 
1283 /// Unrolls and jams this loop by the specified factor.
loopUnrollJamByFactor(AffineForOp forOp,uint64_t unrollJamFactor)1284 LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
1285                                           uint64_t unrollJamFactor) {
1286   // Gathers all maximal sub-blocks of operations that do not themselves
1287   // include a for op (a operation could have a descendant for op though
1288   // in its tree).  Ignore the block terminators.
1289   struct JamBlockGatherer {
1290     // Store iterators to the first and last op of each sub-block found.
1291     std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
1292 
1293     // This is a linear time walk.
1294     void walk(Operation *op) {
1295       for (auto &region : op->getRegions())
1296         for (auto &block : region)
1297           walk(block);
1298     }
1299 
1300     void walk(Block &block) {
1301       for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
1302         auto subBlockStart = it;
1303         while (it != e && !isa<AffineForOp>(&*it))
1304           ++it;
1305         if (it != subBlockStart)
1306           subBlocks.push_back({subBlockStart, std::prev(it)});
1307         // Process all for ops that appear next.
1308         while (it != e && isa<AffineForOp>(&*it))
1309           walk(&*it++);
1310       }
1311     }
1312   };
1313 
1314   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
1315 
1316   if (unrollJamFactor == 1)
1317     return promoteIfSingleIteration(forOp);
1318 
1319   // Nothing in the loop body other than the terminator.
1320   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1321     return success();
1322 
1323   // Loops where both lower and upper bounds are multi-result maps won't be
1324   // unrolled (since the trip can't be expressed as an affine function in
1325   // general).
1326   // TODO: this may not be common, but we could support the case
1327   // where the lower bound is a multi-result map and the ub is a single result
1328   // one.
1329   if (forOp.getLowerBoundMap().getNumResults() != 1)
1330     return failure();
1331 
1332   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1333   // If the trip count is lower than the unroll jam factor, no unroll jam.
1334   if (mayBeConstantTripCount.hasValue() &&
1335       mayBeConstantTripCount.getValue() < unrollJamFactor) {
1336     LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
1337     return failure();
1338   }
1339 
1340   // Gather all sub-blocks to jam upon the loop being unrolled.
1341   JamBlockGatherer jbg;
1342   jbg.walk(forOp);
1343   auto &subBlocks = jbg.subBlocks;
1344 
1345   // Generate the cleanup loop if trip count isn't a multiple of
1346   // unrollJamFactor.
1347   if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
1348     // Insert the cleanup loop right after 'forOp'.
1349     OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
1350     auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forOp));
1351     // Adjust the lower bound of the cleanup loop; its upper bound is the same
1352     // as the original loop's upper bound.
1353     AffineMap cleanupMap;
1354     SmallVector<Value, 4> cleanupOperands;
1355     getCleanupLoopLowerBound(forOp, unrollJamFactor, cleanupMap,
1356                              cleanupOperands);
1357     cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap);
1358 
1359     // Promote the cleanup loop if it has turned into a single iteration loop.
1360     promoteIfSingleIteration(cleanupAffineForOp);
1361 
1362     // Adjust the upper bound of the original loop - it will be the same as the
1363     // cleanup loop's lower bound. Its lower bound remains unchanged.
1364     forOp.setUpperBound(cleanupOperands, cleanupMap);
1365   }
1366 
1367   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
1368   int64_t step = forOp.getStep();
1369   forOp.setStep(step * unrollJamFactor);
1370 
1371   auto forOpIV = forOp.getInductionVar();
1372   // Unroll and jam (appends unrollJamFactor - 1 additional copies).
1373   for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1374     // Operand map persists across all sub-blocks.
1375     BlockAndValueMapping operandMap;
1376     for (auto &subBlock : subBlocks) {
1377       // Builder to insert unroll-jammed bodies. Insert right at the end of
1378       // sub-block.
1379       OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
1380 
1381       // If the induction variable is used, create a remapping to the value for
1382       // this unrolled instance.
1383       if (!forOpIV.use_empty()) {
1384         // iv' = iv + i, i = 1 to unrollJamFactor-1.
1385         auto d0 = builder.getAffineDimExpr(0);
1386         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1387         auto ivUnroll =
1388             builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
1389         operandMap.map(forOpIV, ivUnroll);
1390       }
1391       // Clone the sub-block being unroll-jammed.
1392       for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
1393         builder.clone(*it, operandMap);
1394     }
1395   }
1396 
1397   // Promote the loop body up if this has turned into a single iteration loop.
1398   promoteIfSingleIteration(forOp);
1399   return success();
1400 }
1401 
1402 /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
1403 /// nested within 'forOpA' as the only non-terminator operation in its block.
interchangeLoops(AffineForOp forOpA,AffineForOp forOpB)1404 void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
1405   assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
1406   auto &forOpABody = forOpA.getBody()->getOperations();
1407   auto &forOpBBody = forOpB.getBody()->getOperations();
1408 
1409   // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
1410   // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
1411   // body containing only the terminator.
1412   forOpA->getBlock()->getOperations().splice(Block::iterator(forOpA),
1413                                              forOpABody, forOpABody.begin(),
1414                                              std::prev(forOpABody.end()));
1415   // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
1416   // body (this leaves forOpB's body containing only the terminator).
1417   forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
1418                     std::prev(forOpBBody.end()));
1419   // 3) Splice forOpA into the beginning of forOpB's body.
1420   forOpBBody.splice(forOpBBody.begin(), forOpA->getBlock()->getOperations(),
1421                     Block::iterator(forOpA));
1422 }
1423 
1424 // Checks each dependence component against the permutation to see if the
1425 // desired loop interchange would violate dependences by making the
1426 // dependence component lexicographically negative.
checkLoopInterchangeDependences(const std::vector<SmallVector<DependenceComponent,2>> & depCompsVec,ArrayRef<AffineForOp> loops,ArrayRef<unsigned> loopPermMap)1427 static bool checkLoopInterchangeDependences(
1428     const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
1429     ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
1430   // Invert permutation map.
1431   unsigned maxLoopDepth = loops.size();
1432   SmallVector<unsigned, 4> loopPermMapInv;
1433   loopPermMapInv.resize(maxLoopDepth);
1434   for (unsigned i = 0; i < maxLoopDepth; ++i)
1435     loopPermMapInv[loopPermMap[i]] = i;
1436 
1437   // Check each dependence component against the permutation to see if the
1438   // desired loop interchange permutation would make the dependence vectors
1439   // lexicographically negative.
1440   // Example 1: [-1, 1][0, 0]
1441   // Example 2: [0, 0][-1, 1]
1442   for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
1443     const SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
1444     assert(depComps.size() >= maxLoopDepth);
1445     // Check if the first non-zero dependence component is positive.
1446     // This iterates through loops in the desired order.
1447     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1448       unsigned permIndex = loopPermMapInv[j];
1449       assert(depComps[permIndex].lb.hasValue());
1450       int64_t depCompLb = depComps[permIndex].lb.getValue();
1451       if (depCompLb > 0)
1452         break;
1453       if (depCompLb < 0)
1454         return false;
1455     }
1456   }
1457   return true;
1458 }
1459 
1460 /// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
1461 /// nested sequence of loops in 'loops' would violate dependences.
isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,ArrayRef<unsigned> loopPermMap)1462 bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
1463                                              ArrayRef<unsigned> loopPermMap) {
1464   // Gather dependence components for dependences between all ops in loop nest
1465   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1466   assert(loopPermMap.size() == loops.size());
1467   unsigned maxLoopDepth = loops.size();
1468   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1469   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1470   return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
1471 }
1472 
1473 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
1474 /// in it from outermost to innermost.
1475 bool LLVM_ATTRIBUTE_UNUSED
isPerfectlyNested(ArrayRef<AffineForOp> loops)1476 mlir::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
1477   assert(!loops.empty() && "no loops provided");
1478 
1479   // We already know that the block can't be empty.
1480   auto hasTwoElements = [](Block *block) {
1481     auto secondOpIt = std::next(block->begin());
1482     return secondOpIt != block->end() && &*secondOpIt == &block->back();
1483   };
1484 
1485   auto enclosingLoop = loops.front();
1486   for (auto loop : loops.drop_front()) {
1487     auto parentForOp = dyn_cast<AffineForOp>(loop->getParentOp());
1488     // parentForOp's body should be just this loop and the terminator.
1489     if (parentForOp != enclosingLoop || !hasTwoElements(parentForOp.getBody()))
1490       return false;
1491     enclosingLoop = loop;
1492   }
1493   return true;
1494 }
1495 
1496 // input[i] should move from position i -> permMap[i]. Returns the position in
1497 // `input` that becomes the new outermost loop.
permuteLoops(MutableArrayRef<AffineForOp> input,ArrayRef<unsigned> permMap)1498 unsigned mlir::permuteLoops(MutableArrayRef<AffineForOp> input,
1499                             ArrayRef<unsigned> permMap) {
1500   assert(input.size() == permMap.size() && "invalid permutation map size");
1501   // Check whether the permutation spec is valid. This is a small vector - we'll
1502   // just sort and check if it's iota.
1503   SmallVector<unsigned, 4> checkPermMap(permMap.begin(), permMap.end());
1504   llvm::sort(checkPermMap);
1505   if (llvm::any_of(llvm::enumerate(checkPermMap),
1506                    [](const auto &en) { return en.value() != en.index(); }))
1507     assert(false && "invalid permutation map");
1508 
1509   // Nothing to do.
1510   if (input.size() < 2)
1511     return 0;
1512 
1513   assert(isPerfectlyNested(input) && "input not perfectly nested");
1514 
1515   // Compute the inverse mapping, invPermMap: since input[i] goes to position
1516   // permMap[i], position i of the permuted nest is at input[invPermMap[i]].
1517   SmallVector<std::pair<unsigned, unsigned>, 4> invPermMap;
1518   for (unsigned i = 0, e = input.size(); i < e; ++i)
1519     invPermMap.push_back({permMap[i], i});
1520   llvm::sort(invPermMap);
1521 
1522   // Move the innermost loop body to the loop that would be the innermost in the
1523   // permuted nest (only if the innermost loop is going to change).
1524   if (permMap.back() != input.size() - 1) {
1525     auto *destBody = input[invPermMap.back().second].getBody();
1526     auto *srcBody = input.back().getBody();
1527     destBody->getOperations().splice(destBody->begin(),
1528                                      srcBody->getOperations(), srcBody->begin(),
1529                                      std::prev(srcBody->end()));
1530   }
1531 
1532   // We'll move each loop in `input` in the reverse order so that its body is
1533   // empty when we are moving it; this incurs zero copies and no erasing.
1534   for (int i = input.size() - 1; i >= 0; --i) {
1535     // If this has to become the outermost loop after permutation, add it to the
1536     // parent block of the original root.
1537     if (permMap[i] == 0) {
1538       // If the root remains the same, nothing to do.
1539       if (i == 0)
1540         continue;
1541       // Make input[i] the new outermost loop moving it into parentBlock.
1542       auto *parentBlock = input[0]->getBlock();
1543       parentBlock->getOperations().splice(Block::iterator(input[0]),
1544                                           input[i]->getBlock()->getOperations(),
1545                                           Block::iterator(input[i]));
1546       continue;
1547     }
1548 
1549     // If the parent in the permuted order is the same as in the original,
1550     // nothing to do.
1551     unsigned parentPosInInput = invPermMap[permMap[i] - 1].second;
1552     if (i > 0 && static_cast<unsigned>(i - 1) == parentPosInInput)
1553       continue;
1554 
1555     // Move input[i] to its surrounding loop in the transformed nest.
1556     auto *destBody = input[parentPosInInput].getBody();
1557     destBody->getOperations().splice(destBody->begin(),
1558                                      input[i]->getBlock()->getOperations(),
1559                                      Block::iterator(input[i]));
1560   }
1561 
1562   return invPermMap[0].second;
1563 }
1564 
1565 // Sinks all sequential loops to the innermost levels (while preserving
1566 // relative order among them) and moves all parallel loops to the
1567 // outermost (while again preserving relative order among them).
sinkSequentialLoops(AffineForOp forOp)1568 AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
1569   SmallVector<AffineForOp, 4> loops;
1570   getPerfectlyNestedLoops(loops, forOp);
1571   if (loops.size() < 2)
1572     return forOp;
1573 
1574   // Gather dependence components for dependences between all ops in loop nest
1575   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1576   unsigned maxLoopDepth = loops.size();
1577   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1578   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1579 
1580   // Mark loops as either parallel or sequential.
1581   SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
1582   for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
1583     SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
1584     assert(depComps.size() >= maxLoopDepth);
1585     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1586       DependenceComponent &depComp = depComps[j];
1587       assert(depComp.lb.hasValue() && depComp.ub.hasValue());
1588       if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
1589         isParallelLoop[j] = false;
1590     }
1591   }
1592 
1593   // Count the number of parallel loops.
1594   unsigned numParallelLoops = 0;
1595   for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
1596     if (isParallelLoop[i])
1597       ++numParallelLoops;
1598 
1599   // Compute permutation of loops that sinks sequential loops (and thus raises
1600   // parallel loops) while preserving relative order.
1601   SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
1602   unsigned nextSequentialLoop = numParallelLoops;
1603   unsigned nextParallelLoop = 0;
1604   for (unsigned i = 0; i < maxLoopDepth; ++i) {
1605     if (isParallelLoop[i]) {
1606       loopPermMap[i] = nextParallelLoop++;
1607     } else {
1608       loopPermMap[i] = nextSequentialLoop++;
1609     }
1610   }
1611 
1612   // Check if permutation 'loopPermMap' would violate dependences.
1613   if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
1614     return forOp;
1615   // Perform loop interchange according to permutation 'loopPermMap'.
1616   unsigned loopNestRootIndex = permuteLoops(loops, loopPermMap);
1617   return loops[loopNestRootIndex];
1618 }
1619 
1620 // Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
1621 // lower (resp. upper) loop bound. When called for both the lower and upper
1622 // bounds, the resulting IR resembles:
1623 //
1624 // ```mlir
1625 //    affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
1626 //      ...
1627 //    }
1628 // ```
augmentMapAndBounds(OpBuilder & b,Value iv,AffineMap * map,SmallVector<Value,4> * operands,int64_t offset=0)1629 static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
1630                                 SmallVector<Value, 4> *operands,
1631                                 int64_t offset = 0) {
1632   auto bounds = llvm::to_vector<4>(map->getResults());
1633   bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
1634   operands->insert(operands->begin() + map->getNumDims(), iv);
1635   *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds,
1636                         b.getContext());
1637   canonicalizeMapAndOperands(map, operands);
1638 }
1639 
1640 // Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
1641 // Stripmine-sink is a primitive building block for generalized tiling of
1642 // imperfectly nested loops.
1643 // This transformation is purely mechanical and does not check legality,
1644 // profitability or even structural correctness. It is the user's
1645 // responsibility to specify `targets` that are dominated by `forOp`.
1646 // Returns the new AffineForOps, one per `targets`, nested immediately under
1647 // each of the `targets`.
1648 static SmallVector<AffineForOp, 8>
stripmineSink(AffineForOp forOp,uint64_t factor,ArrayRef<AffineForOp> targets)1649 stripmineSink(AffineForOp forOp, uint64_t factor,
1650               ArrayRef<AffineForOp> targets) {
1651   auto originalStep = forOp.getStep();
1652   auto scaledStep = originalStep * factor;
1653   forOp.setStep(scaledStep);
1654 
1655   OpBuilder b(forOp->getBlock(), std::next(Block::iterator(forOp)));
1656 
1657   // Lower-bound map creation.
1658   auto lbMap = forOp.getLowerBoundMap();
1659   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1660   augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
1661 
1662   // Upper-bound map creation.
1663   auto ubMap = forOp.getUpperBoundMap();
1664   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1665   augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
1666                       /*offset=*/scaledStep);
1667 
1668   auto iv = forOp.getInductionVar();
1669   SmallVector<AffineForOp, 8> innerLoops;
1670   for (auto t : targets) {
1671     // Insert newForOp before the terminator of `t`.
1672     auto b = OpBuilder::atBlockTerminator(t.getBody());
1673     auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
1674                                           ubOperands, ubMap, originalStep);
1675     auto begin = t.getBody()->begin();
1676     // Skip terminator and `newForOp` which is just before the terminator.
1677     auto nOps = t.getBody()->getOperations().size() - 2;
1678     newForOp.getBody()->getOperations().splice(
1679         newForOp.getBody()->getOperations().begin(),
1680         t.getBody()->getOperations(), begin, std::next(begin, nOps));
1681     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1682                                newForOp.region());
1683     innerLoops.push_back(newForOp);
1684   }
1685 
1686   return innerLoops;
1687 }
1688 
stripmineSink(scf::ForOp forOp,Value factor,ArrayRef<scf::ForOp> targets)1689 static Loops stripmineSink(scf::ForOp forOp, Value factor,
1690                            ArrayRef<scf::ForOp> targets) {
1691   auto originalStep = forOp.step();
1692   auto iv = forOp.getInductionVar();
1693 
1694   OpBuilder b(forOp);
1695   forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor));
1696 
1697   Loops innerLoops;
1698   for (auto t : targets) {
1699     // Save information for splicing ops out of t when done
1700     auto begin = t.getBody()->begin();
1701     auto nOps = t.getBody()->getOperations().size();
1702 
1703     // Insert newForOp before the terminator of `t`.
1704     auto b = OpBuilder::atBlockTerminator((t.getBody()));
1705     Value stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
1706     Value less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
1707                                   forOp.upperBound(), stepped);
1708     Value ub =
1709         b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
1710 
1711     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1712     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1713     newForOp.getBody()->getOperations().splice(
1714         newForOp.getBody()->getOperations().begin(),
1715         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1716     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1717                                newForOp.region());
1718 
1719     innerLoops.push_back(newForOp);
1720   }
1721 
1722   return innerLoops;
1723 }
1724 
1725 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1726 // Returns the new AffineForOps, nested immediately under `target`.
1727 template <typename ForType, typename SizeType>
stripmineSink(ForType forOp,SizeType factor,ForType target)1728 static ForType stripmineSink(ForType forOp, SizeType factor, ForType target) {
1729   // TODO: Use cheap structural assertions that targets are nested under
1730   // forOp and that targets are not nested under each other when DominanceInfo
1731   // exposes the capability. It seems overkill to construct a whole function
1732   // dominance tree at this point.
1733   auto res = stripmineSink(forOp, factor, ArrayRef<ForType>{target});
1734   assert(res.size() == 1 && "Expected 1 inner forOp");
1735   return res[0];
1736 }
1737 
1738 template <typename ForType, typename SizeType>
1739 static SmallVector<SmallVector<ForType, 8>, 8>
tileImpl(ArrayRef<ForType> forOps,ArrayRef<SizeType> sizes,ArrayRef<ForType> targets)1740 tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes,
1741          ArrayRef<ForType> targets) {
1742   SmallVector<SmallVector<ForType, 8>, 8> res;
1743   SmallVector<ForType, 8> currentTargets(targets.begin(), targets.end());
1744   for (auto it : llvm::zip(forOps, sizes)) {
1745     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1746     res.push_back(step);
1747     currentTargets = step;
1748   }
1749   return res;
1750 }
1751 
1752 SmallVector<SmallVector<AffineForOp, 8>, 8>
tile(ArrayRef<AffineForOp> forOps,ArrayRef<uint64_t> sizes,ArrayRef<AffineForOp> targets)1753 mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
1754            ArrayRef<AffineForOp> targets) {
1755   return tileImpl(forOps, sizes, targets);
1756 }
1757 
tile(ArrayRef<scf::ForOp> forOps,ArrayRef<Value> sizes,ArrayRef<scf::ForOp> targets)1758 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
1759                                  ArrayRef<Value> sizes,
1760                                  ArrayRef<scf::ForOp> targets) {
1761   return tileImpl(forOps, sizes, targets);
1762 }
1763 
1764 template <typename ForType, typename SizeType>
1765 static SmallVector<ForType, 8>
tileImpl(ArrayRef<ForType> forOps,ArrayRef<SizeType> sizes,ForType target)1766 tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, ForType target) {
1767   SmallVector<ForType, 8> res;
1768   for (auto loops : tile(forOps, sizes, ArrayRef<ForType>{target})) {
1769     assert(loops.size() == 1);
1770     res.push_back(loops[0]);
1771   }
1772   return res;
1773 }
1774 
tile(ArrayRef<AffineForOp> forOps,ArrayRef<uint64_t> sizes,AffineForOp target)1775 SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
1776                                        ArrayRef<uint64_t> sizes,
1777                                        AffineForOp target) {
1778   return tileImpl(forOps, sizes, target);
1779 }
1780 
tile(ArrayRef<scf::ForOp> forOps,ArrayRef<Value> sizes,scf::ForOp target)1781 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
1782                  scf::ForOp target) {
1783   return tileImpl(forOps, sizes, target);
1784 }
1785 
tilePerfectlyNested(scf::ForOp rootForOp,ArrayRef<Value> sizes)1786 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1787   // Collect perfectly nested loops.  If more size values provided than nested
1788   // loops available, truncate `sizes`.
1789   SmallVector<scf::ForOp, 4> forOps;
1790   forOps.reserve(sizes.size());
1791   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1792   if (forOps.size() < sizes.size())
1793     sizes = sizes.take_front(forOps.size());
1794 
1795   return ::tile(forOps, sizes, forOps.back());
1796 }
1797 
1798 // Hoist the ops within `outer` that appear before `inner`.
1799 // Such ops include the ops that have been introduced by parametric tiling.
1800 // Ops that come from triangular loops (i.e. that belong to the program slice
1801 // rooted at `outer`) and ops that have side effects cannot be hoisted.
1802 // Return failure when any op fails to hoist.
hoistOpsBetween(scf::ForOp outer,scf::ForOp inner)1803 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1804   SetVector<Operation *> forwardSlice;
1805   getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
1806     return op != inner.getOperation();
1807   });
1808   LogicalResult status = success();
1809   SmallVector<Operation *, 8> toHoist;
1810   for (auto &op : outer.getBody()->without_terminator()) {
1811     // Stop when encountering the inner loop.
1812     if (&op == inner.getOperation())
1813       break;
1814     // Skip over non-hoistable ops.
1815     if (forwardSlice.count(&op) > 0) {
1816       status = failure();
1817       continue;
1818     }
1819     // Skip scf::ForOp, these are not considered a failure.
1820     if (op.getNumRegions() > 0)
1821       continue;
1822     // Skip other ops with regions.
1823     if (op.getNumRegions() > 0) {
1824       status = failure();
1825       continue;
1826     }
1827     // Skip if op has side effects.
1828     // TODO: loads to immutable memory regions are ok.
1829     if (!MemoryEffectOpInterface::hasNoEffect(&op)) {
1830       status = failure();
1831       continue;
1832     }
1833     toHoist.push_back(&op);
1834   }
1835   auto *outerForOp = outer.getOperation();
1836   for (auto *op : toHoist)
1837     op->moveBefore(outerForOp);
1838   return status;
1839 }
1840 
1841 // Traverse the interTile and intraTile loops and try to hoist ops such that
1842 // bands of perfectly nested loops are isolated.
1843 // Return failure if either perfect interTile or perfect intraTile bands cannot
1844 // be formed.
tryIsolateBands(const TileLoops & tileLoops)1845 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1846   LogicalResult status = success();
1847   const Loops &interTile = tileLoops.first;
1848   const Loops &intraTile = tileLoops.second;
1849   auto size = interTile.size();
1850   assert(size == intraTile.size());
1851   if (size <= 1)
1852     return success();
1853   for (unsigned s = 1; s < size; ++s)
1854     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1855                                : failure();
1856   for (unsigned s = 1; s < size; ++s)
1857     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1858                                : failure();
1859   return status;
1860 }
1861 
extractFixedOuterLoops(scf::ForOp rootForOp,ArrayRef<int64_t> sizes)1862 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
1863                                        ArrayRef<int64_t> sizes) {
1864   // Collect perfectly nested loops.  If more size values provided than nested
1865   // loops available, truncate `sizes`.
1866   SmallVector<scf::ForOp, 4> forOps;
1867   forOps.reserve(sizes.size());
1868   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1869   if (forOps.size() < sizes.size())
1870     sizes = sizes.take_front(forOps.size());
1871 
1872   // Compute the tile sizes such that i-th outer loop executes size[i]
1873   // iterations.  Given that the loop current executes
1874   //   numIterations = ceildiv((upperBound - lowerBound), step)
1875   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1876   SmallVector<Value, 4> tileSizes;
1877   tileSizes.reserve(sizes.size());
1878   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1879     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1880 
1881     auto forOp = forOps[i];
1882     OpBuilder builder(forOp);
1883     auto loc = forOp.getLoc();
1884     Value diff =
1885         builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound());
1886     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
1887     Value iterationsPerBlock =
1888         ceilDivPositive(builder, loc, numIterations, sizes[i]);
1889     tileSizes.push_back(iterationsPerBlock);
1890   }
1891 
1892   // Call parametric tiling with the given sizes.
1893   auto intraTile = tile(forOps, tileSizes, forOps.back());
1894   TileLoops tileLoops = std::make_pair(forOps, intraTile);
1895 
1896   // TODO: for now we just ignore the result of band isolation.
1897   // In the future, mapping decisions may be impacted by the ability to
1898   // isolate perfectly nested bands.
1899   tryIsolateBands(tileLoops);
1900 
1901   return tileLoops;
1902 }
1903 
1904 /// Return the new lower bound, upper bound, and step in that order. Insert any
1905 /// additional bounds calculations before the given builder and any additional
1906 /// conversion back to the original loop induction value inside the given Block.
normalizeLoop(OpBuilder & boundsBuilder,OpBuilder & insideLoopBuilder,Location loc,Value lowerBound,Value upperBound,Value step,Value inductionVar)1907 static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
1908                                 OpBuilder &insideLoopBuilder, Location loc,
1909                                 Value lowerBound, Value upperBound, Value step,
1910                                 Value inductionVar) {
1911   // Check if the loop is already known to have a constant zero lower bound or
1912   // a constant one step.
1913   bool isZeroBased = false;
1914   if (auto ubCst = lowerBound.getDefiningOp<ConstantIndexOp>())
1915     isZeroBased = ubCst.getValue() == 0;
1916 
1917   bool isStepOne = false;
1918   if (auto stepCst = step.getDefiningOp<ConstantIndexOp>())
1919     isStepOne = stepCst.getValue() == 1;
1920 
1921   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
1922   // assuming the step is strictly positive.  Update the bounds and the step
1923   // of the loop to go from 0 to the number of iterations, if necessary.
1924   // TODO: introduce support for negative steps or emit dynamic asserts
1925   // on step positivity, whatever gets implemented first.
1926   if (isZeroBased && isStepOne)
1927     return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
1928             /*step=*/step};
1929 
1930   Value diff = boundsBuilder.create<SubIOp>(loc, upperBound, lowerBound);
1931   Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step);
1932 
1933   Value newLowerBound =
1934       isZeroBased ? lowerBound : boundsBuilder.create<ConstantIndexOp>(loc, 0);
1935   Value newStep =
1936       isStepOne ? step : boundsBuilder.create<ConstantIndexOp>(loc, 1);
1937 
1938   // Insert code computing the value of the original loop induction variable
1939   // from the "normalized" one.
1940   Value scaled =
1941       isStepOne ? inductionVar
1942                 : insideLoopBuilder.create<MulIOp>(loc, inductionVar, step);
1943   Value shifted =
1944       isZeroBased ? scaled
1945                   : insideLoopBuilder.create<AddIOp>(loc, scaled, lowerBound);
1946 
1947   SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
1948                                        shifted.getDefiningOp()};
1949   inductionVar.replaceAllUsesExcept(shifted, preserve);
1950   return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
1951           /*step=*/newStep};
1952 }
1953 
1954 /// Transform a loop with a strictly positive step
1955 ///   for %i = %lb to %ub step %s
1956 /// into a 0-based loop with step 1
1957 ///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
1958 ///     %i = %ii * %s + %lb
1959 /// Insert the induction variable remapping in the body of `inner`, which is
1960 /// expected to be either `loop` or another loop perfectly nested under `loop`.
1961 /// Insert the definition of new bounds immediate before `outer`, which is
1962 /// expected to be either `loop` or its parent in the loop nest.
normalizeLoop(scf::ForOp loop,scf::ForOp outer,scf::ForOp inner)1963 static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
1964   OpBuilder builder(outer);
1965   OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
1966   auto loopPieces =
1967       normalizeLoop(builder, innerBuilder, loop.getLoc(), loop.lowerBound(),
1968                     loop.upperBound(), loop.step(), loop.getInductionVar());
1969 
1970   loop.setLowerBound(loopPieces.lowerBound);
1971   loop.setUpperBound(loopPieces.upperBound);
1972   loop.setStep(loopPieces.step);
1973 }
1974 
coalesceLoops(MutableArrayRef<scf::ForOp> loops)1975 void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
1976   if (loops.size() < 2)
1977     return;
1978 
1979   scf::ForOp innermost = loops.back();
1980   scf::ForOp outermost = loops.front();
1981 
1982   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
1983   // allows the following code to assume upperBound is the number of iterations.
1984   for (auto loop : loops)
1985     normalizeLoop(loop, outermost, innermost);
1986 
1987   // 2. Emit code computing the upper bound of the coalesced loop as product
1988   // of the number of iterations of all loops.
1989   OpBuilder builder(outermost);
1990   Location loc = outermost.getLoc();
1991   Value upperBound = outermost.upperBound();
1992   for (auto loop : loops.drop_front())
1993     upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound());
1994   outermost.setUpperBound(upperBound);
1995 
1996   builder.setInsertionPointToStart(outermost.getBody());
1997 
1998   // 3. Remap induction variables.  For each original loop, the value of the
1999   // induction variable can be obtained by dividing the induction variable of
2000   // the linearized loop by the total number of iterations of the loops nested
2001   // in it modulo the number of iterations in this loop (remove the values
2002   // related to the outer loops):
2003   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
2004   // Compute these iteratively from the innermost loop by creating a "running
2005   // quotient" of division by the range.
2006   Value previous = outermost.getInductionVar();
2007   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
2008     unsigned idx = loops.size() - i - 1;
2009     if (i != 0)
2010       previous = builder.create<SignedDivIOp>(loc, previous,
2011                                               loops[idx + 1].upperBound());
2012 
2013     Value iv = (i == e - 1) ? previous
2014                             : builder.create<SignedRemIOp>(
2015                                   loc, previous, loops[idx].upperBound());
2016     replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
2017                                loops.back().region());
2018   }
2019 
2020   // 4. Move the operations from the innermost just above the second-outermost
2021   // loop, delete the extra terminator and the second-outermost loop.
2022   scf::ForOp second = loops[1];
2023   innermost.getBody()->back().erase();
2024   outermost.getBody()->getOperations().splice(
2025       Block::iterator(second.getOperation()),
2026       innermost.getBody()->getOperations());
2027   second.erase();
2028 }
2029 
collapseParallelLoops(scf::ParallelOp loops,ArrayRef<std::vector<unsigned>> combinedDimensions)2030 void mlir::collapseParallelLoops(
2031     scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
2032   OpBuilder outsideBuilder(loops);
2033   Location loc = loops.getLoc();
2034 
2035   // Normalize ParallelOp's iteration pattern.
2036   SmallVector<Value, 3> normalizedLowerBounds;
2037   SmallVector<Value, 3> normalizedSteps;
2038   SmallVector<Value, 3> normalizedUpperBounds;
2039   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
2040     OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
2041     auto resultBounds =
2042         normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
2043                       loops.lowerBound()[i], loops.upperBound()[i],
2044                       loops.step()[i], loops.getBody()->getArgument(i));
2045 
2046     normalizedLowerBounds.push_back(resultBounds.lowerBound);
2047     normalizedUpperBounds.push_back(resultBounds.upperBound);
2048     normalizedSteps.push_back(resultBounds.step);
2049   }
2050 
2051   // Combine iteration spaces.
2052   SmallVector<Value, 3> lowerBounds;
2053   SmallVector<Value, 3> steps;
2054   SmallVector<Value, 3> upperBounds;
2055   auto cst0 = outsideBuilder.create<ConstantIndexOp>(loc, 0);
2056   auto cst1 = outsideBuilder.create<ConstantIndexOp>(loc, 1);
2057   for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
2058     Value newUpperBound = outsideBuilder.create<ConstantIndexOp>(loc, 1);
2059     for (auto idx : combinedDimensions[i]) {
2060       newUpperBound = outsideBuilder.create<MulIOp>(loc, newUpperBound,
2061                                                     normalizedUpperBounds[idx]);
2062     }
2063     lowerBounds.push_back(cst0);
2064     steps.push_back(cst1);
2065     upperBounds.push_back(newUpperBound);
2066   }
2067 
2068   // Create new ParallelLoop with conversions to the original induction values.
2069   // The loop below uses divisions to get the relevant range of values in the
2070   // new induction value that represent each range of the original induction
2071   // value. The remainders then determine based on that range, which iteration
2072   // of the original induction value this represents. This is a normalized value
2073   // that is un-normalized already by the previous logic.
2074   auto newPloop = outsideBuilder.create<scf::ParallelOp>(
2075       loc, lowerBounds, upperBounds, steps,
2076       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
2077         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
2078           Value previous = ploopIVs[i];
2079           unsigned numberCombinedDimensions = combinedDimensions[i].size();
2080           // Iterate over all except the last induction value.
2081           for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) {
2082             unsigned idx = combinedDimensions[i][j];
2083 
2084             // Determine the current induction value's current loop iteration
2085             Value iv = insideBuilder.create<SignedRemIOp>(
2086                 loc, previous, normalizedUpperBounds[idx]);
2087             replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
2088                                        loops.region());
2089 
2090             // Remove the effect of the current induction value to prepare for
2091             // the next value.
2092             previous = insideBuilder.create<SignedDivIOp>(
2093                 loc, previous, normalizedUpperBounds[idx]);
2094           }
2095 
2096           // The final induction value is just the remaining value.
2097           unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1];
2098           replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
2099                                      previous, loops.region());
2100         }
2101       });
2102 
2103   // Replace the old loop with the new loop.
2104   loops.getBody()->back().erase();
2105   newPloop.getBody()->getOperations().splice(
2106       Block::iterator(newPloop.getBody()->back()),
2107       loops.getBody()->getOperations());
2108   loops.erase();
2109 }
2110 
mapLoopToProcessorIds(scf::ForOp forOp,ArrayRef<Value> processorId,ArrayRef<Value> numProcessors)2111 void mlir::mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
2112                                  ArrayRef<Value> numProcessors) {
2113   assert(processorId.size() == numProcessors.size());
2114   if (processorId.empty())
2115     return;
2116 
2117   OpBuilder b(forOp);
2118   Location loc(forOp.getLoc());
2119   Value mul = processorId.front();
2120   for (unsigned i = 1, e = processorId.size(); i < e; ++i)
2121     mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
2122                            processorId[i]);
2123   Value lb = b.create<AddIOp>(loc, forOp.lowerBound(),
2124                               b.create<MulIOp>(loc, forOp.step(), mul));
2125   forOp.setLowerBound(lb);
2126 
2127   Value step = forOp.step();
2128   for (auto numProcs : numProcessors)
2129     step = b.create<MulIOp>(loc, step, numProcs);
2130   forOp.setStep(step);
2131 }
2132 
2133 /// Given a memref region, determine the lowest depth at which transfers can be
2134 /// placed for it, and return the corresponding block, start and end positions
2135 /// in the block for placing incoming (read) and outgoing (write) copies
2136 /// respectively. The lowest depth depends on whether the region being accessed
2137 /// is hoistable with respect to one or more immediately surrounding loops.
2138 static void
findHighestBlockForPlacement(const MemRefRegion & region,Block & block,Block::iterator & begin,Block::iterator & end,Block ** copyPlacementBlock,Block::iterator * copyInPlacementStart,Block::iterator * copyOutPlacementStart)2139 findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
2140                              Block::iterator &begin, Block::iterator &end,
2141                              Block **copyPlacementBlock,
2142                              Block::iterator *copyInPlacementStart,
2143                              Block::iterator *copyOutPlacementStart) {
2144   const auto *cst = region.getConstraints();
2145   SmallVector<Value, 4> symbols;
2146   cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
2147 
2148   SmallVector<AffineForOp, 4> enclosingFors;
2149   getLoopIVs(*block.begin(), &enclosingFors);
2150   // Walk up loop parents till we find an IV on which this region is
2151   // symbolic/variant.
2152   auto it = enclosingFors.rbegin();
2153   for (auto e = enclosingFors.rend(); it != e; ++it) {
2154     // TODO: also need to be checking this for regions symbols that
2155     // aren't loop IVs, whether we are within their resp. defs' dominance scope.
2156     if (llvm::is_contained(symbols, it->getInductionVar()))
2157       break;
2158   }
2159 
2160   if (it != enclosingFors.rbegin()) {
2161     auto lastInvariantIV = *std::prev(it);
2162     *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
2163     *copyOutPlacementStart = std::next(*copyInPlacementStart);
2164     *copyPlacementBlock = lastInvariantIV->getBlock();
2165   } else {
2166     *copyInPlacementStart = begin;
2167     *copyOutPlacementStart = end;
2168     *copyPlacementBlock = &block;
2169   }
2170 }
2171 
2172 // Info comprising stride and number of elements transferred every stride.
2173 struct StrideInfo {
2174   int64_t stride;
2175   int64_t numEltPerStride;
2176 };
2177 
2178 /// Returns striding information for a copy/transfer of this region with
2179 /// potentially multiple striding levels from outermost to innermost. For an
2180 /// n-dimensional region, there can be at most n-1 levels of striding
2181 /// successively nested.
2182 //  TODO: make this work with non-identity layout maps.
getMultiLevelStrides(const MemRefRegion & region,ArrayRef<int64_t> bufferShape,SmallVectorImpl<StrideInfo> * strideInfos)2183 static void getMultiLevelStrides(const MemRefRegion &region,
2184                                  ArrayRef<int64_t> bufferShape,
2185                                  SmallVectorImpl<StrideInfo> *strideInfos) {
2186   if (bufferShape.size() <= 1)
2187     return;
2188 
2189   int64_t numEltPerStride = 1;
2190   int64_t stride = 1;
2191   for (int d = bufferShape.size() - 1; d >= 1; d--) {
2192     int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
2193     stride *= dimSize;
2194     numEltPerStride *= bufferShape[d];
2195     // A stride is needed only if the region has a shorter extent than the
2196     // memref along the dimension *and* has an extent greater than one along the
2197     // next major dimension.
2198     if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
2199       strideInfos->push_back({stride, numEltPerStride});
2200     }
2201   }
2202 }
2203 
2204 /// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
2205 /// returns the outermost AffineForOp of the copy loop nest. `lbMaps` and
2206 /// `ubMaps` along with `lbOperands` and `ubOperands` hold the lower and upper
2207 /// bound information for the copy loop nest. `fastBufOffsets` contain the
2208 /// expressions to be subtracted out from the respective copy loop iterators in
2209 /// order to index the fast buffer. If `copyOut' is true, generates a copy-out;
2210 /// otherwise a copy-in. Builder `b` should be set to the point the copy nest is
2211 /// inserted.
2212 //
2213 /// The copy-in nest is generated as follows as an example for a 2-d region:
2214 /// for x = ...
2215 ///   for y = ...
2216 ///     fast_buf[x - offset_x][y - offset_y] = memref[x][y]
2217 ///
2218 static AffineForOp
generatePointWiseCopy(Location loc,Value memref,Value fastMemRef,ArrayRef<AffineMap> lbMaps,ArrayRef<Value> lbOperands,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> ubOperands,ArrayRef<AffineExpr> fastBufOffsets,bool isCopyOut,OpBuilder b)2219 generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
2220                       ArrayRef<AffineMap> lbMaps, ArrayRef<Value> lbOperands,
2221                       ArrayRef<AffineMap> ubMaps, ArrayRef<Value> ubOperands,
2222                       ArrayRef<AffineExpr> fastBufOffsets, bool isCopyOut,
2223                       OpBuilder b) {
2224   assert(llvm::all_of(lbMaps, [&](AffineMap lbMap) {
2225     return lbMap.getNumInputs() == lbOperands.size();
2226   }));
2227   assert(llvm::all_of(ubMaps, [&](AffineMap ubMap) {
2228     return ubMap.getNumInputs() == ubOperands.size();
2229   }));
2230 
2231   unsigned rank = memref.getType().cast<MemRefType>().getRank();
2232   assert(lbMaps.size() == rank && "wrong number of lb maps");
2233   assert(ubMaps.size() == rank && "wrong number of ub maps");
2234 
2235   SmallVector<Value, 4> memIndices;
2236   SmallVector<AffineExpr, 4> fastBufExprs;
2237   SmallVector<Value, 4> fastBufMapOperands;
2238   AffineForOp copyNestRoot;
2239   SmallVector<AffineApplyOp, 4> mayBeDeadApplys;
2240   for (unsigned d = 0; d < rank; ++d) {
2241     auto forOp = createCanonicalizedAffineForOp(b, loc, lbOperands, lbMaps[d],
2242                                                 ubOperands, ubMaps[d]);
2243     if (d == 0)
2244       copyNestRoot = forOp;
2245 
2246     b = OpBuilder::atBlockTerminator(forOp.getBody());
2247 
2248     auto fastBufOffsetMap =
2249         AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
2250     auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
2251 
2252     // Construct the subscript for the fast memref being copied into/from:
2253     // x - offset_x.
2254     fastBufExprs.push_back(b.getAffineDimExpr(2 * d + 1) -
2255                            b.getAffineDimExpr(2 * d));
2256     fastBufMapOperands.push_back(offset);
2257     fastBufMapOperands.push_back(forOp.getInductionVar());
2258     mayBeDeadApplys.push_back(offset);
2259 
2260     // Subscript for the slow memref being copied.
2261     memIndices.push_back(forOp.getInductionVar());
2262   }
2263 
2264   auto fastBufMap =
2265       AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs, b.getContext());
2266   fullyComposeAffineMapAndOperands(&fastBufMap, &fastBufMapOperands);
2267   fastBufMap = simplifyAffineMap(fastBufMap);
2268   canonicalizeMapAndOperands(&fastBufMap, &fastBufMapOperands);
2269 
2270   // Drop any dead affine.applys.
2271   for (auto applyOp : mayBeDeadApplys)
2272     if (applyOp.use_empty())
2273       applyOp.erase();
2274 
2275   if (!isCopyOut) {
2276     // Copy in.
2277     auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
2278     b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap,
2279                             fastBufMapOperands);
2280     return copyNestRoot;
2281   }
2282 
2283   // Copy out.
2284   auto load =
2285       b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands);
2286   b.create<AffineStoreOp>(loc, load, memref, memIndices);
2287   return copyNestRoot;
2288 }
2289 
2290 static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block & block)2291 emitRemarkForBlock(Block &block) {
2292   return block.getParentOp()->emitRemark();
2293 }
2294 
2295 /// Creates a buffer in the faster memory space for the specified memref region;
2296 /// generates a copy from the lower memory space to this one, and replaces all
2297 /// loads/stores in the block range [`begin', `end') of `block' to load/store
2298 /// from that buffer. Returns failure if copies could not be generated due to
2299 /// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
2300 /// in copyPlacementBlock specify the insertion points where the incoming copies
2301 /// and outgoing copies, respectively, should be inserted (the insertion happens
2302 /// right before the insertion point). Since `begin` can itself be invalidated
2303 /// due to the memref rewriting done from this method, the output argument
2304 /// `nBegin` is set to its replacement (set to `begin` if no invalidation
2305 /// happens). Since outgoing copies could have  been inserted at `end`, the
2306 /// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
2307 /// size of the fast buffer allocated.
generateCopy(const MemRefRegion & region,Block * block,Block::iterator begin,Block::iterator end,Block * copyPlacementBlock,Block::iterator copyInPlacementStart,Block::iterator copyOutPlacementStart,AffineCopyOptions copyOptions,DenseMap<Value,Value> & fastBufferMap,DenseSet<Operation * > & copyNests,uint64_t * sizeInBytes,Block::iterator * nBegin,Block::iterator * nEnd)2308 static LogicalResult generateCopy(
2309     const MemRefRegion &region, Block *block, Block::iterator begin,
2310     Block::iterator end, Block *copyPlacementBlock,
2311     Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
2312     AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
2313     DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
2314     Block::iterator *nBegin, Block::iterator *nEnd) {
2315   *nBegin = begin;
2316   *nEnd = end;
2317 
2318   FuncOp f = begin->getParentOfType<FuncOp>();
2319   OpBuilder topBuilder(f.getBody());
2320   Value zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
2321 
2322   if (begin == end)
2323     return success();
2324 
2325   // Is the copy out point at the end of the block where we are doing
2326   // explicit copying.
2327   bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
2328 
2329   // Copies for read regions are going to be inserted at 'begin'.
2330   OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
2331   // Copies for write regions are going to be inserted at 'end'.
2332   OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
2333   OpBuilder &b = region.isWrite() ? epilogue : prologue;
2334 
2335   // Builder to create constants at the top level.
2336   auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>();
2337   OpBuilder top(func.getBody());
2338 
2339   auto loc = region.loc;
2340   auto memref = region.memref;
2341   auto memRefType = memref.getType().cast<MemRefType>();
2342 
2343   auto layoutMaps = memRefType.getAffineMaps();
2344   if (layoutMaps.size() > 1 ||
2345       (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
2346     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
2347     return failure();
2348   }
2349 
2350   // Indices to use for the copying.
2351   // Indices for the original memref being copied from/to.
2352   SmallVector<Value, 4> memIndices;
2353   // Indices for the faster buffer being copied into/from.
2354   SmallVector<Value, 4> bufIndices;
2355 
2356   unsigned rank = memRefType.getRank();
2357   SmallVector<int64_t, 4> fastBufferShape;
2358 
2359   // Compute the extents of the buffer.
2360   std::vector<SmallVector<int64_t, 4>> lbs;
2361   SmallVector<int64_t, 8> lbDivisors;
2362   lbs.reserve(rank);
2363   Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
2364       &fastBufferShape, &lbs, &lbDivisors);
2365   if (!numElements.hasValue()) {
2366     LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
2367     return failure();
2368   }
2369 
2370   if (numElements.getValue() == 0) {
2371     LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
2372     *sizeInBytes = 0;
2373     return success();
2374   }
2375 
2376   SmallVector<AffineMap, 4> lbMaps(rank), ubMaps(rank);
2377   for (unsigned i = 0; i < rank; ++i)
2378     region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
2379 
2380   const FlatAffineConstraints *cst = region.getConstraints();
2381   // 'regionSymbols' hold values that this memory region is symbolic/parametric
2382   // on; these typically include loop IVs surrounding the level at which the
2383   // copy generation is being done or other valid symbols in MLIR.
2384   SmallVector<Value, 8> regionSymbols;
2385   cst->getIdValues(rank, cst->getNumIds(), &regionSymbols);
2386 
2387   // Construct the index expressions for the fast memory buffer. The index
2388   // expression for a particular dimension of the fast buffer is obtained by
2389   // subtracting out the lower bound on the original memref's data region
2390   // along the corresponding dimension.
2391 
2392   // Index start offsets for faster memory buffer relative to the original.
2393   SmallVector<AffineExpr, 4> fastBufOffsets;
2394   fastBufOffsets.reserve(rank);
2395   for (unsigned d = 0; d < rank; d++) {
2396     assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
2397 
2398     AffineExpr offset = top.getAffineConstantExpr(0);
2399     for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++)
2400       offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
2401     assert(lbDivisors[d] > 0);
2402     offset =
2403         (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
2404 
2405     // Set copy start location for this dimension in the lower memory space
2406     // memref.
2407     if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
2408       auto indexVal = caf.getValue();
2409       if (indexVal == 0) {
2410         memIndices.push_back(zeroIndex);
2411       } else {
2412         memIndices.push_back(
2413             top.create<ConstantIndexOp>(loc, indexVal).getResult());
2414       }
2415     } else {
2416       // The coordinate for the start location is just the lower bound along the
2417       // corresponding dimension on the memory region (stored in 'offset').
2418       auto map = AffineMap::get(
2419           cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
2420       memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
2421     }
2422     // The fast buffer is copied into at location zero; addressing is relative.
2423     bufIndices.push_back(zeroIndex);
2424 
2425     // Record the offsets since they are needed to remap the memory accesses of
2426     // the original memref further below.
2427     fastBufOffsets.push_back(offset);
2428   }
2429 
2430   // The faster memory space buffer.
2431   Value fastMemRef;
2432 
2433   // Check if a buffer was already created.
2434   bool existingBuf = fastBufferMap.count(memref) > 0;
2435   if (!existingBuf) {
2436     AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
2437     auto fastMemRefType =
2438         MemRefType::get(fastBufferShape, memRefType.getElementType(),
2439                         fastBufferLayout, copyOptions.fastMemorySpace);
2440 
2441     // Create the fast memory space buffer just before the 'affine.for'
2442     // operation.
2443     fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType).getResult();
2444     // Record it.
2445     fastBufferMap[memref] = fastMemRef;
2446     // fastMemRefType is a constant shaped memref.
2447     *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
2448     LLVM_DEBUG(emitRemarkForBlock(*block)
2449                << "Creating fast buffer of type " << fastMemRefType
2450                << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
2451                << " KiB\n");
2452   } else {
2453     // Reuse the one already created.
2454     fastMemRef = fastBufferMap[memref];
2455     *sizeInBytes = 0;
2456   }
2457 
2458   auto numElementsSSA =
2459       top.create<ConstantIndexOp>(loc, numElements.getValue());
2460 
2461   Value dmaStride = nullptr;
2462   Value numEltPerDmaStride = nullptr;
2463   if (copyOptions.generateDma) {
2464     SmallVector<StrideInfo, 4> dmaStrideInfos;
2465     getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos);
2466 
2467     // TODO: use all stride levels once DmaStartOp is extended for
2468     // multi-level strides.
2469     if (dmaStrideInfos.size() > 1) {
2470       LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
2471       return failure();
2472     }
2473 
2474     if (!dmaStrideInfos.empty()) {
2475       dmaStride = top.create<ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
2476       numEltPerDmaStride =
2477           top.create<ConstantIndexOp>(loc, dmaStrideInfos[0].numEltPerStride);
2478     }
2479   }
2480 
2481   // Record the last operation where we want the memref replacement to end. We
2482   // later do the memref replacement only in [begin, postDomFilter] so
2483   // that the original memref's used in the data movement code themselves don't
2484   // get replaced.
2485   auto postDomFilter = std::prev(end);
2486 
2487   // Create fully composed affine maps for each memref.
2488   auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
2489   fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
2490   auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
2491   fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
2492 
2493   if (!copyOptions.generateDma) {
2494     // Point-wise copy generation.
2495     auto copyNest =
2496         generatePointWiseCopy(loc, memref, fastMemRef, lbMaps,
2497                               /*lbOperands=*/regionSymbols, ubMaps,
2498                               /*ubOperands=*/regionSymbols, fastBufOffsets,
2499                               /*isCopyOut=*/region.isWrite(), b);
2500 
2501     // Record this so that we can skip it from yet another copy.
2502     copyNests.insert(copyNest);
2503 
2504     // Since new ops are being appended (for copy out's), adjust the end to
2505     // mark end of block range being processed if necessary.
2506     if (region.isWrite() && isCopyOutAtEndOfBlock)
2507       *nEnd = Block::iterator(copyNest.getOperation());
2508   } else {
2509     // DMA generation.
2510     // Create a tag (single element 1-d memref) for the DMA.
2511     auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
2512                                          copyOptions.tagMemorySpace);
2513     auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
2514 
2515     SmallVector<Value, 4> tagIndices({zeroIndex});
2516     auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
2517     fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
2518     if (!region.isWrite()) {
2519       // DMA non-blocking read from original buffer to fast buffer.
2520       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
2521                                  fastMemRef, bufAffineMap, bufIndices,
2522                                  tagMemRef, tagAffineMap, tagIndices,
2523                                  numElementsSSA, dmaStride, numEltPerDmaStride);
2524     } else {
2525       // DMA non-blocking write from fast buffer to the original memref.
2526       auto op = b.create<AffineDmaStartOp>(
2527           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
2528           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
2529           dmaStride, numEltPerDmaStride);
2530       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
2531       // end to mark end of block range being processed.
2532       if (isCopyOutAtEndOfBlock)
2533         *nEnd = Block::iterator(op.getOperation());
2534     }
2535 
2536     // Matching DMA wait to block on completion; tag always has a 0 index.
2537     b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
2538                               numElementsSSA);
2539 
2540     // Generate dealloc for the tag.
2541     auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef);
2542     if (*nEnd == end && isCopyOutAtEndOfBlock)
2543       // Since new ops are being appended (for outgoing DMAs), adjust the end to
2544       // mark end of range of the original.
2545       *nEnd = Block::iterator(tagDeallocOp.getOperation());
2546   }
2547 
2548   // Generate dealloc for the buffer.
2549   if (!existingBuf) {
2550     auto bufDeallocOp = epilogue.create<DeallocOp>(loc, fastMemRef);
2551     // When generating pointwise copies, `nEnd' has to be set to deallocOp on
2552     // the fast buffer (since it marks the new end insertion point).
2553     if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
2554       *nEnd = Block::iterator(bufDeallocOp.getOperation());
2555   }
2556 
2557   // Replace all uses of the old memref with the faster one while remapping
2558   // access indices (subtracting out lower bound offsets for each dimension).
2559   // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
2560   // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
2561   // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
2562   // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
2563   // d2, d3 correspond to the original indices (%i, %j).
2564   SmallVector<AffineExpr, 4> remapExprs;
2565   remapExprs.reserve(rank);
2566   for (unsigned i = 0; i < rank; i++) {
2567     // The starting operands of indexRemap will be regionSymbols (the symbols on
2568     // which the memref region is parametric); then those corresponding to
2569     // the memref's original indices follow.
2570     auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
2571     remapExprs.push_back(dimExpr - fastBufOffsets[i]);
2572   }
2573   auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs,
2574                                    b.getContext());
2575 
2576   // Record the begin since it may be invalidated by memref replacement.
2577   Block::iterator prevOfBegin;
2578   bool isBeginAtStartOfBlock = (begin == block->begin());
2579   if (!isBeginAtStartOfBlock)
2580     prevOfBegin = std::prev(begin);
2581 
2582   // *Only* those uses within the range [begin, end) of 'block' are replaced.
2583   replaceAllMemRefUsesWith(memref, fastMemRef,
2584                            /*extraIndices=*/{}, indexRemap,
2585                            /*extraOperands=*/regionSymbols,
2586                            /*symbolOperands=*/{},
2587                            /*domInstFilter=*/&*begin,
2588                            /*postDomInstFilter=*/&*postDomFilter);
2589 
2590   *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
2591 
2592   return success();
2593 }
2594 
2595 /// Construct the memref region to just include the entire memref. Returns false
2596 /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
2597 /// enclosing loop IVs of `op` (starting from the outermost) that the region
2598 /// is parametric on.
getFullMemRefAsRegion(Operation * op,unsigned numParamLoopIVs,MemRefRegion * region)2599 static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
2600                                   MemRefRegion *region) {
2601   unsigned rank;
2602   if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
2603     rank = loadOp.getMemRefType().getRank();
2604     region->memref = loadOp.getMemRef();
2605     region->setWrite(false);
2606   } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
2607     rank = storeOp.getMemRefType().getRank();
2608     region->memref = storeOp.getMemRef();
2609     region->setWrite(true);
2610   } else {
2611     assert(false && "expected load or store op");
2612     return false;
2613   }
2614   auto memRefType = region->memref.getType().cast<MemRefType>();
2615   if (!memRefType.hasStaticShape())
2616     return false;
2617 
2618   auto *regionCst = region->getConstraints();
2619 
2620   // Just get the first numSymbols IVs, which the memref region is parametric
2621   // on.
2622   SmallVector<AffineForOp, 4> ivs;
2623   getLoopIVs(*op, &ivs);
2624   ivs.resize(numParamLoopIVs);
2625   SmallVector<Value, 4> symbols;
2626   extractForInductionVars(ivs, &symbols);
2627   regionCst->reset(rank, numParamLoopIVs, 0);
2628   regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
2629 
2630   // Memref dim sizes provide the bounds.
2631   for (unsigned d = 0; d < rank; d++) {
2632     auto dimSize = memRefType.getDimSize(d);
2633     assert(dimSize > 0 && "filtered dynamic shapes above");
2634     regionCst->addConstantLowerBound(d, 0);
2635     regionCst->addConstantUpperBound(d, dimSize - 1);
2636   }
2637   return true;
2638 }
2639 
2640 /// Performs explicit copying for the contiguous sequence of operations in the
2641 /// block iterator range [`begin', `end'), where `end' can't be past the
2642 /// terminator of the block (since additional operations are potentially
2643 /// inserted right before `end`. Returns the total size of fast memory space
2644 /// buffers used. `copyOptions` provides various parameters, and the output
2645 /// argument `copyNests` is the set of all copy nests inserted, each represented
2646 /// by its root affine.for. Since we generate alloc's and dealloc's for all fast
2647 /// buffers (before and after the range of operations resp. or at a hoisted
2648 /// position), all of the fast memory capacity is assumed to be available for
2649 /// processing this block range. When 'filterMemRef' is specified, copies are
2650 /// only generated for the provided MemRef.
affineDataCopyGenerate(Block::iterator begin,Block::iterator end,const AffineCopyOptions & copyOptions,Optional<Value> filterMemRef,DenseSet<Operation * > & copyNests)2651 uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
2652                                       Block::iterator end,
2653                                       const AffineCopyOptions &copyOptions,
2654                                       Optional<Value> filterMemRef,
2655                                       DenseSet<Operation *> &copyNests) {
2656   if (begin == end)
2657     return 0;
2658 
2659   assert(begin->getBlock() == std::prev(end)->getBlock() &&
2660          "Inconsistent block begin/end args");
2661   assert(end != end->getBlock()->end() && "end can't be the block terminator");
2662 
2663   Block *block = begin->getBlock();
2664 
2665   // Copies will be generated for this depth, i.e., symbolic in all loops
2666   // surrounding the this block range.
2667   unsigned copyDepth = getNestingDepth(&*begin);
2668 
2669   LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
2670                           << "\n");
2671   LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
2672   LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
2673 
2674   // List of memory regions to copy for. We need a map vector to have a
2675   // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
2676   // since the alloc's for example are identical except for the SSA id.
2677   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
2678   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
2679 
2680   // Map from original memref's to the fast buffers that their accesses are
2681   // replaced with.
2682   DenseMap<Value, Value> fastBufferMap;
2683 
2684   // To check for errors when walking the block.
2685   bool error = false;
2686 
2687   // Walk this range of operations  to gather all memory regions.
2688   block->walk(begin, end, [&](Operation *opInst) {
2689     // Gather regions to allocate to buffers in faster memory space.
2690     if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
2691       if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) ||
2692           (loadOp.getMemRefType().getMemorySpace() !=
2693            copyOptions.slowMemorySpace))
2694         return;
2695     } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
2696       if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) ||
2697           storeOp.getMemRefType().getMemorySpace() !=
2698               copyOptions.slowMemorySpace)
2699         return;
2700     } else {
2701       // Neither load nor a store op.
2702       return;
2703     }
2704 
2705     // Compute the MemRefRegion accessed.
2706     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2707     if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
2708                                /*addMemRefDimBounds=*/false))) {
2709       LLVM_DEBUG(llvm::dbgs()
2710                  << "Error obtaining memory region: semi-affine maps?\n");
2711       LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
2712       if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2713         LLVM_DEBUG(
2714             opInst->emitError("non-constant memref sizes not yet supported"));
2715         error = true;
2716         return;
2717       }
2718     }
2719 
2720     // Each memref has a single buffer associated with it irrespective of how
2721     // many load's and store's happen on it.
2722     // TODO: in the future, when regions don't intersect and satisfy
2723     // other properties (based on load/store regions), we could consider
2724     // multiple buffers per memref.
2725 
2726     // Add to the appropriate region if it's not already in it, or take a
2727     // bounding box union with the existing one if it's already in there.
2728     // Note that a memref may have both read and write regions - so update the
2729     // region in the other list if one exists (write in case of read and vice
2730     // versa) since there is a single bounding box for a memref across all reads
2731     // and writes that happen on it.
2732 
2733     // Attempts to update; returns true if 'region' exists in targetRegions.
2734     auto updateRegion =
2735         [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2736                 &targetRegions) {
2737           const auto it = targetRegions.find(region->memref);
2738           if (it == targetRegions.end())
2739             return false;
2740 
2741           // Perform a union with the existing region.
2742           if (failed(it->second->unionBoundingBox(*region))) {
2743             LLVM_DEBUG(llvm::dbgs()
2744                        << "Memory region bounding box failed; "
2745                           "over-approximating to the entire memref\n");
2746             // If the union fails, we will overapproximate.
2747             if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2748               LLVM_DEBUG(opInst->emitError(
2749                   "non-constant memref sizes not yet supported"));
2750               error = true;
2751               return true;
2752             }
2753             it->second->getConstraints()->clearAndCopyFrom(
2754                 *region->getConstraints());
2755           } else {
2756             // Union was computed and stored in 'it->second': copy to 'region'.
2757             region->getConstraints()->clearAndCopyFrom(
2758                 *it->second->getConstraints());
2759           }
2760           return true;
2761         };
2762 
2763     bool existsInRead = updateRegion(readRegions);
2764     if (error)
2765       return;
2766     bool existsInWrite = updateRegion(writeRegions);
2767     if (error)
2768       return;
2769 
2770     // Finally add it to the region list.
2771     if (region->isWrite() && !existsInWrite) {
2772       writeRegions[region->memref] = std::move(region);
2773     } else if (!region->isWrite() && !existsInRead) {
2774       readRegions[region->memref] = std::move(region);
2775     }
2776   });
2777 
2778   if (error) {
2779     begin->emitError(
2780         "copy generation failed for one or more memref's in this block\n");
2781     return 0;
2782   }
2783 
2784   uint64_t totalCopyBuffersSizeInBytes = 0;
2785   bool ret = true;
2786   auto processRegions =
2787       [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2788               &regions) {
2789         for (const auto &regionEntry : regions) {
2790           // For each region, hoist copy in/out past all hoistable
2791           // 'affine.for's.
2792           Block::iterator copyInPlacementStart, copyOutPlacementStart;
2793           Block *copyPlacementBlock;
2794           findHighestBlockForPlacement(
2795               *regionEntry.second, *block, begin, end, &copyPlacementBlock,
2796               &copyInPlacementStart, &copyOutPlacementStart);
2797 
2798           uint64_t sizeInBytes;
2799           Block::iterator nBegin, nEnd;
2800           LogicalResult iRet = generateCopy(
2801               *regionEntry.second, block, begin, end, copyPlacementBlock,
2802               copyInPlacementStart, copyOutPlacementStart, copyOptions,
2803               fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
2804           if (succeeded(iRet)) {
2805             // begin/end could have been invalidated, and need update.
2806             begin = nBegin;
2807             end = nEnd;
2808             totalCopyBuffersSizeInBytes += sizeInBytes;
2809           }
2810           ret = ret & succeeded(iRet);
2811         }
2812       };
2813   processRegions(readRegions);
2814   processRegions(writeRegions);
2815 
2816   if (!ret) {
2817     begin->emitError(
2818         "copy generation failed for one or more memref's in this block\n");
2819     return totalCopyBuffersSizeInBytes;
2820   }
2821 
2822   // For a range of operations, a note will be emitted at the caller.
2823   AffineForOp forOp;
2824   uint64_t sizeInKib = llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024);
2825   if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
2826     forOp.emitRemark()
2827         << sizeInKib
2828         << " KiB of copy buffers in fast memory space for this block\n";
2829   }
2830 
2831   if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
2832     StringRef str = "Total size of all copy buffers' for this block "
2833                     "exceeds fast memory capacity\n";
2834     block->getParentOp()->emitWarning(str);
2835   }
2836 
2837   return totalCopyBuffersSizeInBytes;
2838 }
2839 
2840 // A convenience version of affineDataCopyGenerate for all ops in the body of
2841 // an AffineForOp.
affineDataCopyGenerate(AffineForOp forOp,const AffineCopyOptions & copyOptions,Optional<Value> filterMemRef,DenseSet<Operation * > & copyNests)2842 uint64_t mlir::affineDataCopyGenerate(AffineForOp forOp,
2843                                       const AffineCopyOptions &copyOptions,
2844                                       Optional<Value> filterMemRef,
2845                                       DenseSet<Operation *> &copyNests) {
2846   return affineDataCopyGenerate(forOp.getBody()->begin(),
2847                                 std::prev(forOp.getBody()->end()), copyOptions,
2848                                 filterMemRef, copyNests);
2849 }
2850 
generateCopyForMemRegion(const MemRefRegion & memrefRegion,Operation * analyzedOp,const AffineCopyOptions & copyOptions,CopyGenerateResult & result)2851 LogicalResult mlir::generateCopyForMemRegion(
2852     const MemRefRegion &memrefRegion, Operation *analyzedOp,
2853     const AffineCopyOptions &copyOptions, CopyGenerateResult &result) {
2854   Block *block = analyzedOp->getBlock();
2855   auto begin = analyzedOp->getIterator();
2856   auto end = std::next(begin);
2857   DenseMap<Value, Value> fastBufferMap;
2858   DenseSet<Operation *> copyNests;
2859 
2860   auto err = generateCopy(memrefRegion, block, begin, end, block, begin, end,
2861                           copyOptions, fastBufferMap, copyNests,
2862                           &result.sizeInBytes, &begin, &end);
2863   if (failed(err))
2864     return err;
2865 
2866   result.alloc =
2867       fastBufferMap.find(memrefRegion.memref)->second.getDefiningOp();
2868   assert(copyNests.size() <= 1 && "At most one copy nest is expected.");
2869   result.copyNest = copyNests.empty() ? nullptr : *copyNests.begin();
2870   return success();
2871 }
2872 
2873 /// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'.
2874 static void
gatherLoopsInBlock(Block * block,unsigned currLoopDepth,std::vector<SmallVector<AffineForOp,2>> & depthToLoops)2875 gatherLoopsInBlock(Block *block, unsigned currLoopDepth,
2876                    std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2877   // Add a new empty level to output if it doesn't exist level already.
2878   assert(currLoopDepth <= depthToLoops.size() && "Unexpected currLoopDepth");
2879   if (currLoopDepth == depthToLoops.size())
2880     depthToLoops.push_back(SmallVector<AffineForOp, 2>());
2881 
2882   for (auto &op : *block) {
2883     if (auto forOp = dyn_cast<AffineForOp>(op)) {
2884       depthToLoops[currLoopDepth].push_back(forOp);
2885       gatherLoopsInBlock(forOp.getBody(), currLoopDepth + 1, depthToLoops);
2886     }
2887   }
2888 }
2889 
2890 /// Gathers all AffineForOps in 'func' grouped by loop depth.
gatherLoops(FuncOp func,std::vector<SmallVector<AffineForOp,2>> & depthToLoops)2891 void mlir::gatherLoops(FuncOp func,
2892                        std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2893   for (auto &block : func)
2894     gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops);
2895 
2896   // Remove last loop level from output since it's empty.
2897   if (!depthToLoops.empty()) {
2898     assert(depthToLoops.back().empty() && "Last loop level is not empty?");
2899     depthToLoops.pop_back();
2900   }
2901 }
2902 
2903 // TODO: if necessary, this can be extended to also compose in any
2904 // affine.applys, fold to constant if all result dimensions of the map are
2905 // constant (canonicalizeMapAndOperands below already does this for single
2906 // result bound maps), and use simplifyMap to perform algebraic simplification.
createCanonicalizedAffineForOp(OpBuilder b,Location loc,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step)2907 AffineForOp mlir::createCanonicalizedAffineForOp(
2908     OpBuilder b, Location loc, ValueRange lbOperands, AffineMap lbMap,
2909     ValueRange ubOperands, AffineMap ubMap, int64_t step) {
2910   SmallVector<Value, 4> lowerOperands(lbOperands);
2911   SmallVector<Value, 4> upperOperands(ubOperands);
2912 
2913   fullyComposeAffineMapAndOperands(&lbMap, &lowerOperands);
2914   canonicalizeMapAndOperands(&lbMap, &lowerOperands);
2915   lbMap = removeDuplicateExprs(lbMap);
2916   fullyComposeAffineMapAndOperands(&ubMap, &upperOperands);
2917   canonicalizeMapAndOperands(&ubMap, &upperOperands);
2918   ubMap = removeDuplicateExprs(ubMap);
2919 
2920   return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
2921                                step);
2922 }
2923 
2924 /// Creates an AffineIfOp that encodes the conditional to choose between
2925 /// the constant trip count version and an unknown trip count version of this
2926 /// nest of loops. This is used to separate partial and full tiles if `loops`
2927 /// has the intra-tile loops. The affine.if op is inserted at the builder
2928 /// insertion point of `b`.
createSeparationCondition(MutableArrayRef<AffineForOp> loops,OpBuilder b)2929 static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
2930                                             OpBuilder b) {
2931   if (loops.empty())
2932     return nullptr;
2933 
2934   auto *context = loops[0].getContext();
2935 
2936   FlatAffineConstraints cst;
2937   SmallVector<Operation *, 8> ops;
2938   ops.reserve(loops.size());
2939   for (AffineForOp forOp : loops)
2940     ops.push_back(forOp);
2941   getIndexSet(ops, &cst);
2942 
2943   // Remove constraints that are independent of these loop IVs.
2944   cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
2945 
2946   // Construct the constraint set representing the guard for full tiles. The
2947   // lower bound (and upper bound) corresponding to the full tile should be
2948   // larger (and resp. smaller) than any other lower (or upper bound).
2949   SmallVector<int64_t, 8> fullTileLb, fullTileUb;
2950   for (auto loop : loops) {
2951     (void)loop;
2952     // TODO: Non-unit stride is not an issue to generalize to.
2953     assert(loop.getStep() == 1 && "point loop step expected to be one");
2954     // Mark everything symbols for the purpose of finding a constant diff pair.
2955     cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolIds() -
2956                                1);
2957     unsigned fullTileLbPos, fullTileUbPos;
2958     if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr,
2959                                        /*lbFloorDivisor=*/nullptr,
2960                                        /*ub=*/nullptr, &fullTileLbPos,
2961                                        &fullTileUbPos)) {
2962       LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
2963       return nullptr;
2964     }
2965 
2966     SmallVector<unsigned, 4> lbIndices, ubIndices;
2967     cst.getLowerAndUpperBoundIndices(/*pos=*/0, &lbIndices, &ubIndices);
2968 
2969     auto fLb = cst.getInequality(fullTileLbPos);
2970     auto fUb = cst.getInequality(fullTileUbPos);
2971     fullTileLb.assign(fLb.begin(), fLb.end());
2972     fullTileUb.assign(fUb.begin(), fUb.end());
2973 
2974     // Full tile lower bound should be >= than any other lower bound.
2975     for (auto lbIndex : lbIndices)
2976       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2977         cst.atIneq(lbIndex, i) = fullTileLb[i] - cst.atIneq(lbIndex, i);
2978 
2979     // Full tile upper bound should be <= any other upper bound.
2980     for (auto ubIndex : ubIndices)
2981       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2982         cst.atIneq(ubIndex, i) -= fullTileUb[i];
2983 
2984     cst.removeId(0);
2985   }
2986 
2987   // The previous step leads to all zeros for the full tile lb and ub position
2988   // itself; remove those and any other duplicates / trivial redundancies.
2989   cst.removeTrivialRedundancy();
2990 
2991   // Turn everything into dims conservatively since we earlier turned all
2992   // trailing ids past point loop IV into symbols. Some of these could be outer
2993   // loop IVs; we'll canonicalize anyway.
2994   cst.setDimSymbolSeparation(0);
2995 
2996   IntegerSet ifCondSet = cst.getAsIntegerSet(context);
2997   // ifCondSet can be null if cst was empty -- this can happen if all loops
2998   // in the nest have constant trip counts.
2999   if (!ifCondSet)
3000     return nullptr;
3001 
3002   SmallVector<Value, 4> setOperands;
3003   cst.getIdValues(0, cst.getNumDimAndSymbolIds(), &setOperands);
3004   canonicalizeSetAndOperands(&ifCondSet, &setOperands);
3005   return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
3006                               /*withElseRegion=*/true);
3007 }
3008 
3009 /// Create the full tile loop nest (along with its body).
3010 static LogicalResult
createFullTiles(MutableArrayRef<AffineForOp> inputNest,SmallVectorImpl<AffineForOp> & fullTileLoops,OpBuilder b)3011 createFullTiles(MutableArrayRef<AffineForOp> inputNest,
3012                 SmallVectorImpl<AffineForOp> &fullTileLoops, OpBuilder b) {
3013   fullTileLoops.reserve(inputNest.size());
3014 
3015   // For each loop in the original nest identify a lower/upper bound pair such
3016   // that their difference is a constant.
3017   FlatAffineConstraints cst;
3018   for (auto loop : inputNest) {
3019     // TODO: straightforward to generalize to a non-unit stride.
3020     if (loop.getStep() != 1) {
3021       LLVM_DEBUG(llvm::dbgs()
3022                  << "[tile separation] non-unit stride not implemented\n");
3023       return failure();
3024     }
3025     SmallVector<Operation *, 1> loopOp{loop.getOperation()};
3026     getIndexSet(loopOp, &cst);
3027     // We will mark everything other than this loop IV as symbol for getting a
3028     // pair of <lb, ub> with a constant difference.
3029     cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1);
3030     unsigned lbPos, ubPos;
3031     if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr,
3032                                        /*lbDivisor=*/nullptr, /*ub=*/nullptr,
3033                                        &lbPos, &ubPos) ||
3034         lbPos == ubPos) {
3035       LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
3036                                  "equalities not yet handled\n");
3037       return failure();
3038     }
3039 
3040     // Set all identifiers as dimensions uniformly since some of those marked as
3041     // symbols above could be outer loop IVs (corresponding tile space IVs).
3042     cst.setDimSymbolSeparation(/*newSymbolCount=*/0);
3043 
3044     AffineValueMap lbVmap, ubVmap;
3045     cst.getIneqAsAffineValueMap(/*pos=*/0, lbPos, lbVmap, b.getContext());
3046     cst.getIneqAsAffineValueMap(/*pos=*/0, ubPos, ubVmap, b.getContext());
3047     AffineForOp fullTileLoop = createCanonicalizedAffineForOp(
3048         b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(),
3049         ubVmap.getOperands(), ubVmap.getAffineMap());
3050     b = OpBuilder::atBlockTerminator(fullTileLoop.getBody());
3051     fullTileLoops.push_back(fullTileLoop);
3052   }
3053 
3054   // Add the body for the full tile loop nest.
3055   BlockAndValueMapping operandMap;
3056   for (auto loopEn : llvm::enumerate(inputNest))
3057     operandMap.map(loopEn.value().getInductionVar(),
3058                    fullTileLoops[loopEn.index()].getInductionVar());
3059   b = OpBuilder::atBlockTerminator(fullTileLoops.back().getBody());
3060   for (auto &op : inputNest.back().getBody()->without_terminator())
3061     b.clone(op, operandMap);
3062   return success();
3063 }
3064 
3065 LogicalResult
separateFullTiles(MutableArrayRef<AffineForOp> inputNest,SmallVectorImpl<AffineForOp> * fullTileNest)3066 mlir::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
3067                         SmallVectorImpl<AffineForOp> *fullTileNest) {
3068   if (inputNest.empty())
3069     return success();
3070 
3071   auto firstLoop = inputNest[0];
3072 
3073   // Each successive for op has to be nested in the other.
3074   auto prevLoop = firstLoop;
3075   for (auto loop : inputNest.drop_front(1)) {
3076     assert(loop->getParentOp() == prevLoop && "input not contiguously nested");
3077     prevLoop = loop;
3078   }
3079 
3080   // Create the full tile loop nest.
3081   SmallVector<AffineForOp, 4> fullTileLoops;
3082   OpBuilder b(firstLoop);
3083   if (failed(createFullTiles(inputNest, fullTileLoops, b))) {
3084     if (!fullTileLoops.empty())
3085       fullTileLoops.front().erase();
3086     return failure();
3087   }
3088 
3089   // Create and insert the version select right before the root of the nest.
3090   b = OpBuilder(firstLoop);
3091   AffineIfOp ifOp = createSeparationCondition(inputNest, b);
3092   if (!ifOp) {
3093     fullTileLoops.front().erase();
3094     LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
3095                                "separation condition\n");
3096     return failure();
3097   }
3098 
3099   // Move the full tile into the then block.
3100   Block *thenBlock = ifOp.getThenBlock();
3101   AffineForOp outermostFullTileLoop = fullTileLoops[0];
3102   thenBlock->getOperations().splice(
3103       std::prev(thenBlock->end()),
3104       outermostFullTileLoop->getBlock()->getOperations(),
3105       Block::iterator(outermostFullTileLoop));
3106 
3107   // Move the partial tile into the else block. The partial tile is the same as
3108   // the original loop nest.
3109   Block *elseBlock = ifOp.getElseBlock();
3110   elseBlock->getOperations().splice(std::prev(elseBlock->end()),
3111                                     firstLoop->getBlock()->getOperations(),
3112                                     Block::iterator(firstLoop));
3113 
3114   if (fullTileNest)
3115     *fullTileNest = std::move(fullTileLoops);
3116 
3117   return success();
3118 }
3119