//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; using llvm::dbgs; #define DEBUG_TYPE "affine-analysis" //===----------------------------------------------------------------------===// // AffineDialect Interfaces //===----------------------------------------------------------------------===// namespace { /// This class defines the interface for handling inlining with affine /// operations. struct AffineInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { // Conservatively don't allow inlining into affine structures. return false; } /// Returns true if the given operation 'op', that is registered to this /// dialect, can be inlined into the given region, false otherwise. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { // Always allow inlining affine operations into the top-level region of a // function. There are some edge cases when inlining *into* affine // structures, but that is handled in the other 'isLegalToInline' hook // above. // TODO: We should be able to inline into other regions than functions. return isa(region->getParentOp()); } /// Affine regions should be analyzed recursively. bool shouldAnalyzeRecursively(Operation *op) const final { return true; } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // AffineDialect //===----------------------------------------------------------------------===// void AffineDialect::initialize() { addOperations(); addInterfaces(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); } /// A utility function to check if a value is defined at the top level of an /// op with trait `AffineScope`. If the value is defined in an unlinked region, /// conservatively assume it is not top-level. A value of index type defined at /// the top level is always a valid symbol. bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast()) { // The block owning the argument may be unlinked, e.g. when the surrounding // region has not yet been attached to an Op, at which point the parent Op // is null. Operation *parentOp = arg.getOwner()->getParentOp(); return parentOp && parentOp->hasTrait(); } // The defining Op may live in an unlinked block so its parent Op may be null. Operation *parentOp = value.getDefiningOp()->getParentOp(); return parentOp && parentOp->hasTrait(); } /// A utility function to check if a value is defined at the top level of /// `region` or is an argument of `region`. A value of index type defined at the /// top level of a `AffineScope` region is always a valid symbol for all /// uses in that region. static bool isTopLevelValue(Value value, Region *region) { if (auto arg = value.dyn_cast()) return arg.getParentRegion() == region; return value.getDefiningOp()->getParentRegion() == region; } /// Returns the closest region enclosing `op` that is held by an operation with /// trait `AffineScope`; `nullptr` if there is no such region. // TODO: getAffineScope should be publicly exposed for affine passes/utilities. static Region *getAffineScope(Operation *op) { auto *curOp = op; while (auto *parentOp = curOp->getParentOp()) { if (parentOp->hasTrait()) return curOp->getParentRegion(); curOp = parentOp; } return nullptr; } // A Value can be used as a dimension id iff it meets one of the following // conditions: // *) It is valid as a symbol. // *) It is an induction variable. // *) It is the result of affine apply operation with dimension id arguments. bool mlir::isValidDim(Value value) { // The value must be an index type. if (!value.getType().isIndex()) return false; if (auto *defOp = value.getDefiningOp()) return isValidDim(value, getAffineScope(defOp)); // This value has to be a block argument for an op that has the // `AffineScope` trait or for an affine.for or affine.parallel. auto *parentOp = value.cast().getOwner()->getParentOp(); return parentOp && (parentOp->hasTrait() || isa(parentOp)); } // Value can be used as a dimension id iff it meets one of the following // conditions: // *) It is valid as a symbol. // *) It is an induction variable. // *) It is the result of an affine apply operation with dimension id operands. bool mlir::isValidDim(Value value, Region *region) { // The value must be an index type. if (!value.getType().isIndex()) return false; // All valid symbols are okay. if (isValidSymbol(value, region)) return true; auto *op = value.getDefiningOp(); if (!op) { // This value has to be a block argument for an affine.for or an // affine.parallel. auto *parentOp = value.cast().getOwner()->getParentOp(); return isa(parentOp); } // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast(op)) return applyOp.isValidDim(region); // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = dyn_cast(op)) return isTopLevelValue(dimOp.memrefOrTensor()); return false; } /// Returns true if the 'index' dimension of the `memref` defined by /// `memrefDefOp` is a statically shaped one or defined using a valid symbol /// for `region`. template static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region) { auto memRefType = memrefDefOp.getType(); // Statically shaped. if (!memRefType.isDynamicDim(index)) return true; // Get the position of the dimension among dynamic dimensions; unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), region); } /// Returns true if the result of the dim op is a valid symbol for `region`. static bool isDimOpValidSymbol(DimOp dimOp, Region *region) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (isTopLevelValue(dimOp.memrefOrTensor())) return true; // Conservatively handle remaining BlockArguments as non-valid symbols. // E.g. scf.for iterArgs. if (dimOp.memrefOrTensor().isa()) return false; // The dim op is also okay if its operand memref/tensor is a view/subview // whose corresponding size is a valid symbol. Optional index = dimOp.getConstantIndex(); assert(index.hasValue() && "expect only `dim` operations with a constant index"); int64_t i = index.getValue(); return TypeSwitch(dimOp.memrefOrTensor().getDefiningOp()) .Case( [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) .Default([](Operation *) { return false; }); } // A value can be used as a symbol (at all its use sites) iff it meets one of // the following conditions: // *) It is a constant. // *) Its defining op or block arg appearance is immediately enclosed by an op // with `AffineScope` trait. // *) It is the result of an affine.apply operation with symbol operands. // *) It is a result of the dim op on a memref whose corresponding size is a // valid symbol. bool mlir::isValidSymbol(Value value) { // The value must be an index type. if (!value.getType().isIndex()) return false; // Check that the value is a top level value. if (isTopLevelValue(value)) return true; if (auto *defOp = value.getDefiningOp()) return isValidSymbol(value, getAffineScope(defOp)); return false; } /// A value can be used as a symbol for `region` iff it meets onf of the the /// following conditions: /// *) It is a constant. /// *) It is the result of an affine apply operation with symbol arguments. /// *) It is a result of the dim op on a memref whose corresponding size is /// a valid symbol. /// *) It is defined at the top level of 'region' or is its argument. /// *) It dominates `region`'s parent op. /// If `region` is null, conservatively assume the symbol definition scope does /// not exist and only accept the values that would be symbols regardless of /// the surrounding region structure, i.e. the first three cases above. bool mlir::isValidSymbol(Value value, Region *region) { // The value must be an index type. if (!value.getType().isIndex()) return false; // A top-level value is a valid symbol. if (region && ::isTopLevelValue(value, region)) return true; auto *defOp = value.getDefiningOp(); if (!defOp) { // A block argument that is not a top-level value is a valid symbol if it // dominates region's parent op. if (region && !region->getParentOp()->isKnownIsolatedFromAbove()) if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) return isValidSymbol(value, parentOpRegion); return false; } // Constant operation is ok. Attribute operandCst; if (matchPattern(defOp, m_Constant(&operandCst))) return true; // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast(defOp)) return applyOp.isValidSymbol(region); // Dim op results could be valid symbols at any level. if (auto dimOp = dyn_cast(defOp)) return isDimOpValidSymbol(dimOp, region); // Check for values dominating `region`'s parent op. if (region && !region->getParentOp()->isKnownIsolatedFromAbove()) if (auto *parentRegion = region->getParentOp()->getParentRegion()) return isValidSymbol(value, parentRegion); return false; } // Returns true if 'value' is a valid index to an affine operation (e.g. // affine.load, affine.store, affine.dma_start, affine.dma_wait) where // `region` provides the polyhedral symbol scope. Returns false otherwise. static bool isValidAffineIndexOperand(Value value, Region *region) { return isValidDim(value, region) || isValidSymbol(value, region); } /// Prints dimension and symbol list. static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer) { OperandRange operands(begin, end); printer << '(' << operands.take_front(numDims) << ')'; if (operands.size() > numDims) printer << '[' << operands.drop_front(numDims) << ']'; } /// Parses dimension and symbol list and returns true if parsing failed. ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl &operands, unsigned &numDims) { SmallVector opInfos; if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) return failure(); // Store number of dimensions for validation by caller. numDims = opInfos.size(); // Parse the optional symbol operands. auto indexTy = parser.getBuilder().getIndexType(); return failure(parser.parseOperandList( opInfos, OpAsmParser::Delimiter::OptionalSquare) || parser.resolveOperands(opInfos, indexTy, operands)); } /// Utility function to verify that a set of operands are valid dimension and /// symbol identifiers. The operands should be laid out such that the dimension /// operands are before the symbol operands. This function returns failure if /// there was an invalid operand. An operation is provided to emit any necessary /// errors. template static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims) { unsigned opIt = 0; for (auto operand : operands) { if (opIt++ < numDims) { if (!isValidDim(operand, getAffineScope(op))) return op.emitOpError("operand cannot be used as a dimension id"); } else if (!isValidSymbol(operand, getAffineScope(op))) { return op.emitOpError("operand cannot be used as a symbol"); } } return success(); } //===----------------------------------------------------------------------===// // AffineApplyOp //===----------------------------------------------------------------------===// AffineValueMap AffineApplyOp::getAffineValueMap() { return AffineValueMap(getAffineMap(), getOperands(), getResult()); } static ParseResult parseAffineApplyOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); AffineMapAttr mapAttr; unsigned numDims; if (parser.parseAttribute(mapAttr, "map", result.attributes) || parseDimAndSymbolList(parser, result.operands, numDims) || parser.parseOptionalAttrDict(result.attributes)) return failure(); auto map = mapAttr.getValue(); if (map.getNumDims() != numDims || numDims + map.getNumSymbols() != result.operands.size()) { return parser.emitError(parser.getNameLoc(), "dimension or symbol index mismatch"); } result.types.append(map.getNumResults(), indexTy); return success(); } static void print(OpAsmPrinter &p, AffineApplyOp op) { p << AffineApplyOp::getOperationName() << " " << op.mapAttr(); printDimAndSymbolList(op.operand_begin(), op.operand_end(), op.getAffineMap().getNumDims(), p); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); } static LogicalResult verify(AffineApplyOp op) { // Check input and output dimensions match. auto map = op.map(); // Verify that operand count matches affine map dimension and symbol count. if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) return op.emitOpError( "operand count and affine map dimension and symbol count must match"); // Verify that the map only produces one result. if (map.getNumResults() != 1) return op.emitOpError("mapping must produce one value"); return success(); } // The result of the affine apply operation can be used as a dimension id if all // its operands are valid dimension ids. bool AffineApplyOp::isValidDim() { return llvm::all_of(getOperands(), [](Value op) { return mlir::isValidDim(op); }); } // The result of the affine apply operation can be used as a dimension id if all // its operands are valid dimension ids with the parent operation of `region` // defining the polyhedral scope for symbols. bool AffineApplyOp::isValidDim(Region *region) { return llvm::all_of(getOperands(), [&](Value op) { return ::isValidDim(op, region); }); } // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { return llvm::all_of(getOperands(), [](Value op) { return mlir::isValidSymbol(op); }); } // The result of the affine apply operation can be used as a symbol in `region` // if all its operands are symbols in `region`. bool AffineApplyOp::isValidSymbol(Region *region) { return llvm::all_of(getOperands(), [&](Value operand) { return mlir::isValidSymbol(operand, region); }); } OpFoldResult AffineApplyOp::fold(ArrayRef operands) { auto map = getAffineMap(); // Fold dims and symbols to existing values. auto expr = map.getResult(0); if (auto dim = expr.dyn_cast()) return getOperand(dim.getPosition()); if (auto sym = expr.dyn_cast()) return getOperand(map.getNumDims() + sym.getPosition()); // Otherwise, default to folding the map. SmallVector result; if (failed(map.constantFold(operands, result))) return {}; return result[0]; } AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { DenseMap::iterator iterPos; bool inserted = false; std::tie(iterPos, inserted) = dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); if (inserted) { reorderedDims.push_back(v); } return getAffineDimExpr(iterPos->second, v.getContext()) .cast(); } AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { SmallVector dimRemapping; for (auto v : other.reorderedDims) { auto kvp = other.dimValueToPosition.find(v); if (dimRemapping.size() <= kvp->second) dimRemapping.resize(kvp->second + 1); dimRemapping[kvp->second] = renumberOneDim(kvp->first); } unsigned numSymbols = concatenatedSymbols.size(); unsigned numOtherSymbols = other.concatenatedSymbols.size(); SmallVector symRemapping(numOtherSymbols); for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { symRemapping[idx] = getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); } concatenatedSymbols.insert(concatenatedSymbols.end(), other.concatenatedSymbols.begin(), other.concatenatedSymbols.end()); auto map = other.affineMap; return map.replaceDimsAndSymbols(dimRemapping, symRemapping, reorderedDims.size(), concatenatedSymbols.size()); } // Gather the positions of the operands that are produced by an AffineApplyOp. static llvm::SetVector indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; for (auto en : llvm::enumerate(operands)) if (isa_and_nonnull(en.value().getDefiningOp())) res.insert(en.index()); return res; } // Support the special case of a symbol coming from an AffineApplyOp that needs // to be composed into the current AffineApplyOp. // This case is handled by rewriting all such symbols into dims for the purpose // of allowing mathematical AffineMap composition. // Returns an AffineMap where symbols that come from an AffineApplyOp have been // rewritten as dims and are ordered after the original dims. // TODO: This promotion makes AffineMap lose track of which // symbols are represented as dims. This loss is static but can still be // recovered dynamically (with `isValidSymbol`). Still this is annoying for the // semi-affine map case. A dynamic canonicalization of all dims that are valid // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even // results in better simplifications and foldings. But we should evaluate // whether this behavior is what we really want after using more. static AffineMap promoteComposedSymbolsAsDims(AffineMap map, ArrayRef symbols) { if (symbols.empty()) { return map; } // Sanity check on symbols. for (auto sym : symbols) { assert(isValidSymbol(sym) && "Expected only valid symbols"); (void)sym; } // Extract the symbol positions that come from an AffineApplyOp and // needs to be rewritten as dims. auto symPositions = indicesFromAffineApplyOp(symbols); if (symPositions.empty()) { return map; } // Create the new map by replacing each symbol at pos by the next new dim. unsigned numDims = map.getNumDims(); unsigned numSymbols = map.getNumSymbols(); unsigned numNewDims = 0; unsigned numNewSymbols = 0; SmallVector symReplacements(numSymbols); for (unsigned i = 0; i < numSymbols; ++i) { symReplacements[i] = symPositions.count(i) > 0 ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) : getAffineSymbolExpr(numNewSymbols++, map.getContext()); } assert(numSymbols >= numNewDims); AffineMap newMap = map.replaceDimsAndSymbols( {}, symReplacements, numDims + numNewDims, numNewSymbols); return newMap; } /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to /// keep a correspondence between the mathematical `map` and the `operands` of /// a given AffineApplyOp. This correspondence is maintained by iterating over /// the operands and forming an `auxiliaryMap` that can be composed /// mathematically with `map`. To keep this correspondence in cases where /// symbols are produced by affine.apply operations, we perform a local rewrite /// of symbols as dims. /// /// Rationale for locally rewriting symbols as dims: /// ================================================ /// The mathematical composition of AffineMap must always concatenate symbols /// because it does not have enough information to do otherwise. For example, /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce /// `(d0)[s0, s1] -> (d0 + s0 + s1)`. /// /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when /// applied to the same mlir::Value for both s0 and s1. /// As a consequence mathematical composition of AffineMap always concatenates /// symbols. /// /// When AffineMaps are used in AffineApplyOp however, they may specify /// composition via symbols, which is ambiguous mathematically. This corner case /// is handled by locally rewriting such symbols that come from AffineApplyOp /// into dims and composing through dims. /// TODO: Composition via symbols comes at a significant code /// complexity. Alternatively we should investigate whether we want to /// explicitly disallow symbols coming from affine.apply and instead force the /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 /// extra API calls for such uses, which haven't popped up until now) and the /// benefit potentially big: simpler and more maintainable code for a /// non-trivial, recursive, procedure. AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, ArrayRef operands) : AffineApplyNormalizer() { static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); assert(map.getNumInputs() == operands.size() && "number of operands does not match the number of map inputs"); LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); // Promote symbols that come from an AffineApplyOp to dims by rewriting the // map to always refer to: // (dims, symbols coming from AffineApplyOp, other symbols). // The order of operands can remain unchanged. // This is a simplification that relies on 2 ordering properties: // 1. rewritten symbols always appear after the original dims in the map; // 2. operands are traversed in order and either dispatched to: // a. auxiliaryExprs (dims and symbols rewritten as dims); // b. concatenatedSymbols (all other symbols) // This allows operand order to remain unchanged. unsigned numDimsBeforeRewrite = map.getNumDims(); map = promoteComposedSymbolsAsDims(map, operands.take_back(map.getNumSymbols())); LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); SmallVector auxiliaryExprs; bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); // We fully spell out the 2 cases below. In this particular instance a little // code duplication greatly improves readability. // Note that the first branch would disappear if we only supported full // composition (i.e. infinite kMaxAffineApplyDepth). if (!furtherCompose) { // 1. Only dispatch dims or symbols. for (auto en : llvm::enumerate(operands)) { auto t = en.value(); assert(t.getType().isIndex()); bool isDim = (en.index() < map.getNumDims()); if (isDim) { // a. The mathematical composition of AffineMap composes dims. auxiliaryExprs.push_back(renumberOneDim(t)); } else { // b. The mathematical composition of AffineMap concatenates symbols. // We do the same for symbol operands. concatenatedSymbols.push_back(t); } } } else { assert(numDimsBeforeRewrite <= operands.size()); // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { auto t = operands[i]; auto affineApply = t.getDefiningOp(); if (affineApply) { // a. Compose affine.apply operations. LLVM_DEBUG(affineApply->print( dbgs() << "\nCompose AffineApplyOp recursively: ")); AffineMap affineApplyMap = affineApply.getAffineMap(); SmallVector affineApplyOperands( affineApply.getOperands().begin(), affineApply.getOperands().end()); AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); LLVM_DEBUG(normalizer.affineMap.print( dbgs() << "\nRenumber into current normalizer: ")); auto renumberedMap = renumber(normalizer); LLVM_DEBUG( renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); auxiliaryExprs.push_back(renumberedMap.getResult(0)); } else { if (i < numDimsBeforeRewrite) { // b. The mathematical composition of AffineMap composes dims. auxiliaryExprs.push_back(renumberOneDim(t)); } else { // c. The mathematical composition of AffineMap concatenates symbols. // Note that the map composition will put symbols already present // in the map before any symbols coming from the auxiliary map, so // we insert them before any symbols that are due to renumbering, // and after the proper symbols we have seen already. concatenatedSymbols.insert( std::next(concatenatedSymbols.begin(), numProperSymbols++), t); } } } } // Early exit if `map` is already composed. if (auxiliaryExprs.empty()) { affineMap = map; return; } assert(concatenatedSymbols.size() >= map.getNumSymbols() && "Unexpected number of concatenated symbols"); auto numDims = dimValueToPosition.size(); auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext()); LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); // TODO: Disabling simplification results in major speed gains. // Another option is to cache the results as it is expected a lot of redundant // work is performed in practice. affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); LLVM_DEBUG(dbgs() << "\n"); } void AffineApplyNormalizer::normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands) { AffineApplyNormalizer other(*otherMap, *otherOperands); *otherMap = renumber(other); otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size()); otherOperands->assign(reorderedDims.begin(), reorderedDims.end()); otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end()); } /// Implements `map` and `operands` composition and simplification to support /// `makeComposedAffineApply`. This can be called to achieve the same effects /// on `map` and `operands` without creating an AffineApplyOp that needs to be /// immediately deleted. static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl *operands) { AffineApplyNormalizer normalizer(*map, *operands); auto normalizedMap = normalizer.getAffineMap(); auto normalizedOperands = normalizer.getOperands(); canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); *map = normalizedMap; *operands = normalizedOperands; assert(*map); } void mlir::fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl *operands) { while (llvm::any_of(*operands, [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); } } AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { AffineMap normalizedMap = map; SmallVector normalizedOperands(operands.begin(), operands.end()); composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); assert(normalizedMap); return b.create(loc, normalizedMap, normalizedOperands); } // A symbol may appear as a dim in affine.apply operations. This function // canonicalizes dims that are valid symbols into actual symbols. template static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl *operands) { if (!mapOrSet || operands->empty()) return; assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); auto *context = mapOrSet->getContext(); SmallVector resultOperands; resultOperands.reserve(operands->size()); SmallVector remappedSymbols; remappedSymbols.reserve(operands->size()); unsigned nextDim = 0; unsigned nextSym = 0; unsigned oldNumSyms = mapOrSet->getNumSymbols(); SmallVector dimRemapping(mapOrSet->getNumDims()); for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { if (i < mapOrSet->getNumDims()) { if (isValidSymbol((*operands)[i])) { // This is a valid symbol that appears as a dim, canonicalize it. dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); remappedSymbols.push_back((*operands)[i]); } else { dimRemapping[i] = getAffineDimExpr(nextDim++, context); resultOperands.push_back((*operands)[i]); } } else { resultOperands.push_back((*operands)[i]); } } resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); *operands = resultOperands; *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, oldNumSyms + nextSym); assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); } // Works for either an affine map or an integer set. template static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl *operands) { static_assert(llvm::is_one_of::value, "Argument must be either of AffineMap or IntegerSet type"); if (!mapOrSet || operands->empty()) return; assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); canonicalizePromotedSymbols(mapOrSet, operands); // Check to see what dims are used. llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); mapOrSet->walkExprs([&](AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) usedDims[dimExpr.getPosition()] = true; else if (auto symExpr = expr.dyn_cast()) usedSyms[symExpr.getPosition()] = true; }); auto *context = mapOrSet->getContext(); SmallVector resultOperands; resultOperands.reserve(operands->size()); llvm::SmallDenseMap seenDims; SmallVector dimRemapping(mapOrSet->getNumDims()); unsigned nextDim = 0; for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { if (usedDims[i]) { // Remap dim positions for duplicate operands. auto it = seenDims.find((*operands)[i]); if (it == seenDims.end()) { dimRemapping[i] = getAffineDimExpr(nextDim++, context); resultOperands.push_back((*operands)[i]); seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); } else { dimRemapping[i] = it->second; } } } llvm::SmallDenseMap seenSymbols; SmallVector symRemapping(mapOrSet->getNumSymbols()); unsigned nextSym = 0; for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { if (!usedSyms[i]) continue; // Handle constant operands (only needed for symbolic operands since // constant operands in dimensional positions would have already been // promoted to symbolic positions above). IntegerAttr operandCst; if (matchPattern((*operands)[i + mapOrSet->getNumDims()], m_Constant(&operandCst))) { symRemapping[i] = getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); continue; } // Remap symbol positions for duplicate operands. auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); if (it == seenSymbols.end()) { symRemapping[i] = getAffineSymbolExpr(nextSym++, context); resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]); seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], symRemapping[i])); } else { symRemapping[i] = it->second; } } *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); *operands = resultOperands; } void mlir::canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(map, operands); } void mlir::canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(set, operands); } namespace { /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing /// maps that supply results into them. /// template struct SimplifyAffineOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; /// Replace the affine op with another instance of it with the supplied /// map and mapOperands. void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, AffineMap map, ArrayRef mapOperands) const; LogicalResult matchAndRewrite(AffineOpTy affineOp, PatternRewriter &rewriter) const override { static_assert(llvm::is_one_of::value, "affine load/store/apply/prefetch/min/max op expected"); auto map = affineOp.getAffineMap(); AffineMap oldMap = map; auto oldOperands = affineOp.getMapOperands(); SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) return failure(); replaceAffineOp(rewriter, affineOp, map, resultOperands); return success(); } }; // Specialize the template to account for the different build signatures for // affine load, store, and apply ops. template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(load, load.getMemRef(), map, mapOperands); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(), prefetch.isWrite(), prefetch.isDataCache()); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getMemRef(), map, mapOperands); } // Generic version for ops that don't have extra operands. template void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineOpTy op, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(op, map, mapOperands); } } // end anonymous namespace. void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); if (cast && !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // AffineDmaStartOp //===----------------------------------------------------------------------===// // TODO: Check that map operands are loop IVs or symbols. void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride, Value elementsPerStride) { result.addOperands(srcMemRef); result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); result.addOperands(srcIndices); result.addOperands(destMemRef); result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap)); result.addOperands(destIndices); result.addOperands(tagMemRef); result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); result.addOperands(numElements); if (stride) { result.addOperands({stride, elementsPerStride}); } } void AffineDmaStartOp::print(OpAsmPrinter &p) { p << "affine.dma_start " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); p << "], " << getDstMemRef() << '['; p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); p << "], " << getTagMemRef() << '['; p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); p << "], " << getNumElements(); if (isStrided()) { p << ", " << getStride(); p << ", " << getNumElementsPerStride(); } p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " << getTagMemRefType(); } // Parse AffineDmaStartOp. // Ex: // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, // %stride, %num_elt_per_stride // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> // ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; AffineMapAttr srcMapAttr; SmallVector srcMapOperands; OpAsmParser::OperandType dstMemRefInfo; AffineMapAttr dstMapAttr; SmallVector dstMapOperands; OpAsmParser::OperandType tagMemRefInfo; AffineMapAttr tagMapAttr; SmallVector tagMapOperands; OpAsmParser::OperandType numElementsInfo; SmallVector strideInfo; SmallVector types; auto indexType = parser.getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) dst memref followed by its affine maps operands (in square brackets). // *) src memref followed by its affine map operands (in square brackets). // *) tag memref followed by its affine map operands (in square brackets). // *) number of elements transferred by DMA operation. if (parser.parseOperand(srcMemRefInfo) || parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, getSrcMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(dstMemRefInfo) || parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, getDstMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(tagMemRefInfo) || parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, getTagMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(numElementsInfo)) return failure(); // Parse optional stride and elements per stride. if (parser.parseTrailingOperandList(strideInfo)) { return failure(); } if (!strideInfo.empty() && strideInfo.size() != 2) { return parser.emitError(parser.getNameLoc(), "expected two stride related operands"); } bool isStrided = strideInfo.size() == 2; if (parser.parseColonTypeList(types)) return failure(); if (types.size() != 3) return parser.emitError(parser.getNameLoc(), "expected three types"); if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcMapOperands, indexType, result.operands) || parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || parser.resolveOperands(dstMapOperands, indexType, result.operands) || parser.resolveOperand(tagMemRefInfo, types[2], result.operands) || parser.resolveOperands(tagMapOperands, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } // Check that src/dst/tag operand counts match their map.numInputs. if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) return parser.emitError(parser.getNameLoc(), "memref operand count not equal to map.numInputs"); return success(); } LogicalResult AffineDmaStartOp::verify() { if (!getOperand(getSrcMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA source to be of memref type"); if (!getOperand(getDstMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA destination to be of memref type"); if (!getOperand(getTagMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA tag to be of memref type"); // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) { return emitOpError("DMA should be between different memory spaces"); } unsigned numInputsAllMaps = getSrcMap().getNumInputs() + getDstMap().getNumInputs() + getTagMap().getNumInputs(); if (getNumOperands() != numInputsAllMaps + 3 + 1 && getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { return emitOpError("incorrect number of operands"); } Region *scope = getAffineScope(*this); for (auto idx : getSrcIndices()) { if (!idx.getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("src index must be a dimension or symbol identifier"); } for (auto idx : getDstIndices()) { if (!idx.getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("dst index must be a dimension or symbol identifier"); } for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("tag index must be a dimension or symbol identifier"); } return success(); } LogicalResult AffineDmaStartOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineDmaWaitOp //===----------------------------------------------------------------------===// // TODO: Check that map operands are loop IVs or symbols. void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements) { result.addOperands(tagMemRef); result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); result.addOperands(numElements); } void AffineDmaWaitOp::print(OpAsmPrinter &p) { p << "affine.dma_wait " << getTagMemRef() << '['; SmallVector operands(getTagIndices()); p.printAffineMapOfSSAIds(getTagMapAttr(), operands); p << "], "; p.printOperand(getNumElements()); p << " : " << getTagMemRef().getType(); } // Parse AffineDmaWaitOp. // Eg: // affine.dma_wait %tag[%index], %num_elements // : memref<1 x i32, (d0) -> (d0), 4> // ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType tagMemRefInfo; AffineMapAttr tagMapAttr; SmallVector tagMapOperands; Type type; auto indexType = parser.getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its map operands, and dma size. if (parser.parseOperand(tagMemRefInfo) || parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, getTagMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseColonType(type) || parser.resolveOperand(tagMemRefInfo, type, result.operands) || parser.resolveOperands(tagMapOperands, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); if (!type.isa()) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) return parser.emitError(parser.getNameLoc(), "tag memref operand count != to map.numInputs"); return success(); } LogicalResult AffineDmaWaitOp::verify() { if (!getOperand(0).getType().isa()) return emitOpError("expected DMA tag to be of memref type"); Region *scope = getAffineScope(*this); for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); } LogicalResult AffineDmaWaitOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineForOp //===----------------------------------------------------------------------===// /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and /// bodyBuilder are empty/null, we include default terminator op. void AffineForOp::build(OpBuilder &builder, OperationState &result, ValueRange lbOperands, AffineMap lbMap, ValueRange ubOperands, AffineMap ubMap, int64_t step, ValueRange iterArgs, BodyBuilderFn bodyBuilder) { assert(((!lbMap && lbOperands.empty()) || lbOperands.size() == lbMap.getNumInputs()) && "lower bound operand count does not match the affine map"); assert(((!ubMap && ubOperands.empty()) || ubOperands.size() == ubMap.getNumInputs()) && "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); for (Value val : iterArgs) result.addTypes(val.getType()); // Add an attribute for the step. result.addAttribute(getStepAttrName(), builder.getIntegerAttr(builder.getIndexType(), step)); // Add the lower bound. result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap)); result.addOperands(lbOperands); // Add the upper bound. result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap)); result.addOperands(ubOperands); result.addOperands(iterArgs); // Create a region and a block for the body. The argument of the region is // the loop induction variable. Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); Value inductionVar = bodyBlock.addArgument(builder.getIndexType()); for (Value val : iterArgs) bodyBlock.addArgument(val.getType()); // Create the default terminator if the builder is not provided and if the // iteration arguments are not provided. Otherwise, leave this to the caller // because we don't know which values to return from the loop. if (iterArgs.empty() && !bodyBuilder) { ensureTerminator(*bodyRegion, builder, result.location); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); bodyBuilder(builder, result.location, inductionVar, bodyBlock.getArguments().drop_front()); } } void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, int64_t ub, int64_t step, ValueRange iterArgs, BodyBuilderFn bodyBuilder) { auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs, bodyBuilder); } static LogicalResult verify(AffineForOp op) { // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) return op.emitOpError( "expected body to have a single index argument for the " "induction variable"); // Verify that the bound operands are valid dimension/symbols. /// Lower bound. if (op.getLowerBoundMap().getNumInputs() > 0) if (failed( verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), op.getLowerBoundMap().getNumDims()))) return failure(); /// Upper bound. if (op.getUpperBoundMap().getNumInputs() > 0) if (failed( verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), op.getUpperBoundMap().getNumDims()))) return failure(); unsigned opNumResults = op.getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of the defined // values match ForOp initial iter operands and backedge basic block // arguments. if (op.getNumIterOperands() != opNumResults) return op.emitOpError( "mismatch between the number of loop-carried values and results"); if (op.getNumRegionIterArgs() != opNumResults) return op.emitOpError( "mismatch between the number of basic block args and results"); return success(); } /// Parse a for operation loop bounds. static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p) { // 'min' / 'max' prefixes are generally syntactic sugar, but are required if // the map has multiple results. bool failedToParsedMinMax = failed(p.parseOptionalKeyword(isLower ? "max" : "min")); auto &builder = p.getBuilder(); auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() : AffineForOp::getUpperBoundAttrName(); // Parse ssa-id as identity map. SmallVector boundOpInfos; if (p.parseOperandList(boundOpInfos)) return failure(); if (!boundOpInfos.empty()) { // Check that only one operand was parsed. if (boundOpInfos.size() > 1) return p.emitError(p.getNameLoc(), "expected only one loop bound operand"); // TODO: improve error message when SSA value is not of index type. // Currently it is 'use of value ... expects different type than prior uses' if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(), result.operands)) return failure(); // Create an identity map using symbol id. This representation is optimized // for storage. Analysis passes may expand it into a multi-dimensional map // if desired. AffineMap map = builder.getSymbolIdentityMap(); result.addAttribute(boundAttrName, AffineMapAttr::get(map)); return success(); } // Get the attribute location. llvm::SMLoc attrLoc = p.getCurrentLocation(); Attribute boundAttr; if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, result.attributes)) return failure(); // Parse full form - affine map followed by dim and symbol list. if (auto affineMapAttr = boundAttr.dyn_cast()) { unsigned currentNumOperands = result.operands.size(); unsigned numDims; if (parseDimAndSymbolList(p, result.operands, numDims)) return failure(); auto map = affineMapAttr.getValue(); if (map.getNumDims() != numDims) return p.emitError( p.getNameLoc(), "dim operand count and affine map dim count must match"); unsigned numDimAndSymbolOperands = result.operands.size() - currentNumOperands; if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) return p.emitError( p.getNameLoc(), "symbol operand count and affine map symbol count must match"); // If the map has multiple results, make sure that we parsed the min/max // prefix. if (map.getNumResults() > 1 && failedToParsedMinMax) { if (isLower) { return p.emitError(attrLoc, "lower loop bound affine map with " "multiple results requires 'max' prefix"); } return p.emitError(attrLoc, "upper loop bound affine map with multiple " "results requires 'min' prefix"); } return success(); } // Parse custom assembly form. if (auto integerAttr = boundAttr.dyn_cast()) { result.attributes.pop_back(); result.addAttribute( boundAttrName, AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); return success(); } return p.emitError( p.getNameLoc(), "expected valid affine map representation for loop bounds"); } static ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); OpAsmParser::OperandType inductionVariable; // Parse the induction variable followed by '='. if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) return failure(); // Parse loop bounds. if (parseBound(/*isLower=*/true, result, parser) || parser.parseKeyword("to", " between bounds") || parseBound(/*isLower=*/false, result, parser)) return failure(); // Parse the optional loop step, we default to 1 if one is not present. if (parser.parseOptionalKeyword("step")) { result.addAttribute( AffineForOp::getStepAttrName(), builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); } else { llvm::SMLoc stepLoc = parser.getCurrentLocation(); IntegerAttr stepAttr; if (parser.parseAttribute(stepAttr, builder.getIndexType(), AffineForOp::getStepAttrName().data(), result.attributes)) return failure(); if (stepAttr.getValue().getSExtValue() < 0) return parser.emitError( stepLoc, "expected step to be representable as a positive signed integer"); } // Parse the optional initial iteration arguments. SmallVector regionArgs, operands; SmallVector argTypes; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(result.types)) return failure(); // Resolve input operands. for (auto operandType : llvm::zip(operands, result.types)) if (parser.resolveOperand(std::get<0>(operandType), std::get<1>(operandType), result.operands)) return failure(); } // Induction variable. Type indexType = builder.getIndexType(); argTypes.push_back(indexType); // Loop carried variables. argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. Region *body = result.addRegion(); if (regionArgs.size() != argTypes.size()) return parser.emitError( parser.getNameLoc(), "mismatch between the number of loop-carried values and results"); if (parser.parseRegion(*body, regionArgs, argTypes)) return failure(); AffineForOp::ensureTerminator(*body, builder, result.location); // Parse the optional attribute list. return parser.parseOptionalAttrDict(result.attributes); } static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p) { AffineMap map = boundMap.getValue(); // Check if this bound should be printed using custom assembly form. // The decision to restrict printing custom assembly form to trivial cases // comes from the will to roundtrip MLIR binary -> text -> binary in a // lossless way. // Therefore, custom assembly form parsing and printing is only supported for // zero-operand constant maps and single symbol operand identity maps. if (map.getNumResults() == 1) { AffineExpr expr = map.getResult(0); // Print constant bound. if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { if (auto constExpr = expr.dyn_cast()) { p << constExpr.getValue(); return; } } // Print bound that consists of a single SSA symbol if the map is over a // single symbol. if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { if (auto symExpr = expr.dyn_cast()) { p.printOperand(*boundOperands.begin()); return; } } } else { // Map has multiple results. Print 'min' or 'max' prefix. p << prefix << ' '; } // Print the map and its operands. p << boundMap; printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), map.getNumDims(), p); } unsigned AffineForOp::getNumIterOperands() { AffineMap lbMap = getLowerBoundMapAttr().getValue(); AffineMap ubMap = getUpperBoundMapAttr().getValue(); return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); } static void print(OpAsmPrinter &p, AffineForOp op) { p << op.getOperationName() << ' '; p.printOperand(op.getBody()->getArgument(0)); p << " = "; printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); p << " to "; printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); if (op.getStep() != 1) p << " step " << op.getStep(); bool printBlockTerminators = false; if (op.getNumIterOperands() > 0) { p << " iter_args("; auto regionArgs = op.getRegionIterArgs(); auto operands = op.getIterOperands(); llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << op.getResultTypes() << ")"; printBlockTerminators = true; } p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getLowerBoundAttrName(), op.getUpperBoundAttrName(), op.getStepAttrName()}); } /// Fold the constant bounds of a loop. static LogicalResult foldLoopBounds(AffineForOp forOp) { auto foldLowerOrUpperBound = [&forOp](bool lower) { // Check to see if each of the operands is the result of a constant. If // so, get the value. If not, ignore it. SmallVector operandConstants; auto boundOperands = lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); for (auto operand : boundOperands) { Attribute operandCst; matchPattern(operand, m_Constant(&operandCst)); operandConstants.push_back(operandCst); } AffineMap boundMap = lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); assert(boundMap.getNumResults() >= 1 && "bound maps should have at least one result"); SmallVector foldedResults; if (failed(boundMap.constantFold(operandConstants, foldedResults))) return failure(); // Compute the max or min as applicable over the results. assert(!foldedResults.empty() && "bounds should have at least one result"); auto maxOrMin = foldedResults[0].cast().getValue(); for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { auto foldedResult = foldedResults[i].cast().getValue(); maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); return success(); }; // Try to fold the lower bound. bool folded = false; if (!forOp.hasConstantLowerBound()) folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); // Try to fold the upper bound. if (!forOp.hasConstantUpperBound()) folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); return success(folded); } /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { SmallVector lbOperands(forOp.getLowerBoundOperands()); SmallVector ubOperands(forOp.getUpperBoundOperands()); auto lbMap = forOp.getLowerBoundMap(); auto ubMap = forOp.getUpperBoundMap(); auto prevLbMap = lbMap; auto prevUbMap = ubMap; canonicalizeMapAndOperands(&lbMap, &lbOperands); lbMap = removeDuplicateExprs(lbMap); canonicalizeMapAndOperands(&ubMap, &ubOperands); ubMap = removeDuplicateExprs(ubMap); // Any canonicalization change always leads to updated map(s). if (lbMap == prevLbMap && ubMap == prevUbMap) return failure(); if (lbMap != prevLbMap) forOp.setLowerBound(lbOperands, lbMap); if (ubMap != prevUbMap) forOp.setUpperBound(ubOperands, ubMap); return success(); } namespace { /// This is a pattern to fold trivially empty loops. struct AffineForEmptyLoopFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineForOp forOp, PatternRewriter &rewriter) const override { // Check that the body only contains a yield. if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); rewriter.eraseOp(forOp); return success(); } }; } // end anonymous namespace void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } LogicalResult AffineForOp::fold(ArrayRef operands, SmallVectorImpl &results) { bool folded = succeeded(foldLoopBounds(*this)); folded |= succeeded(canonicalizeLoopBounds(*this)); return success(folded); } AffineBound AffineForOp::getLowerBound() { auto lbMap = getLowerBoundMap(); return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); } AffineBound AffineForOp::getUpperBound() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap); } void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); SmallVector newOperands(lbOperands.begin(), lbOperands.end()); auto ubOperands = getUpperBoundOperands(); newOperands.append(ubOperands.begin(), ubOperands.end()); auto iterOperands = getIterOperands(); newOperands.append(iterOperands.begin(), iterOperands.end()); (*this)->setOperands(newOperands); setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); } void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); SmallVector newOperands(getLowerBoundOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); auto iterOperands = getIterOperands(); newOperands.append(iterOperands.begin(), iterOperands.end()); (*this)->setOperands(newOperands); setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } void AffineForOp::setLowerBoundMap(AffineMap map) { auto lbMap = getLowerBoundMap(); assert(lbMap.getNumDims() == map.getNumDims() && lbMap.getNumSymbols() == map.getNumSymbols()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); (void)lbMap; setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); } void AffineForOp::setUpperBoundMap(AffineMap map) { auto ubMap = getUpperBoundMap(); assert(ubMap.getNumDims() == map.getNumDims() && ubMap.getNumSymbols() == map.getNumSymbols()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); (void)ubMap; setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } bool AffineForOp::hasConstantLowerBound() { return getLowerBoundMap().isSingleConstant(); } bool AffineForOp::hasConstantUpperBound() { return getUpperBoundMap().isSingleConstant(); } int64_t AffineForOp::getConstantLowerBound() { return getLowerBoundMap().getSingleConstantResult(); } int64_t AffineForOp::getConstantUpperBound() { return getUpperBoundMap().getSingleConstantResult(); } void AffineForOp::setConstantLowerBound(int64_t value) { setLowerBound({}, AffineMap::getConstantMap(value, getContext())); } void AffineForOp::setConstantUpperBound(int64_t value) { setUpperBound({}, AffineMap::getConstantMap(value, getContext())); } AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_begin() + getLowerBoundMap().getNumInputs() + getUpperBoundMap().getNumInputs()}; } bool AffineForOp::matchingBoundOperandList() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); if (lbMap.getNumDims() != ubMap.getNumDims() || lbMap.getNumSymbols() != ubMap.getNumSymbols()) return false; unsigned numOperands = lbMap.getNumInputs(); for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { // Compare Value 's. if (getOperand(i) != getOperand(numOperands + i)) return false; } return true; } Region &AffineForOp::getLoopBody() { return region(); } bool AffineForOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value.getParentRegion()); } LogicalResult AffineForOp::moveOutOfLoop(ArrayRef ops) { for (auto *op : ops) op->moveBefore(*this); return success(); } /// Returns true if the provided value is the induction variable of a /// AffineForOp. bool mlir::isForInductionVar(Value val) { return getForInductionVarOwner(val) != AffineForOp(); } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. AffineForOp mlir::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast(); if (!ivArg || !ivArg.getOwner()) return AffineForOp(); auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); return dyn_cast(containingInst); } /// Extracts the induction variables from a list of AffineForOps and returns /// them. void mlir::extractForInductionVars(ArrayRef forInsts, SmallVectorImpl *ivs) { ivs->reserve(forInsts.size()); for (auto forInst : forInsts) ivs->push_back(forInst.getInductionVar()); } /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop /// operations. template static void buildAffineLoopNestImpl( OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef steps, function_ref bodyBuilderFn, LoopCreatorTy &&loopCreatorFn) { assert(lbs.size() == ubs.size() && "Mismatch in number of arguments"); assert(lbs.size() == steps.size() && "Mismatch in number of arguments"); // If there are no loops to be constructed, construct the body anyway. OpBuilder::InsertionGuard guard(builder); if (lbs.empty()) { if (bodyBuilderFn) bodyBuilderFn(builder, loc, ValueRange()); return; } // Create the loops iteratively and store the induction variables. SmallVector ivs; ivs.reserve(lbs.size()); for (unsigned i = 0, e = lbs.size(); i < e; ++i) { // Callback for creating the loop body, always creates the terminator. auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange iterArgs) { ivs.push_back(iv); // In the innermost loop, call the body builder. if (i == e - 1 && bodyBuilderFn) { OpBuilder::InsertionGuard nestedGuard(nestedBuilder); bodyBuilderFn(nestedBuilder, nestedLoc, ivs); } nestedBuilder.create(nestedLoc); }; // Delegate actual loop creation to the callback in order to dispatch // between constant- and variable-bound loops. auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody); builder.setInsertionPointToStart(loop.getBody()); } } /// Creates an affine loop from the bounds known to be constants. static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { return builder.create(loc, lb, ub, step, /*iterArgs=*/llvm::None, bodyBuilderFn); } /// Creates an affine loop from the bounds that may or may not be constants. static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { auto lbConst = lb.getDefiningOp(); auto ubConst = ub.getDefiningOp(); if (lbConst && ubConst) return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(), ubConst.getValue(), step, bodyBuilderFn); return builder.create(loc, lb, builder.getDimIdentityMap(), ub, builder.getDimIdentityMap(), step, /*iterArgs=*/llvm::None, bodyBuilderFn); } void mlir::buildAffineLoopNest( OpBuilder &builder, Location loc, ArrayRef lbs, ArrayRef ubs, ArrayRef steps, function_ref bodyBuilderFn) { buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, buildAffineLoopFromConstants); } void mlir::buildAffineLoopNest( OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ArrayRef steps, function_ref bodyBuilderFn) { buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, buildAffineLoopFromValues); } //===----------------------------------------------------------------------===// // AffineIfOp //===----------------------------------------------------------------------===// namespace { /// Remove else blocks that have nothing other than a zero value yield. struct SimplifyDeadElse : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineIfOp ifOp, PatternRewriter &rewriter) const override { if (ifOp.elseRegion().empty() || !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) return failure(); rewriter.startRootUpdate(ifOp); rewriter.eraseBlock(ifOp.getElseBlock()); rewriter.finalizeRootUpdate(ifOp); return success(); } }; } // end anonymous namespace. static LogicalResult verify(AffineIfOp op) { // Verify that we have a condition attribute. auto conditionAttr = op->getAttrOfType(op.getConditionAttrName()); if (!conditionAttr) return op.emitOpError( "requires an integer set attribute named 'condition'"); // Verify that there are enough operands for the condition. IntegerSet condition = conditionAttr.getValue(); if (op.getNumOperands() != condition.getNumInputs()) return op.emitOpError( "operand count and condition integer set dimension and " "symbol count must match"); // Verify that the operands are valid dimension/symbols. if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), condition.getNumDims()))) return failure(); return success(); } static ParseResult parseAffineIfOp(OpAsmParser &parser, OperationState &result) { // Parse the condition attribute set. IntegerSetAttr conditionAttr; unsigned numDims; if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), result.attributes) || parseDimAndSymbolList(parser, result.operands, numDims)) return failure(); // Verify the condition operands. auto set = conditionAttr.getValue(); if (set.getNumDims() != numDims) return parser.emitError( parser.getNameLoc(), "dim operand count and integer set dim count must match"); if (numDims + set.getNumSymbols() != result.operands.size()) return parser.emitError( parser.getNameLoc(), "symbol operand count and integer set symbol count must match"); if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Create the regions for 'then' and 'else'. The latter must be created even // if it remains empty for the validity of the operation. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); // Parse the 'then' region. if (parser.parseRegion(*thenRegion, {}, {})) return failure(); AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, {}, {})) return failure(); AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } static void print(OpAsmPrinter &p, AffineIfOp op) { auto conditionAttr = op->getAttrOfType(op.getConditionAttrName()); p << "affine.if " << conditionAttr; printDimAndSymbolList(op.operand_begin(), op.operand_end(), conditionAttr.getValue().getNumDims(), p); p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/op.getNumResults()); // Print the 'else' regions if it has any blocks. auto &elseRegion = op.elseRegion(); if (!elseRegion.empty()) { p << " else"; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/op.getNumResults()); } // Print the attribute list. p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/op.getConditionAttrName()); } IntegerSet AffineIfOp::getIntegerSet() { return (*this) ->getAttrOfType(getConditionAttrName()) .getValue(); } void AffineIfOp::setIntegerSet(IntegerSet newSet) { setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); } void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) { setIntegerSet(set); (*this)->setOperands(operands); } void AffineIfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, IntegerSet set, ValueRange args, bool withElseRegion) { assert(resultTypes.empty() || withElseRegion); result.addTypes(resultTypes); result.addOperands(args); result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); Region *thenRegion = result.addRegion(); thenRegion->push_back(new Block()); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); Region *elseRegion = result.addRegion(); if (withElseRegion) { elseRegion->push_back(new Block()); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); } } void AffineIfOp::build(OpBuilder &builder, OperationState &result, IntegerSet set, ValueRange args, bool withElseRegion) { AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args, withElseRegion); } /// Canonicalize an affine if op's conditional (integer set + operands). LogicalResult AffineIfOp::fold(ArrayRef, SmallVectorImpl &) { auto set = getIntegerSet(); SmallVector operands(getOperands()); canonicalizeSetAndOperands(&set, &operands); // Any canonicalization change always leads to either a reduction in the // number of operands or a change in the number of symbolic operands // (promotion of dims to symbols). if (operands.size() < getIntegerSet().getNumInputs() || set.getNumSymbols() > getIntegerSet().getNumSymbols()) { setConditional(set, operands); return success(); } return failure(); } void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // AffineLoadOp //===----------------------------------------------------------------------===// void AffineLoadOp::build(OpBuilder &builder, OperationState &result, AffineMap map, ValueRange operands) { assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); result.addOperands(operands); if (map) result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); auto memrefType = operands[0].getType().cast(); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(OpBuilder &builder, OperationState &result, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); auto memrefType = memref.getType().cast(); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange indices) { auto memrefType = memref.getType().cast(); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, memref, map, indices); } static ParseResult parseAffineLoadOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); MemRefType type; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure( parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineLoadOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands) || parser.addTypeToList(type.getElementType(), result.types)); } static void print(OpAsmPrinter &p, AffineLoadOp op) { p << "affine.load " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType(); } /// Verify common indexing invariants of affine.load, affine.store, /// affine.vector_load and affine.vector_store. static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands) { if (mapAttr) { AffineMap map = mapAttr.getValue(); if (map.getNumResults() != memrefType.getRank()) return op->emitOpError("affine map num results must equal memref rank"); if (map.getNumInputs() != numIndexOperands) return op->emitOpError("expects as many subscripts as affine map inputs"); } else { if (memrefType.getRank() != numIndexOperands) return op->emitOpError( "expects the number of subscripts to be equal to memref rank"); } Region *scope = getAffineScope(op); for (auto idx : mapOperands) { if (!idx.getType().isIndex()) return op->emitOpError("index to load must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return op->emitOpError("index must be a dimension or symbol identifier"); } return success(); } LogicalResult verify(AffineLoadOp op) { auto memrefType = op.getMemRefType(); if (op.getType() != memrefType.getElementType()) return op.emitOpError("result type must match element type of memref"); if (failed(verifyMemoryOpIndexing( op.getOperation(), op->getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 1))) return failure(); return success(); } void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } //===----------------------------------------------------------------------===// // AffineStoreOp //===----------------------------------------------------------------------===// void AffineStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(valueToStore); result.addOperands(memref); result.addOperands(mapOperands); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); } // Use identity map. void AffineStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { auto memrefType = memref.getType().cast(); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, valueToStore, memref, map, indices); } static ParseResult parseAffineStoreOp(OpAsmParser &parser, OperationState &result) { auto indexTy = parser.getBuilder().getIndexType(); MemRefType type; OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineStoreOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(storeValueInfo, type.getElementType(), result.operands) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands)); } static void print(OpAsmPrinter &p, AffineStoreOp op) { p << "affine.store " << op.getValueToStore(); p << ", " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType(); } LogicalResult verify(AffineStoreOp op) { // First operand must have same type as memref element type. auto memrefType = op.getMemRefType(); if (op.getValueToStore().getType() != memrefType.getElementType()) return op.emitOpError( "first operand must have same type memref element type"); if (failed(verifyMemoryOpIndexing( op.getOperation(), op->getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 2))) return failure(); return success(); } void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineMinMaxOpBase //===----------------------------------------------------------------------===// template static LogicalResult verifyAffineMinMaxOp(T op) { // Verify that operand count matches affine map dimension and symbol count. if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) return op.emitOpError( "operand count and affine map dimension and symbol count must match"); return success(); } template static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName()); auto operands = op.getOperands(); unsigned numDims = op.map().getNumDims(); p << '(' << operands.take_front(numDims) << ')'; if (operands.size() != numDims) p << '[' << operands.drop_front(numDims) << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{T::getMapAttrName()}); } template static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); SmallVector dim_infos; SmallVector sym_infos; AffineMapAttr mapAttr; return failure( parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) || parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) || parser.parseOperandList(sym_infos, OpAsmParser::Delimiter::OptionalSquare) || parser.parseOptionalAttrDict(result.attributes) || parser.resolveOperands(dim_infos, indexType, result.operands) || parser.resolveOperands(sym_infos, indexType, result.operands) || parser.addTypeToList(indexType, result.types)); } /// Fold an affine min or max operation with the given operands. The operand /// list may contain nulls, which are interpreted as the operand not being a /// constant. template static OpFoldResult foldMinMaxOp(T op, ArrayRef operands) { static_assert(llvm::is_one_of::value, "expected affine min or max op"); // Fold the affine map. // TODO: Fold more cases: // min(some_affine, some_affine + constant, ...), etc. SmallVector results; auto foldedMap = op.map().partialConstantFold(operands, &results); // If some of the map results are not constant, try changing the map in-place. if (results.empty()) { // If the map is the same, report that folding did not happen. if (foldedMap == op.map()) return {}; op.setAttr("map", AffineMapAttr::get(foldedMap)); return op.getResult(); } // Otherwise, completely fold the op into a constant. auto resultIt = std::is_same::value ? std::min_element(results.begin(), results.end()) : std::max_element(results.begin(), results.end()); if (resultIt == results.end()) return {}; return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); } //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// // // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) // OpFoldResult AffineMinOp::fold(ArrayRef operands) { return foldMinMaxOp(*this, operands); } void AffineMinOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert>(context); } //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// // // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) // OpFoldResult AffineMaxOp::fold(ArrayRef operands) { return foldMinMaxOp(*this, operands); } void AffineMaxOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert>(context); } //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// // // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> // static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); MemRefType type; OpAsmParser::OperandType memrefInfo; IntegerAttr hintInfo; auto i32Type = parser.getBuilder().getIntegerType(32); StringRef readOrWrite, cacheType; AffineMapAttr mapAttr; SmallVector mapOperands; if (parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffinePrefetchOp::getMapAttrName(), result.attributes) || parser.parseComma() || parser.parseKeyword(&readOrWrite) || parser.parseComma() || parser.parseKeyword("locality") || parser.parseLess() || parser.parseAttribute(hintInfo, i32Type, AffinePrefetchOp::getLocalityHintAttrName(), result.attributes) || parser.parseGreater() || parser.parseComma() || parser.parseKeyword(&cacheType) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands)) return failure(); if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); result.addAttribute( AffinePrefetchOp::getIsWriteAttrName(), parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); if (!cacheType.equals("data") && !cacheType.equals("instr")) return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); result.addAttribute( AffinePrefetchOp::getIsDataCacheAttrName(), parser.getBuilder().getBoolAttr(cacheType.equals("data"))); return success(); } static void print(OpAsmPrinter &p, AffinePrefetchOp op) { p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '['; AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName()); if (mapAttr) { SmallVector operands(op.getMapOperands()); p.printAffineMapOfSSAIds(mapAttr, operands); } p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " << "locality<" << op.localityHint() << ">, " << (op.isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(), op.getIsDataCacheAttrName(), op.getIsWriteAttrName()}); p << " : " << op.getMemRefType(); } static LogicalResult verify(AffinePrefetchOp op) { auto mapAttr = op->getAttrOfType(op.getMapAttrName()); if (mapAttr) { AffineMap map = mapAttr.getValue(); if (map.getNumResults() != op.getMemRefType().getRank()) return op.emitOpError("affine.prefetch affine map num results must equal" " memref rank"); if (map.getNumInputs() + 1 != op.getNumOperands()) return op.emitOpError("too few operands"); } else { if (op.getNumOperands() != 1) return op.emitOpError("too few operands"); } Region *scope = getAffineScope(op); for (auto idx : op.getMapOperands()) { if (!isValidAffineIndexOperand(idx, scope)) return op.emitOpError("index must be a dimension or symbol identifier"); } return success(); } void AffinePrefetchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { // prefetch(memrefcast) -> prefetch results.insert>(context); } LogicalResult AffinePrefetchOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// prefetch(memrefcast) -> prefetch return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineParallelOp //===----------------------------------------------------------------------===// void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, ArrayRef ranges) { SmallVector lbExprs(ranges.size(), builder.getAffineConstantExpr(0)); auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext()); SmallVector ubExprs; for (int64_t range : ranges) ubExprs.push_back(builder.getAffineConstantExpr(range)); auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext()); build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap, /*ubArgs=*/{}); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, AffineMap lbMap, ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs) { auto numDims = lbMap.getNumResults(); // Verify that the dimensionality of both maps are the same. assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); // Make default step sizes of 1. SmallVector steps(numDims, 1); build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs, steps); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, AffineMap lbMap, ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs, ArrayRef steps) { auto numDims = lbMap.getNumResults(); // Verify that the dimensionality of the maps matches the number of steps. assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); assert(numDims == steps.size() && "num dims and num steps mismatch"); result.addTypes(resultTypes); // Convert the reductions to integer attributes. SmallVector reductionAttrs; for (AtomicRMWKind reduction : reductions) reductionAttrs.push_back( builder.getI64IntegerAttr(static_cast(reduction))); result.addAttribute(getReductionsAttrName(), builder.getArrayAttr(reductionAttrs)); result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); result.addOperands(lbArgs); result.addOperands(ubArgs); // Create a region and a block for the body. auto bodyRegion = result.addRegion(); auto body = new Block(); // Add all the block arguments. for (unsigned i = 0; i < numDims; ++i) body->addArgument(IndexType::get(builder.getContext())); bodyRegion->push_back(body); if (resultTypes.empty()) ensureTerminator(*bodyRegion, builder, result.location); } Region &AffineParallelOp::getLoopBody() { return region(); } bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value.getParentRegion()); } LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef ops) { for (Operation *op : ops) op->moveBefore(*this); return success(); } unsigned AffineParallelOp::getNumDims() { return steps().size(); } AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { return getOperands().take_front(lowerBoundsMap().getNumInputs()); } AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { return getOperands().drop_front(lowerBoundsMap().getNumInputs()); } AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands()); } AffineValueMap AffineParallelOp::getUpperBoundsValueMap() { return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands()); } AffineValueMap AffineParallelOp::getRangesValueMap() { AffineValueMap out; AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), &out); return out; } Optional> AffineParallelOp::getConstantRanges() { // Try to convert all the ranges to constant expressions. SmallVector out; AffineValueMap rangesValueMap = getRangesValueMap(); out.reserve(rangesValueMap.getNumResults()); for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { auto expr = rangesValueMap.getResult(i); auto cst = expr.dyn_cast(); if (!cst) return llvm::None; out.push_back(cst.getValue()); } return out; } Block *AffineParallelOp::getBody() { return ®ion().front(); } OpBuilder AffineParallelOp::getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); } void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs() && "operands to map must match number of inputs"); assert(map.getNumResults() >= 1 && "bounds map has at least one result"); auto ubOperands = getUpperBoundsOperands(); SmallVector newOperands(lbOperands); newOperands.append(ubOperands.begin(), ubOperands.end()); (*this)->setOperands(newOperands); lowerBoundsMapAttr(AffineMapAttr::get(map)); } void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs() && "operands to map must match number of inputs"); assert(map.getNumResults() >= 1 && "bounds map has at least one result"); SmallVector newOperands(getLowerBoundsOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); (*this)->setOperands(newOperands); upperBoundsMapAttr(AffineMapAttr::get(map)); } void AffineParallelOp::setLowerBoundsMap(AffineMap map) { AffineMap lbMap = lowerBoundsMap(); assert(lbMap.getNumDims() == map.getNumDims() && lbMap.getNumSymbols() == map.getNumSymbols()); (void)lbMap; lowerBoundsMapAttr(AffineMapAttr::get(map)); } void AffineParallelOp::setUpperBoundsMap(AffineMap map) { AffineMap ubMap = upperBoundsMap(); assert(ubMap.getNumDims() == map.getNumDims() && ubMap.getNumSymbols() == map.getNumSymbols()); (void)ubMap; upperBoundsMapAttr(AffineMapAttr::get(map)); } SmallVector AffineParallelOp::getSteps() { SmallVector result; for (Attribute attr : steps()) { result.push_back(attr.cast().getInt()); } return result; } void AffineParallelOp::setSteps(ArrayRef newSteps) { stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } static LogicalResult verify(AffineParallelOp op) { auto numDims = op.getNumDims(); if (op.lowerBoundsMap().getNumResults() != numDims || op.upperBoundsMap().getNumResults() != numDims || op.steps().size() != numDims || op.getBody()->getNumArguments() != numDims) return op.emitOpError("region argument count and num results of upper " "bounds, lower bounds, and steps must all match"); if (op.reductions().size() != op.getNumResults()) return op.emitOpError("a reduction must be specified for each output"); // Verify reduction ops are all valid for (Attribute attr : op.reductions()) { auto intAttr = attr.dyn_cast(); if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt())) return op.emitOpError("invalid reduction attribute"); } // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), op.lowerBoundsMap().getNumDims()))) return failure(); /// Upper bounds. if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), op.upperBoundsMap().getNumDims()))) return failure(); return success(); } LogicalResult AffineValueMap::canonicalize() { SmallVector newOperands{operands}; auto newMap = getAffineMap(); composeAffineMapAndOperands(&newMap, &newOperands); if (newMap == getAffineMap() && newOperands == operands) return failure(); reset(newMap, newOperands); return success(); } /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { AffineValueMap lb = op.getLowerBoundsValueMap(); bool lbCanonicalized = succeeded(lb.canonicalize()); AffineValueMap ub = op.getUpperBoundsValueMap(); bool ubCanonicalized = succeeded(ub.canonicalize()); // Any canonicalization change always leads to updated map(s). if (!lbCanonicalized && !ubCanonicalized) return failure(); if (lbCanonicalized) op.setLowerBounds(lb.getOperands(), lb.getAffineMap()); if (ubCanonicalized) op.setUpperBounds(ub.getOperands(), ub.getAffineMap()); return success(); } LogicalResult AffineParallelOp::fold(ArrayRef operands, SmallVectorImpl &results) { return canonicalizeLoopBounds(*this); } static void print(OpAsmPrinter &p, AffineParallelOp op) { p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(), op.getLowerBoundsOperands()); p << ") to ("; p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(), op.getUpperBoundsOperands()); p << ')'; SmallVector steps = op.getSteps(); bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); if (!elideSteps) { p << " step ("; llvm::interleaveComma(steps, p); p << ')'; } if (op.getNumResults()) { p << " reduce ("; llvm::interleaveComma(op.reductions(), p, [&](auto &attr) { AtomicRMWKind sym = *symbolizeAtomicRMWKind(attr.template cast().getInt()); p << "\"" << stringifyAtomicRMWKind(sym) << "\""; }); p << ") -> (" << op.getResultTypes() << ")"; } p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/op.getNumResults()); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(), AffineParallelOp::getLowerBoundsMapAttrName(), AffineParallelOp::getUpperBoundsMapAttrName(), AffineParallelOp::getStepsAttrName()}); } // // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)` // `to` `(` map-of-ssa-ids `)` steps? region attr-dict? // steps ::= `steps` `(` integer-literals `)` // static ParseResult parseAffineParallelOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); AffineMapAttr lowerBoundsAttr, upperBoundsAttr; SmallVector ivs; SmallVector lowerBoundsMapOperands; SmallVector upperBoundsMapOperands; if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren) || parser.parseEqual() || parser.parseAffineMapOfSSAIds( lowerBoundsMapOperands, lowerBoundsAttr, AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lowerBoundsMapOperands, indexType, result.operands) || parser.parseKeyword("to") || parser.parseAffineMapOfSSAIds( upperBoundsMapOperands, upperBoundsAttr, AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upperBoundsMapOperands, indexType, result.operands)) return failure(); AffineMapAttr stepsMapAttr; NamedAttrList stepsAttrs; SmallVector stepsMapOperands; if (failed(parser.parseOptionalKeyword("step"))) { SmallVector steps(ivs.size(), 1); result.addAttribute(AffineParallelOp::getStepsAttrName(), builder.getI64ArrayAttr(steps)); } else { if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr, AffineParallelOp::getStepsAttrName(), stepsAttrs, OpAsmParser::Delimiter::Paren)) return failure(); // Convert steps from an AffineMap into an I64ArrayAttr. SmallVector steps; auto stepsMap = stepsMapAttr.getValue(); for (const auto &result : stepsMap.getResults()) { auto constExpr = result.dyn_cast(); if (!constExpr) return parser.emitError(parser.getNameLoc(), "steps must be constant integers"); steps.push_back(constExpr.getValue()); } result.addAttribute(AffineParallelOp::getStepsAttrName(), builder.getI64ArrayAttr(steps)); } // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the // quoted strings are a member of the enum AtomicRMWKind. SmallVector reductions; if (succeeded(parser.parseOptionalKeyword("reduce"))) { if (parser.parseLParen()) return failure(); do { // Parse a single quoted string via the attribute parsing, and then // verify it is a member of the enum and convert to it's integer // representation. StringAttr attrVal; NamedAttrList attrStorage; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce", attrStorage)) return failure(); llvm::Optional reduction = symbolizeAtomicRMWKind(attrVal.getValue()); if (!reduction) return parser.emitError(loc, "invalid reduction value: ") << attrVal; reductions.push_back(builder.getI64IntegerAttr( static_cast(reduction.getValue()))); // While we keep getting commas, keep parsing. } while (succeeded(parser.parseOptionalComma())); if (parser.parseRParen()) return failure(); } result.addAttribute(AffineParallelOp::getReductionsAttrName(), builder.getArrayAttr(reductions)); // Parse return types of reductions (if any) if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Now parse the body. Region *body = result.addRegion(); SmallVector types(ivs.size(), indexType); if (parser.parseRegion(*body, ivs, types) || parser.parseOptionalAttrDict(result.attributes)) return failure(); // Add a terminator if none was parsed. AffineParallelOp::ensureTerminator(*body, builder, result.location); return success(); } //===----------------------------------------------------------------------===// // AffineYieldOp //===----------------------------------------------------------------------===// static LogicalResult verify(AffineYieldOp op) { auto *parentOp = op->getParentOp(); auto results = parentOp->getResults(); auto operands = op.getOperands(); if (!isa(parentOp)) return op.emitOpError() << "only terminates affine.if/for/parallel regions"; if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "parent of yield must have same number of " "results as the yield operands"; for (auto it : llvm::zip(results, operands)) { if (std::get<0>(it).getType() != std::get<1>(it).getType()) return op.emitOpError() << "types mismatch between yield op and its parent"; } return success(); } //===----------------------------------------------------------------------===// // AffineVectorLoadOp //===----------------------------------------------------------------------===// void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, VectorType resultType, AffineMap map, ValueRange operands) { assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); result.addOperands(operands); if (map) result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); result.types.push_back(resultType); } void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, VectorType resultType, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); result.types.push_back(resultType); } void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, VectorType resultType, Value memref, ValueRange indices) { auto memrefType = memref.getType().cast(); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, resultType, memref, map, indices); } static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); MemRefType memrefType; VectorType resultType; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure( parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineVectorLoadOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(memrefType) || parser.parseComma() || parser.parseType(resultType) || parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands) || parser.addTypeToList(resultType, result.types)); } static void print(OpAsmPrinter &p, AffineVectorLoadOp op) { p << "affine.vector_load " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType() << ", " << op.getType(); } /// Verify common invariants of affine.vector_load and affine.vector_store. static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { // Check that memref and vector element types match. if (memrefType.getElementType() != vectorType.getElementType()) return op->emitOpError( "requires memref and vector types of the same elemental type"); return success(); } static LogicalResult verify(AffineVectorLoadOp op) { MemRefType memrefType = op.getMemRefType(); if (failed(verifyMemoryOpIndexing( op.getOperation(), op->getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 1))) return failure(); if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, op.getVectorType()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // AffineVectorStoreOp //===----------------------------------------------------------------------===// void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(valueToStore); result.addOperands(memref); result.addOperands(mapOperands); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); } // Use identity map. void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { auto memrefType = memref.getType().cast(); int64_t rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, valueToStore, memref, map, indices); } static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, OperationState &result) { auto indexTy = parser.getBuilder().getIndexType(); MemRefType memrefType; VectorType resultType; OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure( parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineVectorStoreOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(memrefType) || parser.parseComma() || parser.parseType(resultType) || parser.resolveOperand(storeValueInfo, resultType, result.operands) || parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands)); } static void print(OpAsmPrinter &p, AffineVectorStoreOp op) { p << "affine.vector_store " << op.getValueToStore(); p << ", " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); } static LogicalResult verify(AffineVectorStoreOp op) { MemRefType memrefType = op.getMemRefType(); if (failed(verifyMemoryOpIndexing( op.getOperation(), op->getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 2))) return failure(); if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, op.getVectorType()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"