1 //===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Transforms/Utils.h"
15
16 #include "mlir/Analysis/AffineAnalysis.h"
17 #include "mlir/Analysis/AffineStructures.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 using namespace mlir;
27
28 /// Return true if this operation dereferences one or more memref's.
29 // Temporary utility: will be replaced when this is modeled through
30 // side-effects/op traits. TODO
isMemRefDereferencingOp(Operation & op)31 static bool isMemRefDereferencingOp(Operation &op) {
32 return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
33 AffineDmaWaitOp>(op);
34 }
35
36 /// Return the AffineMapAttr associated with memory 'op' on 'memref'.
getAffineMapAttrForMemRef(Operation * op,Value memref)37 static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
38 return TypeSwitch<Operation *, NamedAttribute>(op)
39 .Case<AffineDmaStartOp, AffineReadOpInterface, AffinePrefetchOp,
40 AffineWriteOpInterface, AffineDmaWaitOp>(
41 [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
42 }
43
44 // Perform the replacement in `op`.
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,Operation * op,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,bool allowNonDereferencingOps)45 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
46 Operation *op,
47 ArrayRef<Value> extraIndices,
48 AffineMap indexRemap,
49 ArrayRef<Value> extraOperands,
50 ArrayRef<Value> symbolOperands,
51 bool allowNonDereferencingOps) {
52 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
53 (void)newMemRefRank; // unused in opt mode
54 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
55 (void)oldMemRefRank; // unused in opt mode
56 if (indexRemap) {
57 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
58 "symbolic operand count mismatch");
59 assert(indexRemap.getNumInputs() ==
60 extraOperands.size() + oldMemRefRank + symbolOperands.size());
61 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
62 } else {
63 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
64 }
65
66 // Assert same elemental type.
67 assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
68 newMemRef.getType().cast<MemRefType>().getElementType());
69
70 SmallVector<unsigned, 2> usePositions;
71 for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
72 if (opEntry.value() == oldMemRef)
73 usePositions.push_back(opEntry.index());
74 }
75
76 // If memref doesn't appear, nothing to do.
77 if (usePositions.empty())
78 return success();
79
80 if (usePositions.size() > 1) {
81 // TODO: extend it for this case when needed (rare).
82 assert(false && "multiple dereferencing uses in a single op not supported");
83 return failure();
84 }
85
86 unsigned memRefOperandPos = usePositions.front();
87
88 OpBuilder builder(op);
89 // The following checks if op is dereferencing memref and performs the access
90 // index rewrites.
91 if (!isMemRefDereferencingOp(*op)) {
92 if (!allowNonDereferencingOps)
93 // Failure: memref used in a non-dereferencing context (potentially
94 // escapes); no replacement in these cases unless allowNonDereferencingOps
95 // is set.
96 return failure();
97 op->setOperand(memRefOperandPos, newMemRef);
98 return success();
99 }
100 // Perform index rewrites for the dereferencing op and then replace the op
101 NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
102 AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
103 unsigned oldMapNumInputs = oldMap.getNumInputs();
104 SmallVector<Value, 4> oldMapOperands(
105 op->operand_begin() + memRefOperandPos + 1,
106 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
107
108 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
109 SmallVector<Value, 4> oldMemRefOperands;
110 SmallVector<Value, 4> affineApplyOps;
111 oldMemRefOperands.reserve(oldMemRefRank);
112 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
113 for (auto resultExpr : oldMap.getResults()) {
114 auto singleResMap = AffineMap::get(oldMap.getNumDims(),
115 oldMap.getNumSymbols(), resultExpr);
116 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
117 oldMapOperands);
118 oldMemRefOperands.push_back(afOp);
119 affineApplyOps.push_back(afOp);
120 }
121 } else {
122 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
123 }
124
125 // Construct new indices as a remap of the old ones if a remapping has been
126 // provided. The indices of a memref come right after it, i.e.,
127 // at position memRefOperandPos + 1.
128 SmallVector<Value, 4> remapOperands;
129 remapOperands.reserve(extraOperands.size() + oldMemRefRank +
130 symbolOperands.size());
131 remapOperands.append(extraOperands.begin(), extraOperands.end());
132 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
133 remapOperands.append(symbolOperands.begin(), symbolOperands.end());
134
135 SmallVector<Value, 4> remapOutputs;
136 remapOutputs.reserve(oldMemRefRank);
137
138 if (indexRemap &&
139 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
140 // Remapped indices.
141 for (auto resultExpr : indexRemap.getResults()) {
142 auto singleResMap = AffineMap::get(
143 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
144 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
145 remapOperands);
146 remapOutputs.push_back(afOp);
147 affineApplyOps.push_back(afOp);
148 }
149 } else {
150 // No remapping specified.
151 remapOutputs.assign(remapOperands.begin(), remapOperands.end());
152 }
153
154 SmallVector<Value, 4> newMapOperands;
155 newMapOperands.reserve(newMemRefRank);
156
157 // Prepend 'extraIndices' in 'newMapOperands'.
158 for (Value extraIndex : extraIndices) {
159 assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
160 "single result op's expected to generate these indices");
161 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
162 "invalid memory op index");
163 newMapOperands.push_back(extraIndex);
164 }
165
166 // Append 'remapOutputs' to 'newMapOperands'.
167 newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
168
169 // Create new fully composed AffineMap for new op to be created.
170 assert(newMapOperands.size() == newMemRefRank);
171 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
172 // TODO: Avoid creating/deleting temporary AffineApplyOps here.
173 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
174 newMap = simplifyAffineMap(newMap);
175 canonicalizeMapAndOperands(&newMap, &newMapOperands);
176 // Remove any affine.apply's that became dead as a result of composition.
177 for (Value value : affineApplyOps)
178 if (value.use_empty())
179 value.getDefiningOp()->erase();
180
181 OperationState state(op->getLoc(), op->getName());
182 // Construct the new operation using this memref.
183 state.operands.reserve(op->getNumOperands() + extraIndices.size());
184 // Insert the non-memref operands.
185 state.operands.append(op->operand_begin(),
186 op->operand_begin() + memRefOperandPos);
187 // Insert the new memref value.
188 state.operands.push_back(newMemRef);
189
190 // Insert the new memref map operands.
191 state.operands.append(newMapOperands.begin(), newMapOperands.end());
192
193 // Insert the remaining operands unmodified.
194 state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
195 oldMapNumInputs,
196 op->operand_end());
197
198 // Result types don't change. Both memref's are of the same elemental type.
199 state.types.reserve(op->getNumResults());
200 for (auto result : op->getResults())
201 state.types.push_back(result.getType());
202
203 // Add attribute for 'newMap', other Attributes do not change.
204 auto newMapAttr = AffineMapAttr::get(newMap);
205 for (auto namedAttr : op->getAttrs()) {
206 if (namedAttr.first == oldMapAttrPair.first)
207 state.attributes.push_back({namedAttr.first, newMapAttr});
208 else
209 state.attributes.push_back(namedAttr);
210 }
211
212 // Create the new operation.
213 auto *repOp = builder.createOperation(state);
214 op->replaceAllUsesWith(repOp);
215 op->erase();
216
217 return success();
218 }
219
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,Operation * domInstFilter,Operation * postDomInstFilter,bool allowNonDereferencingOps,bool replaceInDeallocOp)220 LogicalResult mlir::replaceAllMemRefUsesWith(
221 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
222 AffineMap indexRemap, ArrayRef<Value> extraOperands,
223 ArrayRef<Value> symbolOperands, Operation *domInstFilter,
224 Operation *postDomInstFilter, bool allowNonDereferencingOps,
225 bool replaceInDeallocOp) {
226 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
227 (void)newMemRefRank; // unused in opt mode
228 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
229 (void)oldMemRefRank;
230 if (indexRemap) {
231 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
232 "symbol operand count mismatch");
233 assert(indexRemap.getNumInputs() ==
234 extraOperands.size() + oldMemRefRank + symbolOperands.size());
235 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
236 } else {
237 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
238 }
239
240 // Assert same elemental type.
241 assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
242 newMemRef.getType().cast<MemRefType>().getElementType());
243
244 std::unique_ptr<DominanceInfo> domInfo;
245 std::unique_ptr<PostDominanceInfo> postDomInfo;
246 if (domInstFilter)
247 domInfo = std::make_unique<DominanceInfo>(
248 domInstFilter->getParentOfType<FuncOp>());
249
250 if (postDomInstFilter)
251 postDomInfo = std::make_unique<PostDominanceInfo>(
252 postDomInstFilter->getParentOfType<FuncOp>());
253
254 // Walk all uses of old memref; collect ops to perform replacement. We use a
255 // DenseSet since an operation could potentially have multiple uses of a
256 // memref (although rare), and the replacement later is going to erase ops.
257 DenseSet<Operation *> opsToReplace;
258 for (auto *op : oldMemRef.getUsers()) {
259 // Skip this use if it's not dominated by domInstFilter.
260 if (domInstFilter && !domInfo->dominates(domInstFilter, op))
261 continue;
262
263 // Skip this use if it's not post-dominated by postDomInstFilter.
264 if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
265 continue;
266
267 // Skip dealloc's - no replacement is necessary, and a memref replacement
268 // at other uses doesn't hurt these dealloc's.
269 if (isa<DeallocOp>(op) && !replaceInDeallocOp)
270 continue;
271
272 // Check if the memref was used in a non-dereferencing context. It is fine
273 // for the memref to be used in a non-dereferencing way outside of the
274 // region where this replacement is happening.
275 if (!isMemRefDereferencingOp(*op)) {
276 if (!allowNonDereferencingOps)
277 return failure();
278 // Currently we support the following non-dereferencing ops to be a
279 // candidate for replacement: Dealloc, CallOp and ReturnOp.
280 // TODO: Add support for other kinds of ops.
281 if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
282 return failure();
283 }
284
285 // We'll first collect and then replace --- since replacement erases the op
286 // that has the use, and that op could be postDomFilter or domFilter itself!
287 opsToReplace.insert(op);
288 }
289
290 for (auto *op : opsToReplace) {
291 if (failed(replaceAllMemRefUsesWith(
292 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
293 symbolOperands, allowNonDereferencingOps)))
294 llvm_unreachable("memref replacement guaranteed to succeed here");
295 }
296
297 return success();
298 }
299
300 /// Given an operation, inserts one or more single result affine
301 /// apply operations, results of which are exclusively used by this operation
302 /// operation. The operands of these newly created affine apply ops are
303 /// guaranteed to be loop iterators or terminal symbols of a function.
304 ///
305 /// Before
306 ///
307 /// affine.for %i = 0 to #map(%N)
308 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
309 /// "send"(%idx, %A, ...)
310 /// "compute"(%idx)
311 ///
312 /// After
313 ///
314 /// affine.for %i = 0 to #map(%N)
315 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
316 /// "send"(%idx, %A, ...)
317 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
318 /// "compute"(%idx_)
319 ///
320 /// This allows applying different transformations on send and compute (for eg.
321 /// different shifts/delays).
322 ///
323 /// Returns nullptr either if none of opInst's operands were the result of an
324 /// affine.apply and thus there was no affine computation slice to create, or if
325 /// all the affine.apply op's supplying operands to this opInst did not have any
326 /// uses besides this opInst; otherwise returns the list of affine.apply
327 /// operations created in output argument `sliceOps`.
createAffineComputationSlice(Operation * opInst,SmallVectorImpl<AffineApplyOp> * sliceOps)328 void mlir::createAffineComputationSlice(
329 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
330 // Collect all operands that are results of affine apply ops.
331 SmallVector<Value, 4> subOperands;
332 subOperands.reserve(opInst->getNumOperands());
333 for (auto operand : opInst->getOperands())
334 if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
335 subOperands.push_back(operand);
336
337 // Gather sequence of AffineApplyOps reachable from 'subOperands'.
338 SmallVector<Operation *, 4> affineApplyOps;
339 getReachableAffineApplyOps(subOperands, affineApplyOps);
340 // Skip transforming if there are no affine maps to compose.
341 if (affineApplyOps.empty())
342 return;
343
344 // Check if all uses of the affine apply op's lie only in this op op, in
345 // which case there would be nothing to do.
346 bool localized = true;
347 for (auto *op : affineApplyOps) {
348 for (auto result : op->getResults()) {
349 for (auto *user : result.getUsers()) {
350 if (user != opInst) {
351 localized = false;
352 break;
353 }
354 }
355 }
356 }
357 if (localized)
358 return;
359
360 OpBuilder builder(opInst);
361 SmallVector<Value, 4> composedOpOperands(subOperands);
362 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
363 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
364
365 // Create an affine.apply for each of the map results.
366 sliceOps->reserve(composedMap.getNumResults());
367 for (auto resultExpr : composedMap.getResults()) {
368 auto singleResMap = AffineMap::get(composedMap.getNumDims(),
369 composedMap.getNumSymbols(), resultExpr);
370 sliceOps->push_back(builder.create<AffineApplyOp>(
371 opInst->getLoc(), singleResMap, composedOpOperands));
372 }
373
374 // Construct the new operands that include the results from the composed
375 // affine apply op above instead of existing ones (subOperands). So, they
376 // differ from opInst's operands only for those operands in 'subOperands', for
377 // which they will be replaced by the corresponding one from 'sliceOps'.
378 SmallVector<Value, 4> newOperands(opInst->getOperands());
379 for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
380 // Replace the subOperands from among the new operands.
381 unsigned j, f;
382 for (j = 0, f = subOperands.size(); j < f; j++) {
383 if (newOperands[i] == subOperands[j])
384 break;
385 }
386 if (j < subOperands.size()) {
387 newOperands[i] = (*sliceOps)[j];
388 }
389 }
390 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
391 opInst->setOperand(idx, newOperands[idx]);
392 }
393 }
394
395 // TODO: Currently works for static memrefs with a single layout map.
normalizeMemRef(AllocOp allocOp)396 LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
397 MemRefType memrefType = allocOp.getType();
398 OpBuilder b(allocOp);
399
400 // Fetch a new memref type after normalizing the old memref to have an
401 // identity map layout.
402 MemRefType newMemRefType =
403 normalizeMemRefType(memrefType, b, allocOp.symbolOperands().size());
404 if (newMemRefType == memrefType)
405 // Either memrefType already had an identity map or the map couldn't be
406 // transformed to an identity map.
407 return failure();
408
409 Value oldMemRef = allocOp.getResult();
410
411 SmallVector<Value, 4> symbolOperands(allocOp.symbolOperands());
412 AllocOp newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType,
413 allocOp.alignmentAttr());
414 AffineMap layoutMap = memrefType.getAffineMaps().front();
415 // Replace all uses of the old memref.
416 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
417 /*extraIndices=*/{},
418 /*indexRemap=*/layoutMap,
419 /*extraOperands=*/{},
420 /*symbolOperands=*/symbolOperands,
421 /*domInstFilter=*/nullptr,
422 /*postDomInstFilter=*/nullptr,
423 /*allowDereferencingOps=*/true))) {
424 // If it failed (due to escapes for example), bail out.
425 newAlloc.erase();
426 return failure();
427 }
428 // Replace any uses of the original alloc op and erase it. All remaining uses
429 // have to be dealloc's; RAMUW above would've failed otherwise.
430 assert(llvm::all_of(oldMemRef.getUsers(),
431 [](Operation *op) { return isa<DeallocOp>(op); }));
432 oldMemRef.replaceAllUsesWith(newAlloc);
433 allocOp.erase();
434 return success();
435 }
436
normalizeMemRefType(MemRefType memrefType,OpBuilder b,unsigned numSymbolicOperands)437 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
438 unsigned numSymbolicOperands) {
439 unsigned rank = memrefType.getRank();
440 if (rank == 0)
441 return memrefType;
442
443 ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
444 if (layoutMaps.empty() ||
445 layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
446 // Either no maps is associated with this memref or this memref has
447 // a trivial (identity) map.
448 return memrefType;
449 }
450
451 // We don't do any checks for one-to-one'ness; we assume that it is
452 // one-to-one.
453
454 // TODO: Only for static memref's for now.
455 if (memrefType.getNumDynamicDims() > 0)
456 return memrefType;
457
458 // We have a single map that is not an identity map. Create a new memref
459 // with the right shape and an identity layout map.
460 ArrayRef<int64_t> shape = memrefType.getShape();
461 // FlatAffineConstraint may later on use symbolicOperands.
462 FlatAffineConstraints fac(rank, numSymbolicOperands);
463 for (unsigned d = 0; d < rank; ++d) {
464 fac.addConstantLowerBound(d, 0);
465 fac.addConstantUpperBound(d, shape[d] - 1);
466 }
467 // We compose this map with the original index (logical) space to derive
468 // the upper bounds for the new index space.
469 AffineMap layoutMap = layoutMaps.front();
470 unsigned newRank = layoutMap.getNumResults();
471 if (failed(fac.composeMatchingMap(layoutMap)))
472 return memrefType;
473 // TODO: Handle semi-affine maps.
474 // Project out the old data dimensions.
475 fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
476 SmallVector<int64_t, 4> newShape(newRank);
477 for (unsigned d = 0; d < newRank; ++d) {
478 // The lower bound for the shape is always zero.
479 auto ubConst = fac.getConstantUpperBound(d);
480 // For a static memref and an affine map with no symbols, this is
481 // always bounded.
482 assert(ubConst.hasValue() && "should always have an upper bound");
483 if (ubConst.getValue() < 0)
484 // This is due to an invalid map that maps to a negative space.
485 return memrefType;
486 newShape[d] = ubConst.getValue() + 1;
487 }
488
489 // Create the new memref type after trivializing the old layout map.
490 MemRefType newMemRefType =
491 MemRefType::Builder(memrefType)
492 .setShape(newShape)
493 .setAffineMaps(b.getMultiDimIdentityMap(newRank));
494
495 return newMemRefType;
496 }
497