1 //===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Utils.h"
15 
16 #include "mlir/Analysis/AffineAnalysis.h"
17 #include "mlir/Analysis/AffineStructures.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 using namespace mlir;
27 
28 /// Return true if this operation dereferences one or more memref's.
29 // Temporary utility: will be replaced when this is modeled through
30 // side-effects/op traits. TODO
isMemRefDereferencingOp(Operation & op)31 static bool isMemRefDereferencingOp(Operation &op) {
32   return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
33              AffineDmaWaitOp>(op);
34 }
35 
36 /// Return the AffineMapAttr associated with memory 'op' on 'memref'.
getAffineMapAttrForMemRef(Operation * op,Value memref)37 static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
38   return TypeSwitch<Operation *, NamedAttribute>(op)
39       .Case<AffineDmaStartOp, AffineReadOpInterface, AffinePrefetchOp,
40             AffineWriteOpInterface, AffineDmaWaitOp>(
41           [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
42 }
43 
44 // Perform the replacement in `op`.
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,Operation * op,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,bool allowNonDereferencingOps)45 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
46                                              Operation *op,
47                                              ArrayRef<Value> extraIndices,
48                                              AffineMap indexRemap,
49                                              ArrayRef<Value> extraOperands,
50                                              ArrayRef<Value> symbolOperands,
51                                              bool allowNonDereferencingOps) {
52   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
53   (void)newMemRefRank; // unused in opt mode
54   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
55   (void)oldMemRefRank; // unused in opt mode
56   if (indexRemap) {
57     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
58            "symbolic operand count mismatch");
59     assert(indexRemap.getNumInputs() ==
60            extraOperands.size() + oldMemRefRank + symbolOperands.size());
61     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
62   } else {
63     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
64   }
65 
66   // Assert same elemental type.
67   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
68          newMemRef.getType().cast<MemRefType>().getElementType());
69 
70   SmallVector<unsigned, 2> usePositions;
71   for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
72     if (opEntry.value() == oldMemRef)
73       usePositions.push_back(opEntry.index());
74   }
75 
76   // If memref doesn't appear, nothing to do.
77   if (usePositions.empty())
78     return success();
79 
80   if (usePositions.size() > 1) {
81     // TODO: extend it for this case when needed (rare).
82     assert(false && "multiple dereferencing uses in a single op not supported");
83     return failure();
84   }
85 
86   unsigned memRefOperandPos = usePositions.front();
87 
88   OpBuilder builder(op);
89   // The following checks if op is dereferencing memref and performs the access
90   // index rewrites.
91   if (!isMemRefDereferencingOp(*op)) {
92     if (!allowNonDereferencingOps)
93       // Failure: memref used in a non-dereferencing context (potentially
94       // escapes); no replacement in these cases unless allowNonDereferencingOps
95       // is set.
96       return failure();
97     op->setOperand(memRefOperandPos, newMemRef);
98     return success();
99   }
100   // Perform index rewrites for the dereferencing op and then replace the op
101   NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
102   AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
103   unsigned oldMapNumInputs = oldMap.getNumInputs();
104   SmallVector<Value, 4> oldMapOperands(
105       op->operand_begin() + memRefOperandPos + 1,
106       op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
107 
108   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
109   SmallVector<Value, 4> oldMemRefOperands;
110   SmallVector<Value, 4> affineApplyOps;
111   oldMemRefOperands.reserve(oldMemRefRank);
112   if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
113     for (auto resultExpr : oldMap.getResults()) {
114       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
115                                          oldMap.getNumSymbols(), resultExpr);
116       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
117                                                 oldMapOperands);
118       oldMemRefOperands.push_back(afOp);
119       affineApplyOps.push_back(afOp);
120     }
121   } else {
122     oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
123   }
124 
125   // Construct new indices as a remap of the old ones if a remapping has been
126   // provided. The indices of a memref come right after it, i.e.,
127   // at position memRefOperandPos + 1.
128   SmallVector<Value, 4> remapOperands;
129   remapOperands.reserve(extraOperands.size() + oldMemRefRank +
130                         symbolOperands.size());
131   remapOperands.append(extraOperands.begin(), extraOperands.end());
132   remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
133   remapOperands.append(symbolOperands.begin(), symbolOperands.end());
134 
135   SmallVector<Value, 4> remapOutputs;
136   remapOutputs.reserve(oldMemRefRank);
137 
138   if (indexRemap &&
139       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
140     // Remapped indices.
141     for (auto resultExpr : indexRemap.getResults()) {
142       auto singleResMap = AffineMap::get(
143           indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
144       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
145                                                 remapOperands);
146       remapOutputs.push_back(afOp);
147       affineApplyOps.push_back(afOp);
148     }
149   } else {
150     // No remapping specified.
151     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
152   }
153 
154   SmallVector<Value, 4> newMapOperands;
155   newMapOperands.reserve(newMemRefRank);
156 
157   // Prepend 'extraIndices' in 'newMapOperands'.
158   for (Value extraIndex : extraIndices) {
159     assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
160            "single result op's expected to generate these indices");
161     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
162            "invalid memory op index");
163     newMapOperands.push_back(extraIndex);
164   }
165 
166   // Append 'remapOutputs' to 'newMapOperands'.
167   newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
168 
169   // Create new fully composed AffineMap for new op to be created.
170   assert(newMapOperands.size() == newMemRefRank);
171   auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
172   // TODO: Avoid creating/deleting temporary AffineApplyOps here.
173   fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
174   newMap = simplifyAffineMap(newMap);
175   canonicalizeMapAndOperands(&newMap, &newMapOperands);
176   // Remove any affine.apply's that became dead as a result of composition.
177   for (Value value : affineApplyOps)
178     if (value.use_empty())
179       value.getDefiningOp()->erase();
180 
181   OperationState state(op->getLoc(), op->getName());
182   // Construct the new operation using this memref.
183   state.operands.reserve(op->getNumOperands() + extraIndices.size());
184   // Insert the non-memref operands.
185   state.operands.append(op->operand_begin(),
186                         op->operand_begin() + memRefOperandPos);
187   // Insert the new memref value.
188   state.operands.push_back(newMemRef);
189 
190   // Insert the new memref map operands.
191   state.operands.append(newMapOperands.begin(), newMapOperands.end());
192 
193   // Insert the remaining operands unmodified.
194   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
195                             oldMapNumInputs,
196                         op->operand_end());
197 
198   // Result types don't change. Both memref's are of the same elemental type.
199   state.types.reserve(op->getNumResults());
200   for (auto result : op->getResults())
201     state.types.push_back(result.getType());
202 
203   // Add attribute for 'newMap', other Attributes do not change.
204   auto newMapAttr = AffineMapAttr::get(newMap);
205   for (auto namedAttr : op->getAttrs()) {
206     if (namedAttr.first == oldMapAttrPair.first)
207       state.attributes.push_back({namedAttr.first, newMapAttr});
208     else
209       state.attributes.push_back(namedAttr);
210   }
211 
212   // Create the new operation.
213   auto *repOp = builder.createOperation(state);
214   op->replaceAllUsesWith(repOp);
215   op->erase();
216 
217   return success();
218 }
219 
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,Operation * domInstFilter,Operation * postDomInstFilter,bool allowNonDereferencingOps,bool replaceInDeallocOp)220 LogicalResult mlir::replaceAllMemRefUsesWith(
221     Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
222     AffineMap indexRemap, ArrayRef<Value> extraOperands,
223     ArrayRef<Value> symbolOperands, Operation *domInstFilter,
224     Operation *postDomInstFilter, bool allowNonDereferencingOps,
225     bool replaceInDeallocOp) {
226   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
227   (void)newMemRefRank; // unused in opt mode
228   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
229   (void)oldMemRefRank;
230   if (indexRemap) {
231     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
232            "symbol operand count mismatch");
233     assert(indexRemap.getNumInputs() ==
234            extraOperands.size() + oldMemRefRank + symbolOperands.size());
235     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
236   } else {
237     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
238   }
239 
240   // Assert same elemental type.
241   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
242          newMemRef.getType().cast<MemRefType>().getElementType());
243 
244   std::unique_ptr<DominanceInfo> domInfo;
245   std::unique_ptr<PostDominanceInfo> postDomInfo;
246   if (domInstFilter)
247     domInfo = std::make_unique<DominanceInfo>(
248         domInstFilter->getParentOfType<FuncOp>());
249 
250   if (postDomInstFilter)
251     postDomInfo = std::make_unique<PostDominanceInfo>(
252         postDomInstFilter->getParentOfType<FuncOp>());
253 
254   // Walk all uses of old memref; collect ops to perform replacement. We use a
255   // DenseSet since an operation could potentially have multiple uses of a
256   // memref (although rare), and the replacement later is going to erase ops.
257   DenseSet<Operation *> opsToReplace;
258   for (auto *op : oldMemRef.getUsers()) {
259     // Skip this use if it's not dominated by domInstFilter.
260     if (domInstFilter && !domInfo->dominates(domInstFilter, op))
261       continue;
262 
263     // Skip this use if it's not post-dominated by postDomInstFilter.
264     if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
265       continue;
266 
267     // Skip dealloc's - no replacement is necessary, and a memref replacement
268     // at other uses doesn't hurt these dealloc's.
269     if (isa<DeallocOp>(op) && !replaceInDeallocOp)
270       continue;
271 
272     // Check if the memref was used in a non-dereferencing context. It is fine
273     // for the memref to be used in a non-dereferencing way outside of the
274     // region where this replacement is happening.
275     if (!isMemRefDereferencingOp(*op)) {
276       if (!allowNonDereferencingOps)
277         return failure();
278       // Currently we support the following non-dereferencing ops to be a
279       // candidate for replacement: Dealloc, CallOp and ReturnOp.
280       // TODO: Add support for other kinds of ops.
281       if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
282         return failure();
283     }
284 
285     // We'll first collect and then replace --- since replacement erases the op
286     // that has the use, and that op could be postDomFilter or domFilter itself!
287     opsToReplace.insert(op);
288   }
289 
290   for (auto *op : opsToReplace) {
291     if (failed(replaceAllMemRefUsesWith(
292             oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
293             symbolOperands, allowNonDereferencingOps)))
294       llvm_unreachable("memref replacement guaranteed to succeed here");
295   }
296 
297   return success();
298 }
299 
300 /// Given an operation, inserts one or more single result affine
301 /// apply operations, results of which are exclusively used by this operation
302 /// operation. The operands of these newly created affine apply ops are
303 /// guaranteed to be loop iterators or terminal symbols of a function.
304 ///
305 /// Before
306 ///
307 /// affine.for %i = 0 to #map(%N)
308 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
309 ///   "send"(%idx, %A, ...)
310 ///   "compute"(%idx)
311 ///
312 /// After
313 ///
314 /// affine.for %i = 0 to #map(%N)
315 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
316 ///   "send"(%idx, %A, ...)
317 ///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
318 ///   "compute"(%idx_)
319 ///
320 /// This allows applying different transformations on send and compute (for eg.
321 /// different shifts/delays).
322 ///
323 /// Returns nullptr either if none of opInst's operands were the result of an
324 /// affine.apply and thus there was no affine computation slice to create, or if
325 /// all the affine.apply op's supplying operands to this opInst did not have any
326 /// uses besides this opInst; otherwise returns the list of affine.apply
327 /// operations created in output argument `sliceOps`.
createAffineComputationSlice(Operation * opInst,SmallVectorImpl<AffineApplyOp> * sliceOps)328 void mlir::createAffineComputationSlice(
329     Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
330   // Collect all operands that are results of affine apply ops.
331   SmallVector<Value, 4> subOperands;
332   subOperands.reserve(opInst->getNumOperands());
333   for (auto operand : opInst->getOperands())
334     if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
335       subOperands.push_back(operand);
336 
337   // Gather sequence of AffineApplyOps reachable from 'subOperands'.
338   SmallVector<Operation *, 4> affineApplyOps;
339   getReachableAffineApplyOps(subOperands, affineApplyOps);
340   // Skip transforming if there are no affine maps to compose.
341   if (affineApplyOps.empty())
342     return;
343 
344   // Check if all uses of the affine apply op's lie only in this op op, in
345   // which case there would be nothing to do.
346   bool localized = true;
347   for (auto *op : affineApplyOps) {
348     for (auto result : op->getResults()) {
349       for (auto *user : result.getUsers()) {
350         if (user != opInst) {
351           localized = false;
352           break;
353         }
354       }
355     }
356   }
357   if (localized)
358     return;
359 
360   OpBuilder builder(opInst);
361   SmallVector<Value, 4> composedOpOperands(subOperands);
362   auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
363   fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
364 
365   // Create an affine.apply for each of the map results.
366   sliceOps->reserve(composedMap.getNumResults());
367   for (auto resultExpr : composedMap.getResults()) {
368     auto singleResMap = AffineMap::get(composedMap.getNumDims(),
369                                        composedMap.getNumSymbols(), resultExpr);
370     sliceOps->push_back(builder.create<AffineApplyOp>(
371         opInst->getLoc(), singleResMap, composedOpOperands));
372   }
373 
374   // Construct the new operands that include the results from the composed
375   // affine apply op above instead of existing ones (subOperands). So, they
376   // differ from opInst's operands only for those operands in 'subOperands', for
377   // which they will be replaced by the corresponding one from 'sliceOps'.
378   SmallVector<Value, 4> newOperands(opInst->getOperands());
379   for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
380     // Replace the subOperands from among the new operands.
381     unsigned j, f;
382     for (j = 0, f = subOperands.size(); j < f; j++) {
383       if (newOperands[i] == subOperands[j])
384         break;
385     }
386     if (j < subOperands.size()) {
387       newOperands[i] = (*sliceOps)[j];
388     }
389   }
390   for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
391     opInst->setOperand(idx, newOperands[idx]);
392   }
393 }
394 
395 // TODO: Currently works for static memrefs with a single layout map.
normalizeMemRef(AllocOp allocOp)396 LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
397   MemRefType memrefType = allocOp.getType();
398   OpBuilder b(allocOp);
399 
400   // Fetch a new memref type after normalizing the old memref to have an
401   // identity map layout.
402   MemRefType newMemRefType =
403       normalizeMemRefType(memrefType, b, allocOp.symbolOperands().size());
404   if (newMemRefType == memrefType)
405     // Either memrefType already had an identity map or the map couldn't be
406     // transformed to an identity map.
407     return failure();
408 
409   Value oldMemRef = allocOp.getResult();
410 
411   SmallVector<Value, 4> symbolOperands(allocOp.symbolOperands());
412   AllocOp newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType,
413                                        allocOp.alignmentAttr());
414   AffineMap layoutMap = memrefType.getAffineMaps().front();
415   // Replace all uses of the old memref.
416   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
417                                       /*extraIndices=*/{},
418                                       /*indexRemap=*/layoutMap,
419                                       /*extraOperands=*/{},
420                                       /*symbolOperands=*/symbolOperands,
421                                       /*domInstFilter=*/nullptr,
422                                       /*postDomInstFilter=*/nullptr,
423                                       /*allowDereferencingOps=*/true))) {
424     // If it failed (due to escapes for example), bail out.
425     newAlloc.erase();
426     return failure();
427   }
428   // Replace any uses of the original alloc op and erase it. All remaining uses
429   // have to be dealloc's; RAMUW above would've failed otherwise.
430   assert(llvm::all_of(oldMemRef.getUsers(),
431                       [](Operation *op) { return isa<DeallocOp>(op); }));
432   oldMemRef.replaceAllUsesWith(newAlloc);
433   allocOp.erase();
434   return success();
435 }
436 
normalizeMemRefType(MemRefType memrefType,OpBuilder b,unsigned numSymbolicOperands)437 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
438                                      unsigned numSymbolicOperands) {
439   unsigned rank = memrefType.getRank();
440   if (rank == 0)
441     return memrefType;
442 
443   ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
444   if (layoutMaps.empty() ||
445       layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
446     // Either no maps is associated with this memref or this memref has
447     // a trivial (identity) map.
448     return memrefType;
449   }
450 
451   // We don't do any checks for one-to-one'ness; we assume that it is
452   // one-to-one.
453 
454   // TODO: Only for static memref's for now.
455   if (memrefType.getNumDynamicDims() > 0)
456     return memrefType;
457 
458   // We have a single map that is not an identity map. Create a new memref
459   // with the right shape and an identity layout map.
460   ArrayRef<int64_t> shape = memrefType.getShape();
461   // FlatAffineConstraint may later on use symbolicOperands.
462   FlatAffineConstraints fac(rank, numSymbolicOperands);
463   for (unsigned d = 0; d < rank; ++d) {
464     fac.addConstantLowerBound(d, 0);
465     fac.addConstantUpperBound(d, shape[d] - 1);
466   }
467   // We compose this map with the original index (logical) space to derive
468   // the upper bounds for the new index space.
469   AffineMap layoutMap = layoutMaps.front();
470   unsigned newRank = layoutMap.getNumResults();
471   if (failed(fac.composeMatchingMap(layoutMap)))
472     return memrefType;
473   // TODO: Handle semi-affine maps.
474   // Project out the old data dimensions.
475   fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
476   SmallVector<int64_t, 4> newShape(newRank);
477   for (unsigned d = 0; d < newRank; ++d) {
478     // The lower bound for the shape is always zero.
479     auto ubConst = fac.getConstantUpperBound(d);
480     // For a static memref and an affine map with no symbols, this is
481     // always bounded.
482     assert(ubConst.hasValue() && "should always have an upper bound");
483     if (ubConst.getValue() < 0)
484       // This is due to an invalid map that maps to a negative space.
485       return memrefType;
486     newShape[d] = ubConst.getValue() + 1;
487   }
488 
489   // Create the new memref type after trivializing the old layout map.
490   MemRefType newMemRefType =
491       MemRefType::Builder(memrefType)
492           .setShape(newShape)
493           .setAffineMaps(b.getMultiDimIdentityMap(newRank));
494 
495   return newMemRefType;
496 }
497