1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22
23 using namespace mlir;
24 using llvm::dbgs;
25
26 #define DEBUG_TYPE "affine-analysis"
27
28 //===----------------------------------------------------------------------===//
29 // AffineDialect Interfaces
30 //===----------------------------------------------------------------------===//
31
32 namespace {
33 /// This class defines the interface for handling inlining with affine
34 /// operations.
35 struct AffineInlinerInterface : public DialectInlinerInterface {
36 using DialectInlinerInterface::DialectInlinerInterface;
37
38 //===--------------------------------------------------------------------===//
39 // Analysis Hooks
40 //===--------------------------------------------------------------------===//
41
42 /// Returns true if the given region 'src' can be inlined into the region
43 /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonb6f842fb0111::AffineInlinerInterface44 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45 BlockAndValueMapping &valueMapping) const final {
46 // Conservatively don't allow inlining into affine structures.
47 return false;
48 }
49
50 /// Returns true if the given operation 'op', that is registered to this
51 /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anonb6f842fb0111::AffineInlinerInterface52 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
53 BlockAndValueMapping &valueMapping) const final {
54 // Always allow inlining affine operations into the top-level region of a
55 // function. There are some edge cases when inlining *into* affine
56 // structures, but that is handled in the other 'isLegalToInline' hook
57 // above.
58 // TODO: We should be able to inline into other regions than functions.
59 return isa<FuncOp>(region->getParentOp());
60 }
61
62 /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anonb6f842fb0111::AffineInlinerInterface63 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
64 };
65 } // end anonymous namespace
66
67 //===----------------------------------------------------------------------===//
68 // AffineDialect
69 //===----------------------------------------------------------------------===//
70
initialize()71 void AffineDialect::initialize() {
72 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
73 #define GET_OP_LIST
74 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
75 >();
76 addInterfaces<AffineInlinerInterface>();
77 }
78
79 /// Materialize a single constant operation from a given attribute value with
80 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)81 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
82 Attribute value, Type type,
83 Location loc) {
84 return builder.create<ConstantOp>(loc, type, value);
85 }
86
87 /// A utility function to check if a value is defined at the top level of an
88 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
89 /// conservatively assume it is not top-level. A value of index type defined at
90 /// the top level is always a valid symbol.
isTopLevelValue(Value value)91 bool mlir::isTopLevelValue(Value value) {
92 if (auto arg = value.dyn_cast<BlockArgument>()) {
93 // The block owning the argument may be unlinked, e.g. when the surrounding
94 // region has not yet been attached to an Op, at which point the parent Op
95 // is null.
96 Operation *parentOp = arg.getOwner()->getParentOp();
97 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
98 }
99 // The defining Op may live in an unlinked block so its parent Op may be null.
100 Operation *parentOp = value.getDefiningOp()->getParentOp();
101 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
102 }
103
104 /// A utility function to check if a value is defined at the top level of
105 /// `region` or is an argument of `region`. A value of index type defined at the
106 /// top level of a `AffineScope` region is always a valid symbol for all
107 /// uses in that region.
isTopLevelValue(Value value,Region * region)108 static bool isTopLevelValue(Value value, Region *region) {
109 if (auto arg = value.dyn_cast<BlockArgument>())
110 return arg.getParentRegion() == region;
111 return value.getDefiningOp()->getParentRegion() == region;
112 }
113
114 /// Returns the closest region enclosing `op` that is held by an operation with
115 /// trait `AffineScope`; `nullptr` if there is no such region.
116 // TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)117 static Region *getAffineScope(Operation *op) {
118 auto *curOp = op;
119 while (auto *parentOp = curOp->getParentOp()) {
120 if (parentOp->hasTrait<OpTrait::AffineScope>())
121 return curOp->getParentRegion();
122 curOp = parentOp;
123 }
124 return nullptr;
125 }
126
127 // A Value can be used as a dimension id iff it meets one of the following
128 // conditions:
129 // *) It is valid as a symbol.
130 // *) It is an induction variable.
131 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)132 bool mlir::isValidDim(Value value) {
133 // The value must be an index type.
134 if (!value.getType().isIndex())
135 return false;
136
137 if (auto *defOp = value.getDefiningOp())
138 return isValidDim(value, getAffineScope(defOp));
139
140 // This value has to be a block argument for an op that has the
141 // `AffineScope` trait or for an affine.for or affine.parallel.
142 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
143 return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
144 isa<AffineForOp, AffineParallelOp>(parentOp));
145 }
146
147 // Value can be used as a dimension id iff it meets one of the following
148 // conditions:
149 // *) It is valid as a symbol.
150 // *) It is an induction variable.
151 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)152 bool mlir::isValidDim(Value value, Region *region) {
153 // The value must be an index type.
154 if (!value.getType().isIndex())
155 return false;
156
157 // All valid symbols are okay.
158 if (isValidSymbol(value, region))
159 return true;
160
161 auto *op = value.getDefiningOp();
162 if (!op) {
163 // This value has to be a block argument for an affine.for or an
164 // affine.parallel.
165 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
166 return isa<AffineForOp, AffineParallelOp>(parentOp);
167 }
168
169 // Affine apply operation is ok if all of its operands are ok.
170 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
171 return applyOp.isValidDim(region);
172 // The dim op is okay if its operand memref/tensor is defined at the top
173 // level.
174 if (auto dimOp = dyn_cast<DimOp>(op))
175 return isTopLevelValue(dimOp.memrefOrTensor());
176 return false;
177 }
178
179 /// Returns true if the 'index' dimension of the `memref` defined by
180 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
181 /// for `region`.
182 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)183 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
184 Region *region) {
185 auto memRefType = memrefDefOp.getType();
186 // Statically shaped.
187 if (!memRefType.isDynamicDim(index))
188 return true;
189 // Get the position of the dimension among dynamic dimensions;
190 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
191 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
192 region);
193 }
194
195 /// Returns true if the result of the dim op is a valid symbol for `region`.
isDimOpValidSymbol(DimOp dimOp,Region * region)196 static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
197 // The dim op is okay if its operand memref/tensor is defined at the top
198 // level.
199 if (isTopLevelValue(dimOp.memrefOrTensor()))
200 return true;
201
202 // Conservatively handle remaining BlockArguments as non-valid symbols.
203 // E.g. scf.for iterArgs.
204 if (dimOp.memrefOrTensor().isa<BlockArgument>())
205 return false;
206
207 // The dim op is also okay if its operand memref/tensor is a view/subview
208 // whose corresponding size is a valid symbol.
209 Optional<int64_t> index = dimOp.getConstantIndex();
210 assert(index.hasValue() &&
211 "expect only `dim` operations with a constant index");
212 int64_t i = index.getValue();
213 return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
214 .Case<ViewOp, SubViewOp, AllocOp>(
215 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
216 .Default([](Operation *) { return false; });
217 }
218
219 // A value can be used as a symbol (at all its use sites) iff it meets one of
220 // the following conditions:
221 // *) It is a constant.
222 // *) Its defining op or block arg appearance is immediately enclosed by an op
223 // with `AffineScope` trait.
224 // *) It is the result of an affine.apply operation with symbol operands.
225 // *) It is a result of the dim op on a memref whose corresponding size is a
226 // valid symbol.
isValidSymbol(Value value)227 bool mlir::isValidSymbol(Value value) {
228 // The value must be an index type.
229 if (!value.getType().isIndex())
230 return false;
231
232 // Check that the value is a top level value.
233 if (isTopLevelValue(value))
234 return true;
235
236 if (auto *defOp = value.getDefiningOp())
237 return isValidSymbol(value, getAffineScope(defOp));
238
239 return false;
240 }
241
242 /// A value can be used as a symbol for `region` iff it meets onf of the the
243 /// following conditions:
244 /// *) It is a constant.
245 /// *) It is the result of an affine apply operation with symbol arguments.
246 /// *) It is a result of the dim op on a memref whose corresponding size is
247 /// a valid symbol.
248 /// *) It is defined at the top level of 'region' or is its argument.
249 /// *) It dominates `region`'s parent op.
250 /// If `region` is null, conservatively assume the symbol definition scope does
251 /// not exist and only accept the values that would be symbols regardless of
252 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)253 bool mlir::isValidSymbol(Value value, Region *region) {
254 // The value must be an index type.
255 if (!value.getType().isIndex())
256 return false;
257
258 // A top-level value is a valid symbol.
259 if (region && ::isTopLevelValue(value, region))
260 return true;
261
262 auto *defOp = value.getDefiningOp();
263 if (!defOp) {
264 // A block argument that is not a top-level value is a valid symbol if it
265 // dominates region's parent op.
266 if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
267 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
268 return isValidSymbol(value, parentOpRegion);
269 return false;
270 }
271
272 // Constant operation is ok.
273 Attribute operandCst;
274 if (matchPattern(defOp, m_Constant(&operandCst)))
275 return true;
276
277 // Affine apply operation is ok if all of its operands are ok.
278 if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
279 return applyOp.isValidSymbol(region);
280
281 // Dim op results could be valid symbols at any level.
282 if (auto dimOp = dyn_cast<DimOp>(defOp))
283 return isDimOpValidSymbol(dimOp, region);
284
285 // Check for values dominating `region`'s parent op.
286 if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
287 if (auto *parentRegion = region->getParentOp()->getParentRegion())
288 return isValidSymbol(value, parentRegion);
289
290 return false;
291 }
292
293 // Returns true if 'value' is a valid index to an affine operation (e.g.
294 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
295 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)296 static bool isValidAffineIndexOperand(Value value, Region *region) {
297 return isValidDim(value, region) || isValidSymbol(value, region);
298 }
299
300 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)301 static void printDimAndSymbolList(Operation::operand_iterator begin,
302 Operation::operand_iterator end,
303 unsigned numDims, OpAsmPrinter &printer) {
304 OperandRange operands(begin, end);
305 printer << '(' << operands.take_front(numDims) << ')';
306 if (operands.size() > numDims)
307 printer << '[' << operands.drop_front(numDims) << ']';
308 }
309
310 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)311 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
312 SmallVectorImpl<Value> &operands,
313 unsigned &numDims) {
314 SmallVector<OpAsmParser::OperandType, 8> opInfos;
315 if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
316 return failure();
317 // Store number of dimensions for validation by caller.
318 numDims = opInfos.size();
319
320 // Parse the optional symbol operands.
321 auto indexTy = parser.getBuilder().getIndexType();
322 return failure(parser.parseOperandList(
323 opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
324 parser.resolveOperands(opInfos, indexTy, operands));
325 }
326
327 /// Utility function to verify that a set of operands are valid dimension and
328 /// symbol identifiers. The operands should be laid out such that the dimension
329 /// operands are before the symbol operands. This function returns failure if
330 /// there was an invalid operand. An operation is provided to emit any necessary
331 /// errors.
332 template <typename OpTy>
333 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)334 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
335 unsigned numDims) {
336 unsigned opIt = 0;
337 for (auto operand : operands) {
338 if (opIt++ < numDims) {
339 if (!isValidDim(operand, getAffineScope(op)))
340 return op.emitOpError("operand cannot be used as a dimension id");
341 } else if (!isValidSymbol(operand, getAffineScope(op))) {
342 return op.emitOpError("operand cannot be used as a symbol");
343 }
344 }
345 return success();
346 }
347
348 //===----------------------------------------------------------------------===//
349 // AffineApplyOp
350 //===----------------------------------------------------------------------===//
351
getAffineValueMap()352 AffineValueMap AffineApplyOp::getAffineValueMap() {
353 return AffineValueMap(getAffineMap(), getOperands(), getResult());
354 }
355
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)356 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
357 OperationState &result) {
358 auto &builder = parser.getBuilder();
359 auto indexTy = builder.getIndexType();
360
361 AffineMapAttr mapAttr;
362 unsigned numDims;
363 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
364 parseDimAndSymbolList(parser, result.operands, numDims) ||
365 parser.parseOptionalAttrDict(result.attributes))
366 return failure();
367 auto map = mapAttr.getValue();
368
369 if (map.getNumDims() != numDims ||
370 numDims + map.getNumSymbols() != result.operands.size()) {
371 return parser.emitError(parser.getNameLoc(),
372 "dimension or symbol index mismatch");
373 }
374
375 result.types.append(map.getNumResults(), indexTy);
376 return success();
377 }
378
print(OpAsmPrinter & p,AffineApplyOp op)379 static void print(OpAsmPrinter &p, AffineApplyOp op) {
380 p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
381 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
382 op.getAffineMap().getNumDims(), p);
383 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
384 }
385
verify(AffineApplyOp op)386 static LogicalResult verify(AffineApplyOp op) {
387 // Check input and output dimensions match.
388 auto map = op.map();
389
390 // Verify that operand count matches affine map dimension and symbol count.
391 if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
392 return op.emitOpError(
393 "operand count and affine map dimension and symbol count must match");
394
395 // Verify that the map only produces one result.
396 if (map.getNumResults() != 1)
397 return op.emitOpError("mapping must produce one value");
398
399 return success();
400 }
401
402 // The result of the affine apply operation can be used as a dimension id if all
403 // its operands are valid dimension ids.
isValidDim()404 bool AffineApplyOp::isValidDim() {
405 return llvm::all_of(getOperands(),
406 [](Value op) { return mlir::isValidDim(op); });
407 }
408
409 // The result of the affine apply operation can be used as a dimension id if all
410 // its operands are valid dimension ids with the parent operation of `region`
411 // defining the polyhedral scope for symbols.
isValidDim(Region * region)412 bool AffineApplyOp::isValidDim(Region *region) {
413 return llvm::all_of(getOperands(),
414 [&](Value op) { return ::isValidDim(op, region); });
415 }
416
417 // The result of the affine apply operation can be used as a symbol if all its
418 // operands are symbols.
isValidSymbol()419 bool AffineApplyOp::isValidSymbol() {
420 return llvm::all_of(getOperands(),
421 [](Value op) { return mlir::isValidSymbol(op); });
422 }
423
424 // The result of the affine apply operation can be used as a symbol in `region`
425 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)426 bool AffineApplyOp::isValidSymbol(Region *region) {
427 return llvm::all_of(getOperands(), [&](Value operand) {
428 return mlir::isValidSymbol(operand, region);
429 });
430 }
431
fold(ArrayRef<Attribute> operands)432 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
433 auto map = getAffineMap();
434
435 // Fold dims and symbols to existing values.
436 auto expr = map.getResult(0);
437 if (auto dim = expr.dyn_cast<AffineDimExpr>())
438 return getOperand(dim.getPosition());
439 if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
440 return getOperand(map.getNumDims() + sym.getPosition());
441
442 // Otherwise, default to folding the map.
443 SmallVector<Attribute, 1> result;
444 if (failed(map.constantFold(operands, result)))
445 return {};
446 return result[0];
447 }
448
renumberOneDim(Value v)449 AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
450 DenseMap<Value, unsigned>::iterator iterPos;
451 bool inserted = false;
452 std::tie(iterPos, inserted) =
453 dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
454 if (inserted) {
455 reorderedDims.push_back(v);
456 }
457 return getAffineDimExpr(iterPos->second, v.getContext())
458 .cast<AffineDimExpr>();
459 }
460
renumber(const AffineApplyNormalizer & other)461 AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
462 SmallVector<AffineExpr, 8> dimRemapping;
463 for (auto v : other.reorderedDims) {
464 auto kvp = other.dimValueToPosition.find(v);
465 if (dimRemapping.size() <= kvp->second)
466 dimRemapping.resize(kvp->second + 1);
467 dimRemapping[kvp->second] = renumberOneDim(kvp->first);
468 }
469 unsigned numSymbols = concatenatedSymbols.size();
470 unsigned numOtherSymbols = other.concatenatedSymbols.size();
471 SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
472 for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
473 symRemapping[idx] =
474 getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
475 }
476 concatenatedSymbols.insert(concatenatedSymbols.end(),
477 other.concatenatedSymbols.begin(),
478 other.concatenatedSymbols.end());
479 auto map = other.affineMap;
480 return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
481 reorderedDims.size(),
482 concatenatedSymbols.size());
483 }
484
485 // Gather the positions of the operands that are produced by an AffineApplyOp.
486 static llvm::SetVector<unsigned>
indicesFromAffineApplyOp(ArrayRef<Value> operands)487 indicesFromAffineApplyOp(ArrayRef<Value> operands) {
488 llvm::SetVector<unsigned> res;
489 for (auto en : llvm::enumerate(operands))
490 if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
491 res.insert(en.index());
492 return res;
493 }
494
495 // Support the special case of a symbol coming from an AffineApplyOp that needs
496 // to be composed into the current AffineApplyOp.
497 // This case is handled by rewriting all such symbols into dims for the purpose
498 // of allowing mathematical AffineMap composition.
499 // Returns an AffineMap where symbols that come from an AffineApplyOp have been
500 // rewritten as dims and are ordered after the original dims.
501 // TODO: This promotion makes AffineMap lose track of which
502 // symbols are represented as dims. This loss is static but can still be
503 // recovered dynamically (with `isValidSymbol`). Still this is annoying for the
504 // semi-affine map case. A dynamic canonicalization of all dims that are valid
505 // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
506 // results in better simplifications and foldings. But we should evaluate
507 // whether this behavior is what we really want after using more.
promoteComposedSymbolsAsDims(AffineMap map,ArrayRef<Value> symbols)508 static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
509 ArrayRef<Value> symbols) {
510 if (symbols.empty()) {
511 return map;
512 }
513
514 // Sanity check on symbols.
515 for (auto sym : symbols) {
516 assert(isValidSymbol(sym) && "Expected only valid symbols");
517 (void)sym;
518 }
519
520 // Extract the symbol positions that come from an AffineApplyOp and
521 // needs to be rewritten as dims.
522 auto symPositions = indicesFromAffineApplyOp(symbols);
523 if (symPositions.empty()) {
524 return map;
525 }
526
527 // Create the new map by replacing each symbol at pos by the next new dim.
528 unsigned numDims = map.getNumDims();
529 unsigned numSymbols = map.getNumSymbols();
530 unsigned numNewDims = 0;
531 unsigned numNewSymbols = 0;
532 SmallVector<AffineExpr, 8> symReplacements(numSymbols);
533 for (unsigned i = 0; i < numSymbols; ++i) {
534 symReplacements[i] =
535 symPositions.count(i) > 0
536 ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
537 : getAffineSymbolExpr(numNewSymbols++, map.getContext());
538 }
539 assert(numSymbols >= numNewDims);
540 AffineMap newMap = map.replaceDimsAndSymbols(
541 {}, symReplacements, numDims + numNewDims, numNewSymbols);
542
543 return newMap;
544 }
545
546 /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
547 /// keep a correspondence between the mathematical `map` and the `operands` of
548 /// a given AffineApplyOp. This correspondence is maintained by iterating over
549 /// the operands and forming an `auxiliaryMap` that can be composed
550 /// mathematically with `map`. To keep this correspondence in cases where
551 /// symbols are produced by affine.apply operations, we perform a local rewrite
552 /// of symbols as dims.
553 ///
554 /// Rationale for locally rewriting symbols as dims:
555 /// ================================================
556 /// The mathematical composition of AffineMap must always concatenate symbols
557 /// because it does not have enough information to do otherwise. For example,
558 /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
559 /// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
560 ///
561 /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
562 /// applied to the same mlir::Value for both s0 and s1.
563 /// As a consequence mathematical composition of AffineMap always concatenates
564 /// symbols.
565 ///
566 /// When AffineMaps are used in AffineApplyOp however, they may specify
567 /// composition via symbols, which is ambiguous mathematically. This corner case
568 /// is handled by locally rewriting such symbols that come from AffineApplyOp
569 /// into dims and composing through dims.
570 /// TODO: Composition via symbols comes at a significant code
571 /// complexity. Alternatively we should investigate whether we want to
572 /// explicitly disallow symbols coming from affine.apply and instead force the
573 /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
574 /// extra API calls for such uses, which haven't popped up until now) and the
575 /// benefit potentially big: simpler and more maintainable code for a
576 /// non-trivial, recursive, procedure.
AffineApplyNormalizer(AffineMap map,ArrayRef<Value> operands)577 AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
578 ArrayRef<Value> operands)
579 : AffineApplyNormalizer() {
580 static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
581 assert(map.getNumInputs() == operands.size() &&
582 "number of operands does not match the number of map inputs");
583
584 LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
585
586 // Promote symbols that come from an AffineApplyOp to dims by rewriting the
587 // map to always refer to:
588 // (dims, symbols coming from AffineApplyOp, other symbols).
589 // The order of operands can remain unchanged.
590 // This is a simplification that relies on 2 ordering properties:
591 // 1. rewritten symbols always appear after the original dims in the map;
592 // 2. operands are traversed in order and either dispatched to:
593 // a. auxiliaryExprs (dims and symbols rewritten as dims);
594 // b. concatenatedSymbols (all other symbols)
595 // This allows operand order to remain unchanged.
596 unsigned numDimsBeforeRewrite = map.getNumDims();
597 map = promoteComposedSymbolsAsDims(map,
598 operands.take_back(map.getNumSymbols()));
599
600 LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
601
602 SmallVector<AffineExpr, 8> auxiliaryExprs;
603 bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
604 // We fully spell out the 2 cases below. In this particular instance a little
605 // code duplication greatly improves readability.
606 // Note that the first branch would disappear if we only supported full
607 // composition (i.e. infinite kMaxAffineApplyDepth).
608 if (!furtherCompose) {
609 // 1. Only dispatch dims or symbols.
610 for (auto en : llvm::enumerate(operands)) {
611 auto t = en.value();
612 assert(t.getType().isIndex());
613 bool isDim = (en.index() < map.getNumDims());
614 if (isDim) {
615 // a. The mathematical composition of AffineMap composes dims.
616 auxiliaryExprs.push_back(renumberOneDim(t));
617 } else {
618 // b. The mathematical composition of AffineMap concatenates symbols.
619 // We do the same for symbol operands.
620 concatenatedSymbols.push_back(t);
621 }
622 }
623 } else {
624 assert(numDimsBeforeRewrite <= operands.size());
625 // 2. Compose AffineApplyOps and dispatch dims or symbols.
626 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
627 auto t = operands[i];
628 auto affineApply = t.getDefiningOp<AffineApplyOp>();
629 if (affineApply) {
630 // a. Compose affine.apply operations.
631 LLVM_DEBUG(affineApply->print(
632 dbgs() << "\nCompose AffineApplyOp recursively: "));
633 AffineMap affineApplyMap = affineApply.getAffineMap();
634 SmallVector<Value, 8> affineApplyOperands(
635 affineApply.getOperands().begin(), affineApply.getOperands().end());
636 AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
637
638 LLVM_DEBUG(normalizer.affineMap.print(
639 dbgs() << "\nRenumber into current normalizer: "));
640
641 auto renumberedMap = renumber(normalizer);
642
643 LLVM_DEBUG(
644 renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
645
646 auxiliaryExprs.push_back(renumberedMap.getResult(0));
647 } else {
648 if (i < numDimsBeforeRewrite) {
649 // b. The mathematical composition of AffineMap composes dims.
650 auxiliaryExprs.push_back(renumberOneDim(t));
651 } else {
652 // c. The mathematical composition of AffineMap concatenates symbols.
653 // Note that the map composition will put symbols already present
654 // in the map before any symbols coming from the auxiliary map, so
655 // we insert them before any symbols that are due to renumbering,
656 // and after the proper symbols we have seen already.
657 concatenatedSymbols.insert(
658 std::next(concatenatedSymbols.begin(), numProperSymbols++), t);
659 }
660 }
661 }
662 }
663
664 // Early exit if `map` is already composed.
665 if (auxiliaryExprs.empty()) {
666 affineMap = map;
667 return;
668 }
669
670 assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
671 "Unexpected number of concatenated symbols");
672 auto numDims = dimValueToPosition.size();
673 auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
674 auto auxiliaryMap =
675 AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext());
676
677 LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
678 LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
679 LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
680
681 // TODO: Disabling simplification results in major speed gains.
682 // Another option is to cache the results as it is expected a lot of redundant
683 // work is performed in practice.
684 affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
685
686 LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
687 LLVM_DEBUG(dbgs() << "\n");
688 }
689
normalize(AffineMap * otherMap,SmallVectorImpl<Value> * otherOperands)690 void AffineApplyNormalizer::normalize(AffineMap *otherMap,
691 SmallVectorImpl<Value> *otherOperands) {
692 AffineApplyNormalizer other(*otherMap, *otherOperands);
693 *otherMap = renumber(other);
694
695 otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size());
696 otherOperands->assign(reorderedDims.begin(), reorderedDims.end());
697 otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end());
698 }
699
700 /// Implements `map` and `operands` composition and simplification to support
701 /// `makeComposedAffineApply`. This can be called to achieve the same effects
702 /// on `map` and `operands` without creating an AffineApplyOp that needs to be
703 /// immediately deleted.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)704 static void composeAffineMapAndOperands(AffineMap *map,
705 SmallVectorImpl<Value> *operands) {
706 AffineApplyNormalizer normalizer(*map, *operands);
707 auto normalizedMap = normalizer.getAffineMap();
708 auto normalizedOperands = normalizer.getOperands();
709 canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
710 *map = normalizedMap;
711 *operands = normalizedOperands;
712 assert(*map);
713 }
714
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)715 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
716 SmallVectorImpl<Value> *operands) {
717 while (llvm::any_of(*operands, [](Value v) {
718 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
719 })) {
720 composeAffineMapAndOperands(map, operands);
721 }
722 }
723
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)724 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
725 AffineMap map,
726 ArrayRef<Value> operands) {
727 AffineMap normalizedMap = map;
728 SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
729 composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
730 assert(normalizedMap);
731 return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
732 }
733
734 // A symbol may appear as a dim in affine.apply operations. This function
735 // canonicalizes dims that are valid symbols into actual symbols.
736 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)737 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
738 SmallVectorImpl<Value> *operands) {
739 if (!mapOrSet || operands->empty())
740 return;
741
742 assert(mapOrSet->getNumInputs() == operands->size() &&
743 "map/set inputs must match number of operands");
744
745 auto *context = mapOrSet->getContext();
746 SmallVector<Value, 8> resultOperands;
747 resultOperands.reserve(operands->size());
748 SmallVector<Value, 8> remappedSymbols;
749 remappedSymbols.reserve(operands->size());
750 unsigned nextDim = 0;
751 unsigned nextSym = 0;
752 unsigned oldNumSyms = mapOrSet->getNumSymbols();
753 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
754 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
755 if (i < mapOrSet->getNumDims()) {
756 if (isValidSymbol((*operands)[i])) {
757 // This is a valid symbol that appears as a dim, canonicalize it.
758 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
759 remappedSymbols.push_back((*operands)[i]);
760 } else {
761 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
762 resultOperands.push_back((*operands)[i]);
763 }
764 } else {
765 resultOperands.push_back((*operands)[i]);
766 }
767 }
768
769 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
770 *operands = resultOperands;
771 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
772 oldNumSyms + nextSym);
773
774 assert(mapOrSet->getNumInputs() == operands->size() &&
775 "map/set inputs must match number of operands");
776 }
777
778 // Works for either an affine map or an integer set.
779 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)780 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
781 SmallVectorImpl<Value> *operands) {
782 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
783 "Argument must be either of AffineMap or IntegerSet type");
784
785 if (!mapOrSet || operands->empty())
786 return;
787
788 assert(mapOrSet->getNumInputs() == operands->size() &&
789 "map/set inputs must match number of operands");
790
791 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
792
793 // Check to see what dims are used.
794 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
795 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
796 mapOrSet->walkExprs([&](AffineExpr expr) {
797 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
798 usedDims[dimExpr.getPosition()] = true;
799 else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
800 usedSyms[symExpr.getPosition()] = true;
801 });
802
803 auto *context = mapOrSet->getContext();
804
805 SmallVector<Value, 8> resultOperands;
806 resultOperands.reserve(operands->size());
807
808 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
809 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
810 unsigned nextDim = 0;
811 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
812 if (usedDims[i]) {
813 // Remap dim positions for duplicate operands.
814 auto it = seenDims.find((*operands)[i]);
815 if (it == seenDims.end()) {
816 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
817 resultOperands.push_back((*operands)[i]);
818 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
819 } else {
820 dimRemapping[i] = it->second;
821 }
822 }
823 }
824 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
825 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
826 unsigned nextSym = 0;
827 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
828 if (!usedSyms[i])
829 continue;
830 // Handle constant operands (only needed for symbolic operands since
831 // constant operands in dimensional positions would have already been
832 // promoted to symbolic positions above).
833 IntegerAttr operandCst;
834 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
835 m_Constant(&operandCst))) {
836 symRemapping[i] =
837 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
838 continue;
839 }
840 // Remap symbol positions for duplicate operands.
841 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
842 if (it == seenSymbols.end()) {
843 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
844 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
845 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
846 symRemapping[i]));
847 } else {
848 symRemapping[i] = it->second;
849 }
850 }
851 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
852 nextDim, nextSym);
853 *operands = resultOperands;
854 }
855
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)856 void mlir::canonicalizeMapAndOperands(AffineMap *map,
857 SmallVectorImpl<Value> *operands) {
858 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
859 }
860
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)861 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
862 SmallVectorImpl<Value> *operands) {
863 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
864 }
865
866 namespace {
867 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
868 /// maps that supply results into them.
869 ///
870 template <typename AffineOpTy>
871 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
872 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
873
874 /// Replace the affine op with another instance of it with the supplied
875 /// map and mapOperands.
876 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
877 AffineMap map, ArrayRef<Value> mapOperands) const;
878
matchAndRewrite__anonb6f842fb0a11::SimplifyAffineOp879 LogicalResult matchAndRewrite(AffineOpTy affineOp,
880 PatternRewriter &rewriter) const override {
881 static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
882 AffineStoreOp, AffineApplyOp, AffineMinOp,
883 AffineMaxOp>::value,
884 "affine load/store/apply/prefetch/min/max op expected");
885 auto map = affineOp.getAffineMap();
886 AffineMap oldMap = map;
887 auto oldOperands = affineOp.getMapOperands();
888 SmallVector<Value, 8> resultOperands(oldOperands);
889 composeAffineMapAndOperands(&map, &resultOperands);
890 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
891 resultOperands.begin()))
892 return failure();
893
894 replaceAffineOp(rewriter, affineOp, map, resultOperands);
895 return success();
896 }
897 };
898
899 // Specialize the template to account for the different build signatures for
900 // affine load, store, and apply ops.
901 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const902 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
903 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
904 ArrayRef<Value> mapOperands) const {
905 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
906 mapOperands);
907 }
908 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const909 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
910 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
911 ArrayRef<Value> mapOperands) const {
912 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
913 prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
914 prefetch.isWrite(), prefetch.isDataCache());
915 }
916 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const917 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
918 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
919 ArrayRef<Value> mapOperands) const {
920 rewriter.replaceOpWithNewOp<AffineStoreOp>(
921 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
922 }
923
924 // Generic version for ops that don't have extra operands.
925 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const926 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
927 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
928 ArrayRef<Value> mapOperands) const {
929 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
930 }
931 } // end anonymous namespace.
932
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)933 void AffineApplyOp::getCanonicalizationPatterns(
934 OwningRewritePatternList &results, MLIRContext *context) {
935 results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
936 }
937
938 //===----------------------------------------------------------------------===//
939 // Common canonicalization pattern support logic
940 //===----------------------------------------------------------------------===//
941
942 /// This is a common class used for patterns of the form
943 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast
944 /// into the root operation directly.
foldMemRefCast(Operation * op)945 static LogicalResult foldMemRefCast(Operation *op) {
946 bool folded = false;
947 for (OpOperand &operand : op->getOpOperands()) {
948 auto cast = operand.get().getDefiningOp<MemRefCastOp>();
949 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
950 operand.set(cast.getOperand());
951 folded = true;
952 }
953 }
954 return success(folded);
955 }
956
957 //===----------------------------------------------------------------------===//
958 // AffineDmaStartOp
959 //===----------------------------------------------------------------------===//
960
961 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)962 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
963 Value srcMemRef, AffineMap srcMap,
964 ValueRange srcIndices, Value destMemRef,
965 AffineMap dstMap, ValueRange destIndices,
966 Value tagMemRef, AffineMap tagMap,
967 ValueRange tagIndices, Value numElements,
968 Value stride, Value elementsPerStride) {
969 result.addOperands(srcMemRef);
970 result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
971 result.addOperands(srcIndices);
972 result.addOperands(destMemRef);
973 result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
974 result.addOperands(destIndices);
975 result.addOperands(tagMemRef);
976 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
977 result.addOperands(tagIndices);
978 result.addOperands(numElements);
979 if (stride) {
980 result.addOperands({stride, elementsPerStride});
981 }
982 }
983
print(OpAsmPrinter & p)984 void AffineDmaStartOp::print(OpAsmPrinter &p) {
985 p << "affine.dma_start " << getSrcMemRef() << '[';
986 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
987 p << "], " << getDstMemRef() << '[';
988 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
989 p << "], " << getTagMemRef() << '[';
990 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
991 p << "], " << getNumElements();
992 if (isStrided()) {
993 p << ", " << getStride();
994 p << ", " << getNumElementsPerStride();
995 }
996 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
997 << getTagMemRefType();
998 }
999
1000 // Parse AffineDmaStartOp.
1001 // Ex:
1002 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1003 // %stride, %num_elt_per_stride
1004 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1005 //
parse(OpAsmParser & parser,OperationState & result)1006 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1007 OperationState &result) {
1008 OpAsmParser::OperandType srcMemRefInfo;
1009 AffineMapAttr srcMapAttr;
1010 SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
1011 OpAsmParser::OperandType dstMemRefInfo;
1012 AffineMapAttr dstMapAttr;
1013 SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
1014 OpAsmParser::OperandType tagMemRefInfo;
1015 AffineMapAttr tagMapAttr;
1016 SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
1017 OpAsmParser::OperandType numElementsInfo;
1018 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1019
1020 SmallVector<Type, 3> types;
1021 auto indexType = parser.getBuilder().getIndexType();
1022
1023 // Parse and resolve the following list of operands:
1024 // *) dst memref followed by its affine maps operands (in square brackets).
1025 // *) src memref followed by its affine map operands (in square brackets).
1026 // *) tag memref followed by its affine map operands (in square brackets).
1027 // *) number of elements transferred by DMA operation.
1028 if (parser.parseOperand(srcMemRefInfo) ||
1029 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1030 getSrcMapAttrName(), result.attributes) ||
1031 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1032 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1033 getDstMapAttrName(), result.attributes) ||
1034 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1035 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1036 getTagMapAttrName(), result.attributes) ||
1037 parser.parseComma() || parser.parseOperand(numElementsInfo))
1038 return failure();
1039
1040 // Parse optional stride and elements per stride.
1041 if (parser.parseTrailingOperandList(strideInfo)) {
1042 return failure();
1043 }
1044 if (!strideInfo.empty() && strideInfo.size() != 2) {
1045 return parser.emitError(parser.getNameLoc(),
1046 "expected two stride related operands");
1047 }
1048 bool isStrided = strideInfo.size() == 2;
1049
1050 if (parser.parseColonTypeList(types))
1051 return failure();
1052
1053 if (types.size() != 3)
1054 return parser.emitError(parser.getNameLoc(), "expected three types");
1055
1056 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1057 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1058 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1059 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1060 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1061 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1062 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1063 return failure();
1064
1065 if (isStrided) {
1066 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1067 return failure();
1068 }
1069
1070 // Check that src/dst/tag operand counts match their map.numInputs.
1071 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1072 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1073 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1074 return parser.emitError(parser.getNameLoc(),
1075 "memref operand count not equal to map.numInputs");
1076 return success();
1077 }
1078
verify()1079 LogicalResult AffineDmaStartOp::verify() {
1080 if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1081 return emitOpError("expected DMA source to be of memref type");
1082 if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1083 return emitOpError("expected DMA destination to be of memref type");
1084 if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1085 return emitOpError("expected DMA tag to be of memref type");
1086
1087 // DMAs from different memory spaces supported.
1088 if (getSrcMemorySpace() == getDstMemorySpace()) {
1089 return emitOpError("DMA should be between different memory spaces");
1090 }
1091 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1092 getDstMap().getNumInputs() +
1093 getTagMap().getNumInputs();
1094 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1095 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1096 return emitOpError("incorrect number of operands");
1097 }
1098
1099 Region *scope = getAffineScope(*this);
1100 for (auto idx : getSrcIndices()) {
1101 if (!idx.getType().isIndex())
1102 return emitOpError("src index to dma_start must have 'index' type");
1103 if (!isValidAffineIndexOperand(idx, scope))
1104 return emitOpError("src index must be a dimension or symbol identifier");
1105 }
1106 for (auto idx : getDstIndices()) {
1107 if (!idx.getType().isIndex())
1108 return emitOpError("dst index to dma_start must have 'index' type");
1109 if (!isValidAffineIndexOperand(idx, scope))
1110 return emitOpError("dst index must be a dimension or symbol identifier");
1111 }
1112 for (auto idx : getTagIndices()) {
1113 if (!idx.getType().isIndex())
1114 return emitOpError("tag index to dma_start must have 'index' type");
1115 if (!isValidAffineIndexOperand(idx, scope))
1116 return emitOpError("tag index must be a dimension or symbol identifier");
1117 }
1118 return success();
1119 }
1120
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1121 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1122 SmallVectorImpl<OpFoldResult> &results) {
1123 /// dma_start(memrefcast) -> dma_start
1124 return foldMemRefCast(*this);
1125 }
1126
1127 //===----------------------------------------------------------------------===//
1128 // AffineDmaWaitOp
1129 //===----------------------------------------------------------------------===//
1130
1131 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1132 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1133 Value tagMemRef, AffineMap tagMap,
1134 ValueRange tagIndices, Value numElements) {
1135 result.addOperands(tagMemRef);
1136 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1137 result.addOperands(tagIndices);
1138 result.addOperands(numElements);
1139 }
1140
print(OpAsmPrinter & p)1141 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1142 p << "affine.dma_wait " << getTagMemRef() << '[';
1143 SmallVector<Value, 2> operands(getTagIndices());
1144 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1145 p << "], ";
1146 p.printOperand(getNumElements());
1147 p << " : " << getTagMemRef().getType();
1148 }
1149
1150 // Parse AffineDmaWaitOp.
1151 // Eg:
1152 // affine.dma_wait %tag[%index], %num_elements
1153 // : memref<1 x i32, (d0) -> (d0), 4>
1154 //
parse(OpAsmParser & parser,OperationState & result)1155 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1156 OperationState &result) {
1157 OpAsmParser::OperandType tagMemRefInfo;
1158 AffineMapAttr tagMapAttr;
1159 SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1160 Type type;
1161 auto indexType = parser.getBuilder().getIndexType();
1162 OpAsmParser::OperandType numElementsInfo;
1163
1164 // Parse tag memref, its map operands, and dma size.
1165 if (parser.parseOperand(tagMemRefInfo) ||
1166 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1167 getTagMapAttrName(), result.attributes) ||
1168 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1169 parser.parseColonType(type) ||
1170 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1171 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1172 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1173 return failure();
1174
1175 if (!type.isa<MemRefType>())
1176 return parser.emitError(parser.getNameLoc(),
1177 "expected tag to be of memref type");
1178
1179 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1180 return parser.emitError(parser.getNameLoc(),
1181 "tag memref operand count != to map.numInputs");
1182 return success();
1183 }
1184
verify()1185 LogicalResult AffineDmaWaitOp::verify() {
1186 if (!getOperand(0).getType().isa<MemRefType>())
1187 return emitOpError("expected DMA tag to be of memref type");
1188 Region *scope = getAffineScope(*this);
1189 for (auto idx : getTagIndices()) {
1190 if (!idx.getType().isIndex())
1191 return emitOpError("index to dma_wait must have 'index' type");
1192 if (!isValidAffineIndexOperand(idx, scope))
1193 return emitOpError("index must be a dimension or symbol identifier");
1194 }
1195 return success();
1196 }
1197
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1198 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1199 SmallVectorImpl<OpFoldResult> &results) {
1200 /// dma_wait(memrefcast) -> dma_wait
1201 return foldMemRefCast(*this);
1202 }
1203
1204 //===----------------------------------------------------------------------===//
1205 // AffineForOp
1206 //===----------------------------------------------------------------------===//
1207
1208 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1209 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1210 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1211 ValueRange lbOperands, AffineMap lbMap,
1212 ValueRange ubOperands, AffineMap ubMap, int64_t step,
1213 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1214 assert(((!lbMap && lbOperands.empty()) ||
1215 lbOperands.size() == lbMap.getNumInputs()) &&
1216 "lower bound operand count does not match the affine map");
1217 assert(((!ubMap && ubOperands.empty()) ||
1218 ubOperands.size() == ubMap.getNumInputs()) &&
1219 "upper bound operand count does not match the affine map");
1220 assert(step > 0 && "step has to be a positive integer constant");
1221
1222 for (Value val : iterArgs)
1223 result.addTypes(val.getType());
1224
1225 // Add an attribute for the step.
1226 result.addAttribute(getStepAttrName(),
1227 builder.getIntegerAttr(builder.getIndexType(), step));
1228
1229 // Add the lower bound.
1230 result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1231 result.addOperands(lbOperands);
1232
1233 // Add the upper bound.
1234 result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1235 result.addOperands(ubOperands);
1236
1237 result.addOperands(iterArgs);
1238 // Create a region and a block for the body. The argument of the region is
1239 // the loop induction variable.
1240 Region *bodyRegion = result.addRegion();
1241 bodyRegion->push_back(new Block);
1242 Block &bodyBlock = bodyRegion->front();
1243 Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
1244 for (Value val : iterArgs)
1245 bodyBlock.addArgument(val.getType());
1246
1247 // Create the default terminator if the builder is not provided and if the
1248 // iteration arguments are not provided. Otherwise, leave this to the caller
1249 // because we don't know which values to return from the loop.
1250 if (iterArgs.empty() && !bodyBuilder) {
1251 ensureTerminator(*bodyRegion, builder, result.location);
1252 } else if (bodyBuilder) {
1253 OpBuilder::InsertionGuard guard(builder);
1254 builder.setInsertionPointToStart(&bodyBlock);
1255 bodyBuilder(builder, result.location, inductionVar,
1256 bodyBlock.getArguments().drop_front());
1257 }
1258 }
1259
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1260 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1261 int64_t ub, int64_t step, ValueRange iterArgs,
1262 BodyBuilderFn bodyBuilder) {
1263 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1264 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1265 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1266 bodyBuilder);
1267 }
1268
verify(AffineForOp op)1269 static LogicalResult verify(AffineForOp op) {
1270 // Check that the body defines as single block argument for the induction
1271 // variable.
1272 auto *body = op.getBody();
1273 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1274 return op.emitOpError(
1275 "expected body to have a single index argument for the "
1276 "induction variable");
1277
1278 // Verify that the bound operands are valid dimension/symbols.
1279 /// Lower bound.
1280 if (op.getLowerBoundMap().getNumInputs() > 0)
1281 if (failed(
1282 verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1283 op.getLowerBoundMap().getNumDims())))
1284 return failure();
1285 /// Upper bound.
1286 if (op.getUpperBoundMap().getNumInputs() > 0)
1287 if (failed(
1288 verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1289 op.getUpperBoundMap().getNumDims())))
1290 return failure();
1291
1292 unsigned opNumResults = op.getNumResults();
1293 if (opNumResults == 0)
1294 return success();
1295
1296 // If ForOp defines values, check that the number and types of the defined
1297 // values match ForOp initial iter operands and backedge basic block
1298 // arguments.
1299 if (op.getNumIterOperands() != opNumResults)
1300 return op.emitOpError(
1301 "mismatch between the number of loop-carried values and results");
1302 if (op.getNumRegionIterArgs() != opNumResults)
1303 return op.emitOpError(
1304 "mismatch between the number of basic block args and results");
1305
1306 return success();
1307 }
1308
1309 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1310 static ParseResult parseBound(bool isLower, OperationState &result,
1311 OpAsmParser &p) {
1312 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1313 // the map has multiple results.
1314 bool failedToParsedMinMax =
1315 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1316
1317 auto &builder = p.getBuilder();
1318 auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1319 : AffineForOp::getUpperBoundAttrName();
1320
1321 // Parse ssa-id as identity map.
1322 SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1323 if (p.parseOperandList(boundOpInfos))
1324 return failure();
1325
1326 if (!boundOpInfos.empty()) {
1327 // Check that only one operand was parsed.
1328 if (boundOpInfos.size() > 1)
1329 return p.emitError(p.getNameLoc(),
1330 "expected only one loop bound operand");
1331
1332 // TODO: improve error message when SSA value is not of index type.
1333 // Currently it is 'use of value ... expects different type than prior uses'
1334 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1335 result.operands))
1336 return failure();
1337
1338 // Create an identity map using symbol id. This representation is optimized
1339 // for storage. Analysis passes may expand it into a multi-dimensional map
1340 // if desired.
1341 AffineMap map = builder.getSymbolIdentityMap();
1342 result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1343 return success();
1344 }
1345
1346 // Get the attribute location.
1347 llvm::SMLoc attrLoc = p.getCurrentLocation();
1348
1349 Attribute boundAttr;
1350 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1351 result.attributes))
1352 return failure();
1353
1354 // Parse full form - affine map followed by dim and symbol list.
1355 if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1356 unsigned currentNumOperands = result.operands.size();
1357 unsigned numDims;
1358 if (parseDimAndSymbolList(p, result.operands, numDims))
1359 return failure();
1360
1361 auto map = affineMapAttr.getValue();
1362 if (map.getNumDims() != numDims)
1363 return p.emitError(
1364 p.getNameLoc(),
1365 "dim operand count and affine map dim count must match");
1366
1367 unsigned numDimAndSymbolOperands =
1368 result.operands.size() - currentNumOperands;
1369 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1370 return p.emitError(
1371 p.getNameLoc(),
1372 "symbol operand count and affine map symbol count must match");
1373
1374 // If the map has multiple results, make sure that we parsed the min/max
1375 // prefix.
1376 if (map.getNumResults() > 1 && failedToParsedMinMax) {
1377 if (isLower) {
1378 return p.emitError(attrLoc, "lower loop bound affine map with "
1379 "multiple results requires 'max' prefix");
1380 }
1381 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1382 "results requires 'min' prefix");
1383 }
1384 return success();
1385 }
1386
1387 // Parse custom assembly form.
1388 if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1389 result.attributes.pop_back();
1390 result.addAttribute(
1391 boundAttrName,
1392 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1393 return success();
1394 }
1395
1396 return p.emitError(
1397 p.getNameLoc(),
1398 "expected valid affine map representation for loop bounds");
1399 }
1400
parseAffineForOp(OpAsmParser & parser,OperationState & result)1401 static ParseResult parseAffineForOp(OpAsmParser &parser,
1402 OperationState &result) {
1403 auto &builder = parser.getBuilder();
1404 OpAsmParser::OperandType inductionVariable;
1405 // Parse the induction variable followed by '='.
1406 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1407 return failure();
1408
1409 // Parse loop bounds.
1410 if (parseBound(/*isLower=*/true, result, parser) ||
1411 parser.parseKeyword("to", " between bounds") ||
1412 parseBound(/*isLower=*/false, result, parser))
1413 return failure();
1414
1415 // Parse the optional loop step, we default to 1 if one is not present.
1416 if (parser.parseOptionalKeyword("step")) {
1417 result.addAttribute(
1418 AffineForOp::getStepAttrName(),
1419 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1420 } else {
1421 llvm::SMLoc stepLoc = parser.getCurrentLocation();
1422 IntegerAttr stepAttr;
1423 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1424 AffineForOp::getStepAttrName().data(),
1425 result.attributes))
1426 return failure();
1427
1428 if (stepAttr.getValue().getSExtValue() < 0)
1429 return parser.emitError(
1430 stepLoc,
1431 "expected step to be representable as a positive signed integer");
1432 }
1433
1434 // Parse the optional initial iteration arguments.
1435 SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1436 SmallVector<Type, 4> argTypes;
1437 regionArgs.push_back(inductionVariable);
1438
1439 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1440 // Parse assignment list and results type list.
1441 if (parser.parseAssignmentList(regionArgs, operands) ||
1442 parser.parseArrowTypeList(result.types))
1443 return failure();
1444 // Resolve input operands.
1445 for (auto operandType : llvm::zip(operands, result.types))
1446 if (parser.resolveOperand(std::get<0>(operandType),
1447 std::get<1>(operandType), result.operands))
1448 return failure();
1449 }
1450 // Induction variable.
1451 Type indexType = builder.getIndexType();
1452 argTypes.push_back(indexType);
1453 // Loop carried variables.
1454 argTypes.append(result.types.begin(), result.types.end());
1455 // Parse the body region.
1456 Region *body = result.addRegion();
1457 if (regionArgs.size() != argTypes.size())
1458 return parser.emitError(
1459 parser.getNameLoc(),
1460 "mismatch between the number of loop-carried values and results");
1461 if (parser.parseRegion(*body, regionArgs, argTypes))
1462 return failure();
1463
1464 AffineForOp::ensureTerminator(*body, builder, result.location);
1465
1466 // Parse the optional attribute list.
1467 return parser.parseOptionalAttrDict(result.attributes);
1468 }
1469
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1470 static void printBound(AffineMapAttr boundMap,
1471 Operation::operand_range boundOperands,
1472 const char *prefix, OpAsmPrinter &p) {
1473 AffineMap map = boundMap.getValue();
1474
1475 // Check if this bound should be printed using custom assembly form.
1476 // The decision to restrict printing custom assembly form to trivial cases
1477 // comes from the will to roundtrip MLIR binary -> text -> binary in a
1478 // lossless way.
1479 // Therefore, custom assembly form parsing and printing is only supported for
1480 // zero-operand constant maps and single symbol operand identity maps.
1481 if (map.getNumResults() == 1) {
1482 AffineExpr expr = map.getResult(0);
1483
1484 // Print constant bound.
1485 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1486 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1487 p << constExpr.getValue();
1488 return;
1489 }
1490 }
1491
1492 // Print bound that consists of a single SSA symbol if the map is over a
1493 // single symbol.
1494 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1495 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1496 p.printOperand(*boundOperands.begin());
1497 return;
1498 }
1499 }
1500 } else {
1501 // Map has multiple results. Print 'min' or 'max' prefix.
1502 p << prefix << ' ';
1503 }
1504
1505 // Print the map and its operands.
1506 p << boundMap;
1507 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1508 map.getNumDims(), p);
1509 }
1510
getNumIterOperands()1511 unsigned AffineForOp::getNumIterOperands() {
1512 AffineMap lbMap = getLowerBoundMapAttr().getValue();
1513 AffineMap ubMap = getUpperBoundMapAttr().getValue();
1514
1515 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1516 }
1517
print(OpAsmPrinter & p,AffineForOp op)1518 static void print(OpAsmPrinter &p, AffineForOp op) {
1519 p << op.getOperationName() << ' ';
1520 p.printOperand(op.getBody()->getArgument(0));
1521 p << " = ";
1522 printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1523 p << " to ";
1524 printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1525
1526 if (op.getStep() != 1)
1527 p << " step " << op.getStep();
1528
1529 bool printBlockTerminators = false;
1530 if (op.getNumIterOperands() > 0) {
1531 p << " iter_args(";
1532 auto regionArgs = op.getRegionIterArgs();
1533 auto operands = op.getIterOperands();
1534
1535 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1536 p << std::get<0>(it) << " = " << std::get<1>(it);
1537 });
1538 p << ") -> (" << op.getResultTypes() << ")";
1539 printBlockTerminators = true;
1540 }
1541
1542 p.printRegion(op.region(),
1543 /*printEntryBlockArgs=*/false, printBlockTerminators);
1544 p.printOptionalAttrDict(op.getAttrs(),
1545 /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1546 op.getUpperBoundAttrName(),
1547 op.getStepAttrName()});
1548 }
1549
1550 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1551 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1552 auto foldLowerOrUpperBound = [&forOp](bool lower) {
1553 // Check to see if each of the operands is the result of a constant. If
1554 // so, get the value. If not, ignore it.
1555 SmallVector<Attribute, 8> operandConstants;
1556 auto boundOperands =
1557 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1558 for (auto operand : boundOperands) {
1559 Attribute operandCst;
1560 matchPattern(operand, m_Constant(&operandCst));
1561 operandConstants.push_back(operandCst);
1562 }
1563
1564 AffineMap boundMap =
1565 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1566 assert(boundMap.getNumResults() >= 1 &&
1567 "bound maps should have at least one result");
1568 SmallVector<Attribute, 4> foldedResults;
1569 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1570 return failure();
1571
1572 // Compute the max or min as applicable over the results.
1573 assert(!foldedResults.empty() && "bounds should have at least one result");
1574 auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1575 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1576 auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1577 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1578 : llvm::APIntOps::smin(maxOrMin, foldedResult);
1579 }
1580 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1581 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1582 return success();
1583 };
1584
1585 // Try to fold the lower bound.
1586 bool folded = false;
1587 if (!forOp.hasConstantLowerBound())
1588 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1589
1590 // Try to fold the upper bound.
1591 if (!forOp.hasConstantUpperBound())
1592 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1593 return success(folded);
1594 }
1595
1596 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1597 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1598 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1599 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1600
1601 auto lbMap = forOp.getLowerBoundMap();
1602 auto ubMap = forOp.getUpperBoundMap();
1603 auto prevLbMap = lbMap;
1604 auto prevUbMap = ubMap;
1605
1606 canonicalizeMapAndOperands(&lbMap, &lbOperands);
1607 lbMap = removeDuplicateExprs(lbMap);
1608
1609 canonicalizeMapAndOperands(&ubMap, &ubOperands);
1610 ubMap = removeDuplicateExprs(ubMap);
1611
1612 // Any canonicalization change always leads to updated map(s).
1613 if (lbMap == prevLbMap && ubMap == prevUbMap)
1614 return failure();
1615
1616 if (lbMap != prevLbMap)
1617 forOp.setLowerBound(lbOperands, lbMap);
1618 if (ubMap != prevUbMap)
1619 forOp.setUpperBound(ubOperands, ubMap);
1620 return success();
1621 }
1622
1623 namespace {
1624 /// This is a pattern to fold trivially empty loops.
1625 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1626 using OpRewritePattern<AffineForOp>::OpRewritePattern;
1627
matchAndRewrite__anonb6f842fb0d11::AffineForEmptyLoopFolder1628 LogicalResult matchAndRewrite(AffineForOp forOp,
1629 PatternRewriter &rewriter) const override {
1630 // Check that the body only contains a yield.
1631 if (!llvm::hasSingleElement(*forOp.getBody()))
1632 return failure();
1633 rewriter.eraseOp(forOp);
1634 return success();
1635 }
1636 };
1637 } // end anonymous namespace
1638
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1639 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1640 MLIRContext *context) {
1641 results.insert<AffineForEmptyLoopFolder>(context);
1642 }
1643
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1644 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1645 SmallVectorImpl<OpFoldResult> &results) {
1646 bool folded = succeeded(foldLoopBounds(*this));
1647 folded |= succeeded(canonicalizeLoopBounds(*this));
1648 return success(folded);
1649 }
1650
getLowerBound()1651 AffineBound AffineForOp::getLowerBound() {
1652 auto lbMap = getLowerBoundMap();
1653 return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1654 }
1655
getUpperBound()1656 AffineBound AffineForOp::getUpperBound() {
1657 auto lbMap = getLowerBoundMap();
1658 auto ubMap = getUpperBoundMap();
1659 return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1660 lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1661 }
1662
setLowerBound(ValueRange lbOperands,AffineMap map)1663 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1664 assert(lbOperands.size() == map.getNumInputs());
1665 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1666
1667 SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1668
1669 auto ubOperands = getUpperBoundOperands();
1670 newOperands.append(ubOperands.begin(), ubOperands.end());
1671 auto iterOperands = getIterOperands();
1672 newOperands.append(iterOperands.begin(), iterOperands.end());
1673 (*this)->setOperands(newOperands);
1674
1675 setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1676 }
1677
setUpperBound(ValueRange ubOperands,AffineMap map)1678 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1679 assert(ubOperands.size() == map.getNumInputs());
1680 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1681
1682 SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1683 newOperands.append(ubOperands.begin(), ubOperands.end());
1684 auto iterOperands = getIterOperands();
1685 newOperands.append(iterOperands.begin(), iterOperands.end());
1686 (*this)->setOperands(newOperands);
1687
1688 setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1689 }
1690
setLowerBoundMap(AffineMap map)1691 void AffineForOp::setLowerBoundMap(AffineMap map) {
1692 auto lbMap = getLowerBoundMap();
1693 assert(lbMap.getNumDims() == map.getNumDims() &&
1694 lbMap.getNumSymbols() == map.getNumSymbols());
1695 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1696 (void)lbMap;
1697 setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1698 }
1699
setUpperBoundMap(AffineMap map)1700 void AffineForOp::setUpperBoundMap(AffineMap map) {
1701 auto ubMap = getUpperBoundMap();
1702 assert(ubMap.getNumDims() == map.getNumDims() &&
1703 ubMap.getNumSymbols() == map.getNumSymbols());
1704 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1705 (void)ubMap;
1706 setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1707 }
1708
hasConstantLowerBound()1709 bool AffineForOp::hasConstantLowerBound() {
1710 return getLowerBoundMap().isSingleConstant();
1711 }
1712
hasConstantUpperBound()1713 bool AffineForOp::hasConstantUpperBound() {
1714 return getUpperBoundMap().isSingleConstant();
1715 }
1716
getConstantLowerBound()1717 int64_t AffineForOp::getConstantLowerBound() {
1718 return getLowerBoundMap().getSingleConstantResult();
1719 }
1720
getConstantUpperBound()1721 int64_t AffineForOp::getConstantUpperBound() {
1722 return getUpperBoundMap().getSingleConstantResult();
1723 }
1724
setConstantLowerBound(int64_t value)1725 void AffineForOp::setConstantLowerBound(int64_t value) {
1726 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1727 }
1728
setConstantUpperBound(int64_t value)1729 void AffineForOp::setConstantUpperBound(int64_t value) {
1730 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1731 }
1732
getLowerBoundOperands()1733 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1734 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1735 }
1736
getUpperBoundOperands()1737 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1738 return {operand_begin() + getLowerBoundMap().getNumInputs(),
1739 operand_begin() + getLowerBoundMap().getNumInputs() +
1740 getUpperBoundMap().getNumInputs()};
1741 }
1742
matchingBoundOperandList()1743 bool AffineForOp::matchingBoundOperandList() {
1744 auto lbMap = getLowerBoundMap();
1745 auto ubMap = getUpperBoundMap();
1746 if (lbMap.getNumDims() != ubMap.getNumDims() ||
1747 lbMap.getNumSymbols() != ubMap.getNumSymbols())
1748 return false;
1749
1750 unsigned numOperands = lbMap.getNumInputs();
1751 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1752 // Compare Value 's.
1753 if (getOperand(i) != getOperand(numOperands + i))
1754 return false;
1755 }
1756 return true;
1757 }
1758
getLoopBody()1759 Region &AffineForOp::getLoopBody() { return region(); }
1760
isDefinedOutsideOfLoop(Value value)1761 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1762 return !region().isAncestor(value.getParentRegion());
1763 }
1764
moveOutOfLoop(ArrayRef<Operation * > ops)1765 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1766 for (auto *op : ops)
1767 op->moveBefore(*this);
1768 return success();
1769 }
1770
1771 /// Returns true if the provided value is the induction variable of a
1772 /// AffineForOp.
isForInductionVar(Value val)1773 bool mlir::isForInductionVar(Value val) {
1774 return getForInductionVarOwner(val) != AffineForOp();
1775 }
1776
1777 /// Returns the loop parent of an induction variable. If the provided value is
1778 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1779 AffineForOp mlir::getForInductionVarOwner(Value val) {
1780 auto ivArg = val.dyn_cast<BlockArgument>();
1781 if (!ivArg || !ivArg.getOwner())
1782 return AffineForOp();
1783 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1784 return dyn_cast<AffineForOp>(containingInst);
1785 }
1786
1787 /// Extracts the induction variables from a list of AffineForOps and returns
1788 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1789 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1790 SmallVectorImpl<Value> *ivs) {
1791 ivs->reserve(forInsts.size());
1792 for (auto forInst : forInsts)
1793 ivs->push_back(forInst.getInductionVar());
1794 }
1795
1796 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1797 /// operations.
1798 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1799 static void buildAffineLoopNestImpl(
1800 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1801 ArrayRef<int64_t> steps,
1802 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1803 LoopCreatorTy &&loopCreatorFn) {
1804 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1805 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1806
1807 // If there are no loops to be constructed, construct the body anyway.
1808 OpBuilder::InsertionGuard guard(builder);
1809 if (lbs.empty()) {
1810 if (bodyBuilderFn)
1811 bodyBuilderFn(builder, loc, ValueRange());
1812 return;
1813 }
1814
1815 // Create the loops iteratively and store the induction variables.
1816 SmallVector<Value, 4> ivs;
1817 ivs.reserve(lbs.size());
1818 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1819 // Callback for creating the loop body, always creates the terminator.
1820 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
1821 ValueRange iterArgs) {
1822 ivs.push_back(iv);
1823 // In the innermost loop, call the body builder.
1824 if (i == e - 1 && bodyBuilderFn) {
1825 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1826 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1827 }
1828 nestedBuilder.create<AffineYieldOp>(nestedLoc);
1829 };
1830
1831 // Delegate actual loop creation to the callback in order to dispatch
1832 // between constant- and variable-bound loops.
1833 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1834 builder.setInsertionPointToStart(loop.getBody());
1835 }
1836 }
1837
1838 /// Creates an affine loop from the bounds known to be constants.
1839 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1840 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
1841 int64_t ub, int64_t step,
1842 AffineForOp::BodyBuilderFn bodyBuilderFn) {
1843 return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
1844 bodyBuilderFn);
1845 }
1846
1847 /// Creates an affine loop from the bounds that may or may not be constants.
1848 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1849 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
1850 int64_t step,
1851 AffineForOp::BodyBuilderFn bodyBuilderFn) {
1852 auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1853 auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1854 if (lbConst && ubConst)
1855 return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1856 ubConst.getValue(), step,
1857 bodyBuilderFn);
1858 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1859 builder.getDimIdentityMap(), step,
1860 /*iterArgs=*/llvm::None, bodyBuilderFn);
1861 }
1862
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1863 void mlir::buildAffineLoopNest(
1864 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1865 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1866 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1867 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1868 buildAffineLoopFromConstants);
1869 }
1870
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1871 void mlir::buildAffineLoopNest(
1872 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1873 ArrayRef<int64_t> steps,
1874 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1875 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1876 buildAffineLoopFromValues);
1877 }
1878
1879 //===----------------------------------------------------------------------===//
1880 // AffineIfOp
1881 //===----------------------------------------------------------------------===//
1882
1883 namespace {
1884 /// Remove else blocks that have nothing other than a zero value yield.
1885 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1886 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1887
matchAndRewrite__anonb6f842fb0f11::SimplifyDeadElse1888 LogicalResult matchAndRewrite(AffineIfOp ifOp,
1889 PatternRewriter &rewriter) const override {
1890 if (ifOp.elseRegion().empty() ||
1891 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1892 return failure();
1893
1894 rewriter.startRootUpdate(ifOp);
1895 rewriter.eraseBlock(ifOp.getElseBlock());
1896 rewriter.finalizeRootUpdate(ifOp);
1897 return success();
1898 }
1899 };
1900 } // end anonymous namespace.
1901
verify(AffineIfOp op)1902 static LogicalResult verify(AffineIfOp op) {
1903 // Verify that we have a condition attribute.
1904 auto conditionAttr =
1905 op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1906 if (!conditionAttr)
1907 return op.emitOpError(
1908 "requires an integer set attribute named 'condition'");
1909
1910 // Verify that there are enough operands for the condition.
1911 IntegerSet condition = conditionAttr.getValue();
1912 if (op.getNumOperands() != condition.getNumInputs())
1913 return op.emitOpError(
1914 "operand count and condition integer set dimension and "
1915 "symbol count must match");
1916
1917 // Verify that the operands are valid dimension/symbols.
1918 if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
1919 condition.getNumDims())))
1920 return failure();
1921
1922 return success();
1923 }
1924
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1925 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1926 OperationState &result) {
1927 // Parse the condition attribute set.
1928 IntegerSetAttr conditionAttr;
1929 unsigned numDims;
1930 if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1931 result.attributes) ||
1932 parseDimAndSymbolList(parser, result.operands, numDims))
1933 return failure();
1934
1935 // Verify the condition operands.
1936 auto set = conditionAttr.getValue();
1937 if (set.getNumDims() != numDims)
1938 return parser.emitError(
1939 parser.getNameLoc(),
1940 "dim operand count and integer set dim count must match");
1941 if (numDims + set.getNumSymbols() != result.operands.size())
1942 return parser.emitError(
1943 parser.getNameLoc(),
1944 "symbol operand count and integer set symbol count must match");
1945
1946 if (parser.parseOptionalArrowTypeList(result.types))
1947 return failure();
1948
1949 // Create the regions for 'then' and 'else'. The latter must be created even
1950 // if it remains empty for the validity of the operation.
1951 result.regions.reserve(2);
1952 Region *thenRegion = result.addRegion();
1953 Region *elseRegion = result.addRegion();
1954
1955 // Parse the 'then' region.
1956 if (parser.parseRegion(*thenRegion, {}, {}))
1957 return failure();
1958 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1959 result.location);
1960
1961 // If we find an 'else' keyword then parse the 'else' region.
1962 if (!parser.parseOptionalKeyword("else")) {
1963 if (parser.parseRegion(*elseRegion, {}, {}))
1964 return failure();
1965 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1966 result.location);
1967 }
1968
1969 // Parse the optional attribute list.
1970 if (parser.parseOptionalAttrDict(result.attributes))
1971 return failure();
1972
1973 return success();
1974 }
1975
print(OpAsmPrinter & p,AffineIfOp op)1976 static void print(OpAsmPrinter &p, AffineIfOp op) {
1977 auto conditionAttr =
1978 op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1979 p << "affine.if " << conditionAttr;
1980 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1981 conditionAttr.getValue().getNumDims(), p);
1982 p.printOptionalArrowTypeList(op.getResultTypes());
1983 p.printRegion(op.thenRegion(),
1984 /*printEntryBlockArgs=*/false,
1985 /*printBlockTerminators=*/op.getNumResults());
1986
1987 // Print the 'else' regions if it has any blocks.
1988 auto &elseRegion = op.elseRegion();
1989 if (!elseRegion.empty()) {
1990 p << " else";
1991 p.printRegion(elseRegion,
1992 /*printEntryBlockArgs=*/false,
1993 /*printBlockTerminators=*/op.getNumResults());
1994 }
1995
1996 // Print the attribute list.
1997 p.printOptionalAttrDict(op.getAttrs(),
1998 /*elidedAttrs=*/op.getConditionAttrName());
1999 }
2000
getIntegerSet()2001 IntegerSet AffineIfOp::getIntegerSet() {
2002 return (*this)
2003 ->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
2004 .getValue();
2005 }
setIntegerSet(IntegerSet newSet)2006 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2007 setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
2008 }
2009
setConditional(IntegerSet set,ValueRange operands)2010 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2011 setIntegerSet(set);
2012 (*this)->setOperands(operands);
2013 }
2014
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)2015 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2016 TypeRange resultTypes, IntegerSet set, ValueRange args,
2017 bool withElseRegion) {
2018 assert(resultTypes.empty() || withElseRegion);
2019 result.addTypes(resultTypes);
2020 result.addOperands(args);
2021 result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
2022
2023 Region *thenRegion = result.addRegion();
2024 thenRegion->push_back(new Block());
2025 if (resultTypes.empty())
2026 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2027
2028 Region *elseRegion = result.addRegion();
2029 if (withElseRegion) {
2030 elseRegion->push_back(new Block());
2031 if (resultTypes.empty())
2032 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2033 }
2034 }
2035
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2036 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2037 IntegerSet set, ValueRange args, bool withElseRegion) {
2038 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2039 withElseRegion);
2040 }
2041
2042 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2043 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2044 SmallVectorImpl<OpFoldResult> &) {
2045 auto set = getIntegerSet();
2046 SmallVector<Value, 4> operands(getOperands());
2047 canonicalizeSetAndOperands(&set, &operands);
2048
2049 // Any canonicalization change always leads to either a reduction in the
2050 // number of operands or a change in the number of symbolic operands
2051 // (promotion of dims to symbols).
2052 if (operands.size() < getIntegerSet().getNumInputs() ||
2053 set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
2054 setConditional(set, operands);
2055 return success();
2056 }
2057
2058 return failure();
2059 }
2060
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2061 void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2062 MLIRContext *context) {
2063 results.insert<SimplifyDeadElse>(context);
2064 }
2065
2066 //===----------------------------------------------------------------------===//
2067 // AffineLoadOp
2068 //===----------------------------------------------------------------------===//
2069
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2070 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2071 AffineMap map, ValueRange operands) {
2072 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2073 result.addOperands(operands);
2074 if (map)
2075 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2076 auto memrefType = operands[0].getType().cast<MemRefType>();
2077 result.types.push_back(memrefType.getElementType());
2078 }
2079
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2080 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2081 Value memref, AffineMap map, ValueRange mapOperands) {
2082 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2083 result.addOperands(memref);
2084 result.addOperands(mapOperands);
2085 auto memrefType = memref.getType().cast<MemRefType>();
2086 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2087 result.types.push_back(memrefType.getElementType());
2088 }
2089
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2090 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2091 Value memref, ValueRange indices) {
2092 auto memrefType = memref.getType().cast<MemRefType>();
2093 int64_t rank = memrefType.getRank();
2094 // Create identity map for memrefs with at least one dimension or () -> ()
2095 // for zero-dimensional memrefs.
2096 auto map =
2097 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2098 build(builder, result, memref, map, indices);
2099 }
2100
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)2101 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
2102 OperationState &result) {
2103 auto &builder = parser.getBuilder();
2104 auto indexTy = builder.getIndexType();
2105
2106 MemRefType type;
2107 OpAsmParser::OperandType memrefInfo;
2108 AffineMapAttr mapAttr;
2109 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2110 return failure(
2111 parser.parseOperand(memrefInfo) ||
2112 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2113 AffineLoadOp::getMapAttrName(),
2114 result.attributes) ||
2115 parser.parseOptionalAttrDict(result.attributes) ||
2116 parser.parseColonType(type) ||
2117 parser.resolveOperand(memrefInfo, type, result.operands) ||
2118 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2119 parser.addTypeToList(type.getElementType(), result.types));
2120 }
2121
print(OpAsmPrinter & p,AffineLoadOp op)2122 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2123 p << "affine.load " << op.getMemRef() << '[';
2124 if (AffineMapAttr mapAttr =
2125 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2126 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2127 p << ']';
2128 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2129 p << " : " << op.getMemRefType();
2130 }
2131
2132 /// Verify common indexing invariants of affine.load, affine.store,
2133 /// affine.vector_load and affine.vector_store.
2134 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2135 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2136 Operation::operand_range mapOperands,
2137 MemRefType memrefType, unsigned numIndexOperands) {
2138 if (mapAttr) {
2139 AffineMap map = mapAttr.getValue();
2140 if (map.getNumResults() != memrefType.getRank())
2141 return op->emitOpError("affine map num results must equal memref rank");
2142 if (map.getNumInputs() != numIndexOperands)
2143 return op->emitOpError("expects as many subscripts as affine map inputs");
2144 } else {
2145 if (memrefType.getRank() != numIndexOperands)
2146 return op->emitOpError(
2147 "expects the number of subscripts to be equal to memref rank");
2148 }
2149
2150 Region *scope = getAffineScope(op);
2151 for (auto idx : mapOperands) {
2152 if (!idx.getType().isIndex())
2153 return op->emitOpError("index to load must have 'index' type");
2154 if (!isValidAffineIndexOperand(idx, scope))
2155 return op->emitOpError("index must be a dimension or symbol identifier");
2156 }
2157
2158 return success();
2159 }
2160
verify(AffineLoadOp op)2161 LogicalResult verify(AffineLoadOp op) {
2162 auto memrefType = op.getMemRefType();
2163 if (op.getType() != memrefType.getElementType())
2164 return op.emitOpError("result type must match element type of memref");
2165
2166 if (failed(verifyMemoryOpIndexing(
2167 op.getOperation(),
2168 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2169 op.getMapOperands(), memrefType,
2170 /*numIndexOperands=*/op.getNumOperands() - 1)))
2171 return failure();
2172
2173 return success();
2174 }
2175
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2176 void AffineLoadOp::getCanonicalizationPatterns(
2177 OwningRewritePatternList &results, MLIRContext *context) {
2178 results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
2179 }
2180
fold(ArrayRef<Attribute> cstOperands)2181 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2182 /// load(memrefcast) -> load
2183 if (succeeded(foldMemRefCast(*this)))
2184 return getResult();
2185 return OpFoldResult();
2186 }
2187
2188 //===----------------------------------------------------------------------===//
2189 // AffineStoreOp
2190 //===----------------------------------------------------------------------===//
2191
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2192 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2193 Value valueToStore, Value memref, AffineMap map,
2194 ValueRange mapOperands) {
2195 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2196 result.addOperands(valueToStore);
2197 result.addOperands(memref);
2198 result.addOperands(mapOperands);
2199 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2200 }
2201
2202 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2203 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2204 Value valueToStore, Value memref,
2205 ValueRange indices) {
2206 auto memrefType = memref.getType().cast<MemRefType>();
2207 int64_t rank = memrefType.getRank();
2208 // Create identity map for memrefs with at least one dimension or () -> ()
2209 // for zero-dimensional memrefs.
2210 auto map =
2211 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2212 build(builder, result, valueToStore, memref, map, indices);
2213 }
2214
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2215 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2216 OperationState &result) {
2217 auto indexTy = parser.getBuilder().getIndexType();
2218
2219 MemRefType type;
2220 OpAsmParser::OperandType storeValueInfo;
2221 OpAsmParser::OperandType memrefInfo;
2222 AffineMapAttr mapAttr;
2223 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2224 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2225 parser.parseOperand(memrefInfo) ||
2226 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2227 AffineStoreOp::getMapAttrName(),
2228 result.attributes) ||
2229 parser.parseOptionalAttrDict(result.attributes) ||
2230 parser.parseColonType(type) ||
2231 parser.resolveOperand(storeValueInfo, type.getElementType(),
2232 result.operands) ||
2233 parser.resolveOperand(memrefInfo, type, result.operands) ||
2234 parser.resolveOperands(mapOperands, indexTy, result.operands));
2235 }
2236
print(OpAsmPrinter & p,AffineStoreOp op)2237 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2238 p << "affine.store " << op.getValueToStore();
2239 p << ", " << op.getMemRef() << '[';
2240 if (AffineMapAttr mapAttr =
2241 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2242 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2243 p << ']';
2244 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2245 p << " : " << op.getMemRefType();
2246 }
2247
verify(AffineStoreOp op)2248 LogicalResult verify(AffineStoreOp op) {
2249 // First operand must have same type as memref element type.
2250 auto memrefType = op.getMemRefType();
2251 if (op.getValueToStore().getType() != memrefType.getElementType())
2252 return op.emitOpError(
2253 "first operand must have same type memref element type");
2254
2255 if (failed(verifyMemoryOpIndexing(
2256 op.getOperation(),
2257 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2258 op.getMapOperands(), memrefType,
2259 /*numIndexOperands=*/op.getNumOperands() - 2)))
2260 return failure();
2261
2262 return success();
2263 }
2264
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2265 void AffineStoreOp::getCanonicalizationPatterns(
2266 OwningRewritePatternList &results, MLIRContext *context) {
2267 results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
2268 }
2269
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2270 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2271 SmallVectorImpl<OpFoldResult> &results) {
2272 /// store(memrefcast) -> store
2273 return foldMemRefCast(*this);
2274 }
2275
2276 //===----------------------------------------------------------------------===//
2277 // AffineMinMaxOpBase
2278 //===----------------------------------------------------------------------===//
2279
2280 template <typename T>
verifyAffineMinMaxOp(T op)2281 static LogicalResult verifyAffineMinMaxOp(T op) {
2282 // Verify that operand count matches affine map dimension and symbol count.
2283 if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2284 return op.emitOpError(
2285 "operand count and affine map dimension and symbol count must match");
2286 return success();
2287 }
2288
2289 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2290 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2291 p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName());
2292 auto operands = op.getOperands();
2293 unsigned numDims = op.map().getNumDims();
2294 p << '(' << operands.take_front(numDims) << ')';
2295
2296 if (operands.size() != numDims)
2297 p << '[' << operands.drop_front(numDims) << ']';
2298 p.printOptionalAttrDict(op.getAttrs(),
2299 /*elidedAttrs=*/{T::getMapAttrName()});
2300 }
2301
2302 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2303 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2304 OperationState &result) {
2305 auto &builder = parser.getBuilder();
2306 auto indexType = builder.getIndexType();
2307 SmallVector<OpAsmParser::OperandType, 8> dim_infos;
2308 SmallVector<OpAsmParser::OperandType, 8> sym_infos;
2309 AffineMapAttr mapAttr;
2310 return failure(
2311 parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2312 parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
2313 parser.parseOperandList(sym_infos,
2314 OpAsmParser::Delimiter::OptionalSquare) ||
2315 parser.parseOptionalAttrDict(result.attributes) ||
2316 parser.resolveOperands(dim_infos, indexType, result.operands) ||
2317 parser.resolveOperands(sym_infos, indexType, result.operands) ||
2318 parser.addTypeToList(indexType, result.types));
2319 }
2320
2321 /// Fold an affine min or max operation with the given operands. The operand
2322 /// list may contain nulls, which are interpreted as the operand not being a
2323 /// constant.
2324 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2325 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2326 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2327 "expected affine min or max op");
2328
2329 // Fold the affine map.
2330 // TODO: Fold more cases:
2331 // min(some_affine, some_affine + constant, ...), etc.
2332 SmallVector<int64_t, 2> results;
2333 auto foldedMap = op.map().partialConstantFold(operands, &results);
2334
2335 // If some of the map results are not constant, try changing the map in-place.
2336 if (results.empty()) {
2337 // If the map is the same, report that folding did not happen.
2338 if (foldedMap == op.map())
2339 return {};
2340 op.setAttr("map", AffineMapAttr::get(foldedMap));
2341 return op.getResult();
2342 }
2343
2344 // Otherwise, completely fold the op into a constant.
2345 auto resultIt = std::is_same<T, AffineMinOp>::value
2346 ? std::min_element(results.begin(), results.end())
2347 : std::max_element(results.begin(), results.end());
2348 if (resultIt == results.end())
2349 return {};
2350 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2351 }
2352
2353 //===----------------------------------------------------------------------===//
2354 // AffineMinOp
2355 //===----------------------------------------------------------------------===//
2356 //
2357 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2358 //
2359
fold(ArrayRef<Attribute> operands)2360 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2361 return foldMinMaxOp(*this, operands);
2362 }
2363
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2364 void AffineMinOp::getCanonicalizationPatterns(
2365 OwningRewritePatternList &patterns, MLIRContext *context) {
2366 patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
2367 }
2368
2369 //===----------------------------------------------------------------------===//
2370 // AffineMaxOp
2371 //===----------------------------------------------------------------------===//
2372 //
2373 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2374 //
2375
fold(ArrayRef<Attribute> operands)2376 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2377 return foldMinMaxOp(*this, operands);
2378 }
2379
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2380 void AffineMaxOp::getCanonicalizationPatterns(
2381 OwningRewritePatternList &patterns, MLIRContext *context) {
2382 patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
2383 }
2384
2385 //===----------------------------------------------------------------------===//
2386 // AffinePrefetchOp
2387 //===----------------------------------------------------------------------===//
2388
2389 //
2390 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2391 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2392 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2393 OperationState &result) {
2394 auto &builder = parser.getBuilder();
2395 auto indexTy = builder.getIndexType();
2396
2397 MemRefType type;
2398 OpAsmParser::OperandType memrefInfo;
2399 IntegerAttr hintInfo;
2400 auto i32Type = parser.getBuilder().getIntegerType(32);
2401 StringRef readOrWrite, cacheType;
2402
2403 AffineMapAttr mapAttr;
2404 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2405 if (parser.parseOperand(memrefInfo) ||
2406 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2407 AffinePrefetchOp::getMapAttrName(),
2408 result.attributes) ||
2409 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2410 parser.parseComma() || parser.parseKeyword("locality") ||
2411 parser.parseLess() ||
2412 parser.parseAttribute(hintInfo, i32Type,
2413 AffinePrefetchOp::getLocalityHintAttrName(),
2414 result.attributes) ||
2415 parser.parseGreater() || parser.parseComma() ||
2416 parser.parseKeyword(&cacheType) ||
2417 parser.parseOptionalAttrDict(result.attributes) ||
2418 parser.parseColonType(type) ||
2419 parser.resolveOperand(memrefInfo, type, result.operands) ||
2420 parser.resolveOperands(mapOperands, indexTy, result.operands))
2421 return failure();
2422
2423 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2424 return parser.emitError(parser.getNameLoc(),
2425 "rw specifier has to be 'read' or 'write'");
2426 result.addAttribute(
2427 AffinePrefetchOp::getIsWriteAttrName(),
2428 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2429
2430 if (!cacheType.equals("data") && !cacheType.equals("instr"))
2431 return parser.emitError(parser.getNameLoc(),
2432 "cache type has to be 'data' or 'instr'");
2433
2434 result.addAttribute(
2435 AffinePrefetchOp::getIsDataCacheAttrName(),
2436 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2437
2438 return success();
2439 }
2440
print(OpAsmPrinter & p,AffinePrefetchOp op)2441 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2442 p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2443 AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2444 if (mapAttr) {
2445 SmallVector<Value, 2> operands(op.getMapOperands());
2446 p.printAffineMapOfSSAIds(mapAttr, operands);
2447 }
2448 p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2449 << "locality<" << op.localityHint() << ">, "
2450 << (op.isDataCache() ? "data" : "instr");
2451 p.printOptionalAttrDict(
2452 op.getAttrs(),
2453 /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2454 op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2455 p << " : " << op.getMemRefType();
2456 }
2457
verify(AffinePrefetchOp op)2458 static LogicalResult verify(AffinePrefetchOp op) {
2459 auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2460 if (mapAttr) {
2461 AffineMap map = mapAttr.getValue();
2462 if (map.getNumResults() != op.getMemRefType().getRank())
2463 return op.emitOpError("affine.prefetch affine map num results must equal"
2464 " memref rank");
2465 if (map.getNumInputs() + 1 != op.getNumOperands())
2466 return op.emitOpError("too few operands");
2467 } else {
2468 if (op.getNumOperands() != 1)
2469 return op.emitOpError("too few operands");
2470 }
2471
2472 Region *scope = getAffineScope(op);
2473 for (auto idx : op.getMapOperands()) {
2474 if (!isValidAffineIndexOperand(idx, scope))
2475 return op.emitOpError("index must be a dimension or symbol identifier");
2476 }
2477 return success();
2478 }
2479
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2480 void AffinePrefetchOp::getCanonicalizationPatterns(
2481 OwningRewritePatternList &results, MLIRContext *context) {
2482 // prefetch(memrefcast) -> prefetch
2483 results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2484 }
2485
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2486 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2487 SmallVectorImpl<OpFoldResult> &results) {
2488 /// prefetch(memrefcast) -> prefetch
2489 return foldMemRefCast(*this);
2490 }
2491
2492 //===----------------------------------------------------------------------===//
2493 // AffineParallelOp
2494 //===----------------------------------------------------------------------===//
2495
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2496 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2497 TypeRange resultTypes,
2498 ArrayRef<AtomicRMWKind> reductions,
2499 ArrayRef<int64_t> ranges) {
2500 SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
2501 builder.getAffineConstantExpr(0));
2502 auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
2503 SmallVector<AffineExpr, 8> ubExprs;
2504 for (int64_t range : ranges)
2505 ubExprs.push_back(builder.getAffineConstantExpr(range));
2506 auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
2507 build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap,
2508 /*ubArgs=*/{});
2509 }
2510
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs)2511 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2512 TypeRange resultTypes,
2513 ArrayRef<AtomicRMWKind> reductions,
2514 AffineMap lbMap, ValueRange lbArgs,
2515 AffineMap ubMap, ValueRange ubArgs) {
2516 auto numDims = lbMap.getNumResults();
2517 // Verify that the dimensionality of both maps are the same.
2518 assert(numDims == ubMap.getNumResults() &&
2519 "num dims and num results mismatch");
2520 // Make default step sizes of 1.
2521 SmallVector<int64_t, 8> steps(numDims, 1);
2522 build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs,
2523 steps);
2524 }
2525
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs,ArrayRef<int64_t> steps)2526 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2527 TypeRange resultTypes,
2528 ArrayRef<AtomicRMWKind> reductions,
2529 AffineMap lbMap, ValueRange lbArgs,
2530 AffineMap ubMap, ValueRange ubArgs,
2531 ArrayRef<int64_t> steps) {
2532 auto numDims = lbMap.getNumResults();
2533 // Verify that the dimensionality of the maps matches the number of steps.
2534 assert(numDims == ubMap.getNumResults() &&
2535 "num dims and num results mismatch");
2536 assert(numDims == steps.size() && "num dims and num steps mismatch");
2537
2538 result.addTypes(resultTypes);
2539 // Convert the reductions to integer attributes.
2540 SmallVector<Attribute, 4> reductionAttrs;
2541 for (AtomicRMWKind reduction : reductions)
2542 reductionAttrs.push_back(
2543 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2544 result.addAttribute(getReductionsAttrName(),
2545 builder.getArrayAttr(reductionAttrs));
2546 result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2547 result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2548 result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2549 result.addOperands(lbArgs);
2550 result.addOperands(ubArgs);
2551 // Create a region and a block for the body.
2552 auto bodyRegion = result.addRegion();
2553 auto body = new Block();
2554 // Add all the block arguments.
2555 for (unsigned i = 0; i < numDims; ++i)
2556 body->addArgument(IndexType::get(builder.getContext()));
2557 bodyRegion->push_back(body);
2558 if (resultTypes.empty())
2559 ensureTerminator(*bodyRegion, builder, result.location);
2560 }
2561
getLoopBody()2562 Region &AffineParallelOp::getLoopBody() { return region(); }
2563
isDefinedOutsideOfLoop(Value value)2564 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2565 return !region().isAncestor(value.getParentRegion());
2566 }
2567
moveOutOfLoop(ArrayRef<Operation * > ops)2568 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2569 for (Operation *op : ops)
2570 op->moveBefore(*this);
2571 return success();
2572 }
2573
getNumDims()2574 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2575
getLowerBoundsOperands()2576 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2577 return getOperands().take_front(lowerBoundsMap().getNumInputs());
2578 }
2579
getUpperBoundsOperands()2580 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2581 return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2582 }
2583
getLowerBoundsValueMap()2584 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2585 return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2586 }
2587
getUpperBoundsValueMap()2588 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2589 return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2590 }
2591
getRangesValueMap()2592 AffineValueMap AffineParallelOp::getRangesValueMap() {
2593 AffineValueMap out;
2594 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2595 &out);
2596 return out;
2597 }
2598
getConstantRanges()2599 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2600 // Try to convert all the ranges to constant expressions.
2601 SmallVector<int64_t, 8> out;
2602 AffineValueMap rangesValueMap = getRangesValueMap();
2603 out.reserve(rangesValueMap.getNumResults());
2604 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2605 auto expr = rangesValueMap.getResult(i);
2606 auto cst = expr.dyn_cast<AffineConstantExpr>();
2607 if (!cst)
2608 return llvm::None;
2609 out.push_back(cst.getValue());
2610 }
2611 return out;
2612 }
2613
getBody()2614 Block *AffineParallelOp::getBody() { return ®ion().front(); }
2615
getBodyBuilder()2616 OpBuilder AffineParallelOp::getBodyBuilder() {
2617 return OpBuilder(getBody(), std::prev(getBody()->end()));
2618 }
2619
setLowerBounds(ValueRange lbOperands,AffineMap map)2620 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
2621 assert(lbOperands.size() == map.getNumInputs() &&
2622 "operands to map must match number of inputs");
2623 assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2624
2625 auto ubOperands = getUpperBoundsOperands();
2626
2627 SmallVector<Value, 4> newOperands(lbOperands);
2628 newOperands.append(ubOperands.begin(), ubOperands.end());
2629 (*this)->setOperands(newOperands);
2630
2631 lowerBoundsMapAttr(AffineMapAttr::get(map));
2632 }
2633
setUpperBounds(ValueRange ubOperands,AffineMap map)2634 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
2635 assert(ubOperands.size() == map.getNumInputs() &&
2636 "operands to map must match number of inputs");
2637 assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2638
2639 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
2640 newOperands.append(ubOperands.begin(), ubOperands.end());
2641 (*this)->setOperands(newOperands);
2642
2643 upperBoundsMapAttr(AffineMapAttr::get(map));
2644 }
2645
setLowerBoundsMap(AffineMap map)2646 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
2647 AffineMap lbMap = lowerBoundsMap();
2648 assert(lbMap.getNumDims() == map.getNumDims() &&
2649 lbMap.getNumSymbols() == map.getNumSymbols());
2650 (void)lbMap;
2651 lowerBoundsMapAttr(AffineMapAttr::get(map));
2652 }
2653
setUpperBoundsMap(AffineMap map)2654 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
2655 AffineMap ubMap = upperBoundsMap();
2656 assert(ubMap.getNumDims() == map.getNumDims() &&
2657 ubMap.getNumSymbols() == map.getNumSymbols());
2658 (void)ubMap;
2659 upperBoundsMapAttr(AffineMapAttr::get(map));
2660 }
2661
getSteps()2662 SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
2663 SmallVector<int64_t, 8> result;
2664 for (Attribute attr : steps()) {
2665 result.push_back(attr.cast<IntegerAttr>().getInt());
2666 }
2667 return result;
2668 }
2669
setSteps(ArrayRef<int64_t> newSteps)2670 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2671 stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
2672 }
2673
verify(AffineParallelOp op)2674 static LogicalResult verify(AffineParallelOp op) {
2675 auto numDims = op.getNumDims();
2676 if (op.lowerBoundsMap().getNumResults() != numDims ||
2677 op.upperBoundsMap().getNumResults() != numDims ||
2678 op.steps().size() != numDims ||
2679 op.getBody()->getNumArguments() != numDims)
2680 return op.emitOpError("region argument count and num results of upper "
2681 "bounds, lower bounds, and steps must all match");
2682
2683 if (op.reductions().size() != op.getNumResults())
2684 return op.emitOpError("a reduction must be specified for each output");
2685
2686 // Verify reduction ops are all valid
2687 for (Attribute attr : op.reductions()) {
2688 auto intAttr = attr.dyn_cast<IntegerAttr>();
2689 if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
2690 return op.emitOpError("invalid reduction attribute");
2691 }
2692
2693 // Verify that the bound operands are valid dimension/symbols.
2694 /// Lower bounds.
2695 if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
2696 op.lowerBoundsMap().getNumDims())))
2697 return failure();
2698 /// Upper bounds.
2699 if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
2700 op.upperBoundsMap().getNumDims())))
2701 return failure();
2702 return success();
2703 }
2704
canonicalize()2705 LogicalResult AffineValueMap::canonicalize() {
2706 SmallVector<Value, 4> newOperands{operands};
2707 auto newMap = getAffineMap();
2708 composeAffineMapAndOperands(&newMap, &newOperands);
2709 if (newMap == getAffineMap() && newOperands == operands)
2710 return failure();
2711 reset(newMap, newOperands);
2712 return success();
2713 }
2714
2715 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)2716 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
2717 AffineValueMap lb = op.getLowerBoundsValueMap();
2718 bool lbCanonicalized = succeeded(lb.canonicalize());
2719
2720 AffineValueMap ub = op.getUpperBoundsValueMap();
2721 bool ubCanonicalized = succeeded(ub.canonicalize());
2722
2723 // Any canonicalization change always leads to updated map(s).
2724 if (!lbCanonicalized && !ubCanonicalized)
2725 return failure();
2726
2727 if (lbCanonicalized)
2728 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
2729 if (ubCanonicalized)
2730 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
2731
2732 return success();
2733 }
2734
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2735 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
2736 SmallVectorImpl<OpFoldResult> &results) {
2737 return canonicalizeLoopBounds(*this);
2738 }
2739
print(OpAsmPrinter & p,AffineParallelOp op)2740 static void print(OpAsmPrinter &p, AffineParallelOp op) {
2741 p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
2742 p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
2743 op.getLowerBoundsOperands());
2744 p << ") to (";
2745 p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
2746 op.getUpperBoundsOperands());
2747 p << ')';
2748 SmallVector<int64_t, 8> steps = op.getSteps();
2749 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
2750 if (!elideSteps) {
2751 p << " step (";
2752 llvm::interleaveComma(steps, p);
2753 p << ')';
2754 }
2755 if (op.getNumResults()) {
2756 p << " reduce (";
2757 llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
2758 AtomicRMWKind sym =
2759 *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
2760 p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
2761 });
2762 p << ") -> (" << op.getResultTypes() << ")";
2763 }
2764
2765 p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
2766 /*printBlockTerminators=*/op.getNumResults());
2767 p.printOptionalAttrDict(
2768 op.getAttrs(),
2769 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
2770 AffineParallelOp::getLowerBoundsMapAttrName(),
2771 AffineParallelOp::getUpperBoundsMapAttrName(),
2772 AffineParallelOp::getStepsAttrName()});
2773 }
2774
2775 //
2776 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
2777 // `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
2778 // steps ::= `steps` `(` integer-literals `)`
2779 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)2780 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
2781 OperationState &result) {
2782 auto &builder = parser.getBuilder();
2783 auto indexType = builder.getIndexType();
2784 AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
2785 SmallVector<OpAsmParser::OperandType, 4> ivs;
2786 SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
2787 SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
2788 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2789 OpAsmParser::Delimiter::Paren) ||
2790 parser.parseEqual() ||
2791 parser.parseAffineMapOfSSAIds(
2792 lowerBoundsMapOperands, lowerBoundsAttr,
2793 AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
2794 OpAsmParser::Delimiter::Paren) ||
2795 parser.resolveOperands(lowerBoundsMapOperands, indexType,
2796 result.operands) ||
2797 parser.parseKeyword("to") ||
2798 parser.parseAffineMapOfSSAIds(
2799 upperBoundsMapOperands, upperBoundsAttr,
2800 AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
2801 OpAsmParser::Delimiter::Paren) ||
2802 parser.resolveOperands(upperBoundsMapOperands, indexType,
2803 result.operands))
2804 return failure();
2805
2806 AffineMapAttr stepsMapAttr;
2807 NamedAttrList stepsAttrs;
2808 SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
2809 if (failed(parser.parseOptionalKeyword("step"))) {
2810 SmallVector<int64_t, 4> steps(ivs.size(), 1);
2811 result.addAttribute(AffineParallelOp::getStepsAttrName(),
2812 builder.getI64ArrayAttr(steps));
2813 } else {
2814 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
2815 AffineParallelOp::getStepsAttrName(),
2816 stepsAttrs,
2817 OpAsmParser::Delimiter::Paren))
2818 return failure();
2819
2820 // Convert steps from an AffineMap into an I64ArrayAttr.
2821 SmallVector<int64_t, 4> steps;
2822 auto stepsMap = stepsMapAttr.getValue();
2823 for (const auto &result : stepsMap.getResults()) {
2824 auto constExpr = result.dyn_cast<AffineConstantExpr>();
2825 if (!constExpr)
2826 return parser.emitError(parser.getNameLoc(),
2827 "steps must be constant integers");
2828 steps.push_back(constExpr.getValue());
2829 }
2830 result.addAttribute(AffineParallelOp::getStepsAttrName(),
2831 builder.getI64ArrayAttr(steps));
2832 }
2833
2834 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
2835 // quoted strings are a member of the enum AtomicRMWKind.
2836 SmallVector<Attribute, 4> reductions;
2837 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2838 if (parser.parseLParen())
2839 return failure();
2840 do {
2841 // Parse a single quoted string via the attribute parsing, and then
2842 // verify it is a member of the enum and convert to it's integer
2843 // representation.
2844 StringAttr attrVal;
2845 NamedAttrList attrStorage;
2846 auto loc = parser.getCurrentLocation();
2847 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
2848 attrStorage))
2849 return failure();
2850 llvm::Optional<AtomicRMWKind> reduction =
2851 symbolizeAtomicRMWKind(attrVal.getValue());
2852 if (!reduction)
2853 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
2854 reductions.push_back(builder.getI64IntegerAttr(
2855 static_cast<int64_t>(reduction.getValue())));
2856 // While we keep getting commas, keep parsing.
2857 } while (succeeded(parser.parseOptionalComma()));
2858 if (parser.parseRParen())
2859 return failure();
2860 }
2861 result.addAttribute(AffineParallelOp::getReductionsAttrName(),
2862 builder.getArrayAttr(reductions));
2863
2864 // Parse return types of reductions (if any)
2865 if (parser.parseOptionalArrowTypeList(result.types))
2866 return failure();
2867
2868 // Now parse the body.
2869 Region *body = result.addRegion();
2870 SmallVector<Type, 4> types(ivs.size(), indexType);
2871 if (parser.parseRegion(*body, ivs, types) ||
2872 parser.parseOptionalAttrDict(result.attributes))
2873 return failure();
2874
2875 // Add a terminator if none was parsed.
2876 AffineParallelOp::ensureTerminator(*body, builder, result.location);
2877 return success();
2878 }
2879
2880 //===----------------------------------------------------------------------===//
2881 // AffineYieldOp
2882 //===----------------------------------------------------------------------===//
2883
verify(AffineYieldOp op)2884 static LogicalResult verify(AffineYieldOp op) {
2885 auto *parentOp = op->getParentOp();
2886 auto results = parentOp->getResults();
2887 auto operands = op.getOperands();
2888
2889 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
2890 return op.emitOpError() << "only terminates affine.if/for/parallel regions";
2891 if (parentOp->getNumResults() != op.getNumOperands())
2892 return op.emitOpError() << "parent of yield must have same number of "
2893 "results as the yield operands";
2894 for (auto it : llvm::zip(results, operands)) {
2895 if (std::get<0>(it).getType() != std::get<1>(it).getType())
2896 return op.emitOpError()
2897 << "types mismatch between yield op and its parent";
2898 }
2899
2900 return success();
2901 }
2902
2903 //===----------------------------------------------------------------------===//
2904 // AffineVectorLoadOp
2905 //===----------------------------------------------------------------------===//
2906
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)2907 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2908 VectorType resultType, AffineMap map,
2909 ValueRange operands) {
2910 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2911 result.addOperands(operands);
2912 if (map)
2913 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2914 result.types.push_back(resultType);
2915 }
2916
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)2917 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2918 VectorType resultType, Value memref,
2919 AffineMap map, ValueRange mapOperands) {
2920 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2921 result.addOperands(memref);
2922 result.addOperands(mapOperands);
2923 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2924 result.types.push_back(resultType);
2925 }
2926
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)2927 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2928 VectorType resultType, Value memref,
2929 ValueRange indices) {
2930 auto memrefType = memref.getType().cast<MemRefType>();
2931 int64_t rank = memrefType.getRank();
2932 // Create identity map for memrefs with at least one dimension or () -> ()
2933 // for zero-dimensional memrefs.
2934 auto map =
2935 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2936 build(builder, result, resultType, memref, map, indices);
2937 }
2938
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)2939 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
2940 OperationState &result) {
2941 auto &builder = parser.getBuilder();
2942 auto indexTy = builder.getIndexType();
2943
2944 MemRefType memrefType;
2945 VectorType resultType;
2946 OpAsmParser::OperandType memrefInfo;
2947 AffineMapAttr mapAttr;
2948 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2949 return failure(
2950 parser.parseOperand(memrefInfo) ||
2951 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2952 AffineVectorLoadOp::getMapAttrName(),
2953 result.attributes) ||
2954 parser.parseOptionalAttrDict(result.attributes) ||
2955 parser.parseColonType(memrefType) || parser.parseComma() ||
2956 parser.parseType(resultType) ||
2957 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2958 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2959 parser.addTypeToList(resultType, result.types));
2960 }
2961
print(OpAsmPrinter & p,AffineVectorLoadOp op)2962 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
2963 p << "affine.vector_load " << op.getMemRef() << '[';
2964 if (AffineMapAttr mapAttr =
2965 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2966 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2967 p << ']';
2968 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2969 p << " : " << op.getMemRefType() << ", " << op.getType();
2970 }
2971
2972 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)2973 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
2974 VectorType vectorType) {
2975 // Check that memref and vector element types match.
2976 if (memrefType.getElementType() != vectorType.getElementType())
2977 return op->emitOpError(
2978 "requires memref and vector types of the same elemental type");
2979 return success();
2980 }
2981
verify(AffineVectorLoadOp op)2982 static LogicalResult verify(AffineVectorLoadOp op) {
2983 MemRefType memrefType = op.getMemRefType();
2984 if (failed(verifyMemoryOpIndexing(
2985 op.getOperation(),
2986 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2987 op.getMapOperands(), memrefType,
2988 /*numIndexOperands=*/op.getNumOperands() - 1)))
2989 return failure();
2990
2991 if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2992 op.getVectorType())))
2993 return failure();
2994
2995 return success();
2996 }
2997
2998 //===----------------------------------------------------------------------===//
2999 // AffineVectorStoreOp
3000 //===----------------------------------------------------------------------===//
3001
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)3002 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3003 Value valueToStore, Value memref, AffineMap map,
3004 ValueRange mapOperands) {
3005 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3006 result.addOperands(valueToStore);
3007 result.addOperands(memref);
3008 result.addOperands(mapOperands);
3009 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3010 }
3011
3012 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)3013 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3014 Value valueToStore, Value memref,
3015 ValueRange indices) {
3016 auto memrefType = memref.getType().cast<MemRefType>();
3017 int64_t rank = memrefType.getRank();
3018 // Create identity map for memrefs with at least one dimension or () -> ()
3019 // for zero-dimensional memrefs.
3020 auto map =
3021 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3022 build(builder, result, valueToStore, memref, map, indices);
3023 }
3024
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)3025 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
3026 OperationState &result) {
3027 auto indexTy = parser.getBuilder().getIndexType();
3028
3029 MemRefType memrefType;
3030 VectorType resultType;
3031 OpAsmParser::OperandType storeValueInfo;
3032 OpAsmParser::OperandType memrefInfo;
3033 AffineMapAttr mapAttr;
3034 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3035 return failure(
3036 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3037 parser.parseOperand(memrefInfo) ||
3038 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3039 AffineVectorStoreOp::getMapAttrName(),
3040 result.attributes) ||
3041 parser.parseOptionalAttrDict(result.attributes) ||
3042 parser.parseColonType(memrefType) || parser.parseComma() ||
3043 parser.parseType(resultType) ||
3044 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3045 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3046 parser.resolveOperands(mapOperands, indexTy, result.operands));
3047 }
3048
print(OpAsmPrinter & p,AffineVectorStoreOp op)3049 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
3050 p << "affine.vector_store " << op.getValueToStore();
3051 p << ", " << op.getMemRef() << '[';
3052 if (AffineMapAttr mapAttr =
3053 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3054 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3055 p << ']';
3056 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
3057 p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
3058 }
3059
verify(AffineVectorStoreOp op)3060 static LogicalResult verify(AffineVectorStoreOp op) {
3061 MemRefType memrefType = op.getMemRefType();
3062 if (failed(verifyMemoryOpIndexing(
3063 op.getOperation(),
3064 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3065 op.getMapOperands(), memrefType,
3066 /*numIndexOperands=*/op.getNumOperands() - 2)))
3067 return failure();
3068
3069 if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3070 op.getVectorType())))
3071 return failure();
3072
3073 return success();
3074 }
3075
3076 //===----------------------------------------------------------------------===//
3077 // TableGen'd op method definitions
3078 //===----------------------------------------------------------------------===//
3079
3080 #define GET_OP_CLASSES
3081 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
3082