//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements convenience types for working with super-vectorization // operations, in particular super-vector loads and stores. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/StringSet.h" #include <numeric> using namespace mlir; using namespace mlir::vector; /// Helper enum to classify mask value. enum class MaskFormat { AllTrue = 0, AllFalse = 1, Unknown = 2, }; /// Helper method to classify a 1-D mask value. Currently, the method /// looks "under the hood" of a constant value with dense attributes /// and a constant mask operation (since the client may be called at /// various stages during progressive lowering). static MaskFormat get1DMaskFormat(Value mask) { if (auto c = mask.getDefiningOp<ConstantOp>()) { // Inspect constant dense values. We count up for bits that // are set, count down for bits that are cleared, and bail // when a mix is detected. if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) { int64_t val = 0; for (bool b : denseElts.getValues<bool>()) if (b && val >= 0) val++; else if (!b && val <= 0) val--; else return MaskFormat::Unknown; if (val > 0) return MaskFormat::AllTrue; if (val < 0) return MaskFormat::AllFalse; } } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { // Inspect constant mask index. If the index exceeds the // dimension size, all bits are set. If the index is zero // or less, no bits are set. ArrayAttr masks = m.mask_dim_sizes(); assert(masks.size() == 1); int64_t i = masks[0].cast<IntegerAttr>().getInt(); int64_t u = m.getType().cast<VectorType>().getDimSize(0); if (i >= u) return MaskFormat::AllTrue; if (i <= 0) return MaskFormat::AllFalse; } return MaskFormat::Unknown; } /// Helper method to cast a 1-D memref<10xf32> "base" into a /// memref<vector<10xf32>> in the output parameter "newBase", /// using the 'element' vector type "vt". Returns true on success. static bool castedToMemRef(Location loc, Value base, MemRefType mt, VectorType vt, PatternRewriter &rewriter, Value &newBase) { // The vector.type_cast operation does not accept unknown memref<?xf32>. // TODO: generalize the cast and accept this case too if (!mt.hasStaticShape()) return false; newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base); return true; } //===----------------------------------------------------------------------===// // VectorDialect //===----------------------------------------------------------------------===// void VectorDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Vector/VectorOps.cpp.inc" >(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *VectorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create<ConstantOp>(loc, type, value); } IntegerType vector::getVectorSubscriptType(Builder &builder) { return builder.getIntegerType(64); } ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, ArrayRef<int64_t> values) { return builder.getI64ArrayAttr(values); } //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// static LogicalResult verify(ReductionOp op) { // Verify for 1-D vector. int64_t rank = op.getVectorType().getRank(); if (rank != 1) return op.emitOpError("unsupported reduction rank: ") << rank; // Verify supported reduction kind. auto kind = op.kind(); Type eltType = op.dest().getType(); if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { if (!eltType.isIntOrIndexOrFloat()) return op.emitOpError("unsupported reduction type"); } else if (kind == "and" || kind == "or" || kind == "xor") { if (!eltType.isIntOrIndex()) return op.emitOpError("unsupported reduction type"); } else { return op.emitOpError("unknown reduction kind: ") << kind; } // Verify optional accumulator. if (!op.acc().empty()) { if (kind != "add" && kind != "mul") return op.emitOpError("no accumulator for reduction kind: ") << kind; if (!eltType.isa<FloatType>()) return op.emitOpError("no accumulator for type: ") << eltType; } return success(); } static ParseResult parseReductionOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 2> operandsInfo; Type redType; Type resType; Attribute attr; if (parser.parseAttribute(attr, "kind", result.attributes) || parser.parseComma() || parser.parseOperandList(operandsInfo) || parser.parseColonType(redType) || parser.parseKeywordType("into", resType) || (operandsInfo.size() > 0 && parser.resolveOperand(operandsInfo[0], redType, result.operands)) || (operandsInfo.size() > 1 && parser.resolveOperand(operandsInfo[1], resType, result.operands)) || parser.addTypeToList(resType, result.types)) return failure(); if (operandsInfo.size() < 1 || operandsInfo.size() > 2) return parser.emitError(parser.getNameLoc(), "unsupported number of operands"); return success(); } static void print(OpAsmPrinter &p, ReductionOp op) { p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector(); if (!op.acc().empty()) p << ", " << op.acc(); p << " : " << op.vector().getType() << " into " << op.dest().getType(); } //===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, Value lhs, Value rhs, Value acc, ArrayRef<ArrayRef<AffineExpr>> indexingExprs, ArrayRef<StringRef> iteratorTypes) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), builder.getAffineMapArrayAttr( AffineMap::inferFromExprList(indexingExprs))); result.addAttribute(getIteratorTypesAttrName(), builder.getStrArrayAttr(iteratorTypes)); } void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, Value lhs, Value rhs, Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), indexingMaps); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); } static ParseResult parseContractionOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType lhsInfo; OpAsmParser::OperandType rhsInfo; OpAsmParser::OperandType accInfo; SmallVector<OpAsmParser::OperandType, 2> masksInfo; SmallVector<Type, 2> types; Type resultType; auto loc = parser.getCurrentLocation(); DictionaryAttr dictAttr; // TODO: Unify linalg op attribute parsing. if (parser.parseAttribute(dictAttr, "_", result.attributes) || parser.parseOperand(lhsInfo) || parser.parseComma() || parser.parseOperand(rhsInfo) || parser.parseComma() || parser.parseOperand(accInfo) || parser.parseTrailingOperandList(masksInfo) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types) || parser.parseKeywordType("into", resultType) || parser.resolveOperand(lhsInfo, types[0], result.operands) || parser.resolveOperand(rhsInfo, types[1], result.operands) || parser.resolveOperand(accInfo, resultType, result.operands) || parser.addTypeToList(resultType, result.types)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); if (masksInfo.empty()) return success(); if (masksInfo.size() != 2) return parser.emitError(parser.getNameLoc(), "expected zero or exactly 2 vector mask operands"); auto lhsType = types[0].cast<VectorType>(); auto rhsType = types[1].cast<VectorType>(); auto maskElementType = parser.getBuilder().getI1Type(); std::array<Type, 2> maskTypes = { VectorType::get(lhsType.getShape(), maskElementType), VectorType::get(rhsType.getShape(), maskElementType)}; if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) return failure(); return success(); } static void print(OpAsmPrinter &p, ContractionOp op) { // TODO: Unify printing code with linalg ops. auto attrNames = op.getTraitAttrNames(); llvm::StringSet<> traitAttrsSet; traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector<NamedAttribute, 8> attrs; for (auto attr : op.getAttrs()) if (traitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", "; p << op.rhs() << ", " << op.acc(); if (op.masks().size() == 2) p << ", " << op.masks(); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into " << op.getResultType(); } static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector<std::pair<int64_t, int64_t>> &map) { for (auto &dimPair : map) { if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || dimPair.second < 0 || dimPair.second >= rhsType.getRank() || lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) return false; } return true; } static LogicalResult verifyOutputShape( ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { DenseSet<int64_t> lhsContractingDimSet; DenseSet<int64_t> rhsContractingDimSet; for (auto &dimPair : contractingDimMap) { lhsContractingDimSet.insert(dimPair.first); rhsContractingDimSet.insert(dimPair.second); } DenseSet<int64_t> rhsBatchDimSet; for (auto &dimPair : batchDimMap) rhsBatchDimSet.insert(dimPair.second); // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. SmallVector<int64_t, 4> expectedResultDims; for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { if (lhsContractingDimSet.count(i) > 0) continue; expectedResultDims.push_back(lhsType.getDimSize(i)); } // Add free dimensions from 'rhsType' to 'expectedResultDims'. for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) continue; expectedResultDims.push_back(rhsType.getDimSize(i)); } // Verify 'expectedResultDims'. if (expectedResultDims.size() == 0) { // No batch or free dimension implies a scalar result. if (resType.isa<VectorType>() || accType.isa<VectorType>()) return op.emitOpError("invalid accumulator/result vector shape"); } else { // At least one batch or free dimension implies a vector result. auto resVectorType = resType.dyn_cast<VectorType>(); auto accVectorType = accType.dyn_cast<VectorType>(); if (!resVectorType || !accVectorType) return op.emitOpError("invalid accumulator/result vector shape"); // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector // types fully define the result vector type. This assumes the affine maps // are well-formed, which must have been verified already. MLIRContext *ctx = op.getContext(); AffineMap lhsMap = op.getIndexingMaps()[0]; AffineMap rhsMap = op.getIndexingMaps()[1]; SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); for (auto pair : {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { VectorType v = pair.first; auto map = pair.second; for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { unsigned pos = map.getDimPosition(idx); if (!extents[pos]) extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); } } assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) && "expected extent along all dimensions."); AffineMap resMap = op.getIndexingMaps()[2]; auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), /*symCount=*/0, extents, ctx); // Compose the resMap with the extentsMap, which is a constant map. AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); assert(llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && "expected constant extent along all dimensions."); // Extract the expected shape and build the type. auto expectedShape = llvm::to_vector<4>( llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { return e.cast<AffineConstantExpr>().getValue(); })); auto expected = VectorType::get(expectedShape, resVectorType.getElementType()); if (resVectorType != expected || accVectorType != expected) return op.emitOpError( "invalid accumulator/result vector shape, expected: ") << expected; } return success(); } static LogicalResult verify(ContractionOp op) { auto lhsType = op.getLhsType(); auto rhsType = op.getRhsType(); auto accType = op.getAccType(); auto resType = op.getResultType(); // Verify that an indexing map was specified for each vector operand. if (op.indexing_maps().size() != 3) return op.emitOpError("expected an indexing map for each vector operand"); // Verify that each index map has 'numIterators' inputs, no symbols, and // that the number of map outputs equals the rank of its associated // vector operand. unsigned numIterators = op.iterator_types().getValue().size(); for (auto it : llvm::enumerate(op.indexing_maps())) { auto index = it.index(); auto map = it.value().cast<AffineMapAttr>().getValue(); if (map.getNumSymbols() != 0) return op.emitOpError("expected indexing map ") << index << " to have no symbols"; auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>(); unsigned rank = vectorType ? vectorType.getShape().size() : 0; // Verify that the map has the right number of inputs, outputs, and indices. // This also correctly accounts for (..) -> () for rank-0 results. if (map.getNumDims() != numIterators) return op.emitOpError("expected indexing map ") << index << " to have " << numIterators << " number of inputs"; if (map.getNumResults() != rank) return op.emitOpError("expected indexing map ") << index << " to have " << rank << " number of outputs"; if (!map.isProjectedPermutation()) return op.emitOpError("expected indexing map ") << index << " to be a projected permutation of its inputs"; } auto contractingDimMap = op.getContractingDimMap(); auto batchDimMap = op.getBatchDimMap(); // Verify at least one contracting dimension pair was specified. if (contractingDimMap.empty()) return op.emitOpError("expected at least one contracting dimension pair"); // Verify contracting dimension map was properly constructed. if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) return op.emitOpError("invalid contracting dimension map"); // Verify batch dimension map was properly constructed. if (!verifyDimMap(lhsType, rhsType, batchDimMap)) return op.emitOpError("invalid batch dimension map"); // Verify 'accType' and 'resType' shape. if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType, contractingDimMap, batchDimMap))) return failure(); // Verify that either two vector masks are set or none are set. auto lhsMaskType = op.getLHSVectorMaskType(); auto rhsMaskType = op.getRHSVectorMaskType(); if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) return op.emitOpError("invalid number of vector masks specified"); if (lhsMaskType && rhsMaskType) { // Verify mask rank == argument rank. if (lhsMaskType.getShape().size() != lhsType.getShape().size() || rhsMaskType.getShape().size() != rhsType.getShape().size()) return op.emitOpError("invalid vector mask rank"); } return success(); } ArrayRef<StringRef> ContractionOp::getTraitAttrNames() { static constexpr StringRef names[2] = {getIndexingMapsAttrName(), getIteratorTypesAttrName()}; return llvm::makeArrayRef(names); } static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) if (targetExpr == map.getResult(i)) return i; return -1; } static std::vector<std::pair<int64_t, int64_t>> getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, StringRef targetIteratorTypeName, MLIRContext *context) { std::vector<std::pair<int64_t, int64_t>> dimMap; for (auto it : llvm::enumerate(iteratorTypes)) { auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); if (iteratorTypeName != targetIteratorTypeName) continue; // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), context); int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); if (lhsDim >= 0 && rhsDim >= 0) dimMap.push_back({lhsDim, rhsDim}); } return dimMap; } void ContractionOp::getIterationBounds( SmallVectorImpl<int64_t> &iterationBounds) { auto lhsShape = getLhsType().getShape(); auto resVectorType = getResultType().dyn_cast<VectorType>(); SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); SmallVector<int64_t, 2> iterationShape; for (auto it : llvm::enumerate(iterator_types())) { // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), getContext()); auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); if (iteratorTypeName == getReductionIteratorTypeName()) { // Get reduction dim size from lhs shape (same size in rhsShape). int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); assert(lhsDimIndex >= 0); iterationBounds.push_back(lhsShape[lhsDimIndex]); continue; } // Get parallel dimension size from result shape. int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); assert(resDimIndex >= 0); assert(resVectorType != nullptr); iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); } } void ContractionOp::getIterationIndexMap( std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { unsigned numMaps = indexing_maps().getValue().size(); iterationIndexMap.resize(numMaps); for (auto it : llvm::enumerate(indexing_maps())) { auto index = it.index(); auto map = it.value().cast<AffineMapAttr>().getValue(); for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { auto dim = map.getResult(i).cast<AffineDimExpr>(); iterationIndexMap[index][dim.getPosition()] = i; } } } std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); return getDimMap(indexingMaps, iterator_types(), getReductionIteratorTypeName(), getContext()); } std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); return getDimMap(indexingMaps, iterator_types(), getParallelIteratorTypeName(), getContext()); } SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { return llvm::to_vector<4>( llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) { return mapAttr.cast<AffineMapAttr>().getValue(); })); } Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { SmallVector<int64_t, 4> shape; getIterationBounds(shape); return shape; } //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, Value source, Value position) { result.addOperands({source, position}); result.addTypes(source.getType().cast<VectorType>().getElementType()); } void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t position) { Value pos = builder.create<ConstantIntOp>(result.location, position, 32); build(builder, result, source, pos); } static LogicalResult verify(vector::ExtractElementOp op) { VectorType vectorType = op.getVectorType(); if (vectorType.getRank() != 1) return op.emitOpError("expected 1-D vector"); return success(); } //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// static Type inferExtractOpResultType(VectorType vectorType, ArrayAttr position) { if (static_cast<int64_t>(position.size()) == vectorType.getRank()) return vectorType.getElementType(); return VectorType::get(vectorType.getShape().drop_front(position.size()), vectorType.getElementType()); } void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, ArrayRef<int64_t> position) { result.addOperands(source); auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(), positionAttr)); result.addAttribute(getPositionAttrName(), positionAttr); } // Convenience builder which assumes the values are constant indices. void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, ValueRange position) { SmallVector<int64_t, 4> positionConstants = llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { return pos.getDefiningOp<ConstantIndexOp>().getValue(); })); build(builder, result, source, positionConstants); } static void print(OpAsmPrinter &p, vector::ExtractOp op) { p << op.getOperationName() << " " << op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); p << " : " << op.vector().getType(); } static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc attributeLoc, typeLoc; NamedAttrList attrs; OpAsmParser::OperandType vector; Type type; Attribute attr; if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) || parser.parseAttribute(attr, "position", attrs) || parser.parseOptionalAttrDict(attrs) || parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type)) return failure(); auto vectorType = type.dyn_cast<VectorType>(); if (!vectorType) return parser.emitError(typeLoc, "expected vector type"); auto positionAttr = attr.dyn_cast<ArrayAttr>(); if (!positionAttr || static_cast<int64_t>(positionAttr.size()) > vectorType.getRank()) return parser.emitError( attributeLoc, "expected position attribute of rank smaller than vector rank"); Type resType = inferExtractOpResultType(vectorType, positionAttr); result.attributes = attrs; return failure(parser.resolveOperand(vector, type, result.operands) || parser.addTypeToList(resType, result.types)); } static LogicalResult verify(vector::ExtractOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank())) return op.emitOpError( "expected position attribute of rank smaller than vector rank"); for (auto en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast<IntegerAttr>(); if (!attr || attr.getInt() < 0 || attr.getInt() >= op.getVectorType().getDimSize(en.index())) return op.emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " "vector dimension"; } return success(); } template <typename IntType> static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>(llvm::map_range( arrayAttr.getAsRange<IntegerAttr>(), [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); } /// Fold the result of chains of ExtractOp in place by simply concatenating the /// positions. static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { if (!extractOp.vector().getDefiningOp<ExtractOp>()) return failure(); SmallVector<int64_t, 4> globalPosition; ExtractOp currentOp = extractOp; auto extractedPos = extractVector<int64_t>(currentOp.position()); globalPosition.append(extractedPos.rbegin(), extractedPos.rend()); while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) { currentOp = nextOp; auto extractedPos = extractVector<int64_t>(currentOp.position()); globalPosition.append(extractedPos.rbegin(), extractedPos.rend()); } extractOp.setOperand(currentOp.vector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); extractOp.setAttr(ExtractOp::getPositionAttrName(), b.getI64ArrayAttr(globalPosition)); return success(); } /// Fold the result of an ExtractOp in place when it comes from a TransposeOp. static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) { auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>(); if (!transposeOp) return failure(); auto permutation = extractVector<unsigned>(transposeOp.transp()); auto extractedPos = extractVector<int64_t>(extractOp.position()); // If transposition permutation is larger than the ExtractOp, all minor // dimensions must be an identity for folding to occur. If not, individual // elements within the extracted value are transposed and this is not just a // simple folding. unsigned minorRank = permutation.size() - extractedPos.size(); MLIRContext *ctx = extractOp.getContext(); AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx); AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); if (minorMap && !minorMap.isMinorIdentity()) return failure(); // %1 = transpose %0[x, y, z] : vector<axbxcxf32> // %2 = extract %1[u, v] : vector<..xf32> // may turn into: // %2 = extract %0[w, x] : vector<..xf32> // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and // -1 denotes the inverse. permutationMap = permutationMap.getMajorSubMap(extractedPos.size()); // The major submap has fewer results but the same number of dims. To compose // cleanly, we need to drop dims to form a "square matrix". This is possible // because: // (a) this is a permutation map and // (b) the minor map has already been checked to be identity. // Therefore, the major map cannot contain dims of position greater or equal // than the number of results. assert(llvm::all_of(permutationMap.getResults(), [&](AffineExpr e) { auto dim = e.dyn_cast<AffineDimExpr>(); return dim && dim.getPosition() < permutationMap.getNumResults(); }) && "Unexpected map results depend on higher rank positions"); // Project on the first domain dimensions to allow composition. permutationMap = AffineMap::get(permutationMap.getNumResults(), 0, permutationMap.getResults(), ctx); extractOp.setOperand(transposeOp.vector()); // Compose the inverse permutation map with the extractedPos. auto newExtractedPos = inversePermutation(permutationMap).compose(extractedPos); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setAttr(ExtractOp::getPositionAttrName(), b.getI64ArrayAttr(newExtractedPos)); return success(); } /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The /// result is always the input to some InsertOp. static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) { MLIRContext *context = extractOp.getContext(); AffineMap permutationMap; auto extractedPos = extractVector<unsigned>(extractOp.position()); // Walk back a chain of InsertOp/TransposeOp until we hit a match. // Compose TransposeOp permutations as we walk back. auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>(); auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>(); while (insertOp || transposeOp) { if (transposeOp) { // If it is transposed, compose the map and iterate. auto permutation = extractVector<unsigned>(transposeOp.transp()); AffineMap newMap = AffineMap::getPermutationMap(permutation, context); if (!permutationMap) permutationMap = newMap; else if (newMap.getNumInputs() != permutationMap.getNumResults()) return Value(); else permutationMap = newMap.compose(permutationMap); // Compute insert/transpose for the next iteration. Value transposed = transposeOp.vector(); insertOp = transposed.getDefiningOp<vector::InsertOp>(); transposeOp = transposed.getDefiningOp<vector::TransposeOp>(); continue; } assert(insertOp); Value insertionDest = insertOp.dest(); // If it is inserted into, either the position matches and we have a // successful folding; or we iterate until we run out of // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector` // produces a new vector with 1 modified value/slice in exactly the static // position we need to match. auto insertedPos = extractVector<unsigned>(insertOp.position()); // Trivial permutations are solved with position equality checks. if (!permutationMap || permutationMap.isIdentity()) { if (extractedPos == insertedPos) return insertOp.source(); // Fallthrough: if the position does not match, just skip to the next // producing `vector.insert` / `vector.transpose`. // Compute insert/transpose for the next iteration. insertOp = insertionDest.getDefiningOp<vector::InsertOp>(); transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>(); continue; } // More advanced permutations require application of the permutation. // However, the rank of `insertedPos` may be different from that of the // `permutationMap`. To support such case, we need to: // 1. apply on the `insertedPos.size()` major dimensions // 2. check the other dimensions of the permutation form a minor identity. assert(permutationMap.isPermutation() && "expected a permutation"); if (insertedPos.size() == extractedPos.size()) { bool fold = true; for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) { auto pos = permutationMap.getDimPosition(idx); if (pos >= sz || insertedPos[pos] != extractedPos[idx]) { fold = false; break; } } if (fold) { assert(permutationMap.getNumResults() >= insertedPos.size() && "expected map of rank larger than insert indexing"); unsigned minorRank = permutationMap.getNumResults() - insertedPos.size(); AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); if (!minorMap || minorMap.isMinorIdentity()) return insertOp.source(); } } // If we haven't found a match, just continue to the next producing // `vector.insert` / `vector.transpose`. // Compute insert/transpose for the next iteration. insertOp = insertionDest.getDefiningOp<vector::InsertOp>(); transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>(); } return Value(); } /// Fold extractOp with scalar result coming from BroadcastOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>(); if (!broadcastOp) return Value(); if (extractOp.getType() == broadcastOp.getSourceType()) return broadcastOp.source(); auto getRank = [](Type type) { return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; }; unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType()); unsigned extractResultRank = getRank(extractOp.getType()); if (extractResultRank < broadcasrSrcRank) { auto extractPos = extractVector<int64_t>(extractOp.position()); unsigned rankDiff = broadcasrSrcRank - extractResultRank; extractPos.erase( extractPos.begin(), std::next(extractPos.begin(), extractPos.size() - rankDiff)); extractOp.setOperand(broadcastOp.source()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setAttr(ExtractOp::getPositionAttrName(), b.getI64ArrayAttr(extractPos)); return extractOp.getResult(); } // TODO: In case the rank of the broadcast source is greater than the rank of // the extract result this can be combined into a new broadcast op. This needs // to be added a canonicalization pattern if needed. return Value(); } // Fold extractOp with source coming from ShapeCast op. static Value foldExtractFromShapeCast(ExtractOp extractOp) { auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); if (!shapeCastOp) return Value(); // Get the nth dimension size starting from lowest dimension. auto getDimReverse = [](VectorType type, int64_t n) { return type.getShape().take_back(n+1).front(); }; int64_t destinationRank = extractOp.getType().isa<VectorType>() ? extractOp.getType().cast<VectorType>().getRank() : 0; if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) return Value(); if (destinationRank > 0) { auto destinationType = extractOp.getResult().getType().cast<VectorType>(); for (int64_t i = 0; i < destinationRank; i++) { // The lowest dimension of of the destination must match the lowest // dimension of the shapecast op source. // TODO: This case could be support in a canonicalization pattern. if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != getDimReverse(destinationType, i)) return Value(); } } // Extract the strides associated with the extract op vector source. Then use // this to calculate a linearized position for the extract. auto extractedPos = extractVector<int64_t>(extractOp.position()); std::reverse(extractedPos.begin(), extractedPos.end()); SmallVector<int64_t, 4> strides; int64_t stride = 1; for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { strides.push_back(stride); stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); } int64_t position = linearize(extractedPos, strides); // Then extract the strides associated to the shapeCast op vector source and // delinearize the position using those strides. SmallVector<int64_t, 4> newStrides; int64_t numDimension = shapeCastOp.getSourceVectorType().getRank() - destinationRank; stride = 1; for (int64_t i = 0; i < numDimension; i++) { newStrides.push_back(stride); stride *= getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); } std::reverse(newStrides.begin(), newStrides.end()); SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setAttr(ExtractOp::getPositionAttrName(), b.getI64ArrayAttr(newPosition)); extractOp.setOperand(shapeCastOp.source()); return extractOp.getResult(); } OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (succeeded(foldExtractOpFromTranspose(*this))) return getResult(); if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; if (auto val = foldExtractFromBroadcast(*this)) return val; if (auto val = foldExtractFromShapeCast(*this)) return val; return OpFoldResult(); } //===----------------------------------------------------------------------===// // ExtractSlicesOp //===----------------------------------------------------------------------===// void ExtractSlicesOp::build(OpBuilder &builder, OperationState &result, TupleType tupleType, Value vector, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) { result.addOperands(vector); auto sizesAttr = getVectorSubscriptAttr(builder, sizes); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes(tupleType); result.addAttribute(getSizesAttrName(), sizesAttr); result.addAttribute(getStridesAttrName(), stridesAttr); } static LogicalResult isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType, TupleType tupleType, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) { // Check for non-unit strides. // TODO: Support non-1 strides. if (llvm::any_of(strides, [](int64_t s) { return s != 1; })) return op->emitError("requires unit strides"); // Check that 'vectorType' rank matches rank of tuple element vectors. unsigned rank = vectorType.getRank(); auto is_vector_type_of_rank = [&](Type t) { return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank; }; if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank)) return op->emitError("requires vector tuple elements of rank ") << rank; // Check that 'sizes' and 'strides' are of size == 'rank'. if (sizes.size() != rank || strides.size() != rank) return op->emitError("requires sizes and strides of rank ") << rank; // Generate each slice shape based on 'sizes', 'strides' and 'vectorType', // and verify that the same matches the corresponding tuple element 'i'. auto shape = vectorType.getShape(); auto sliceStrides = computeStrides(shape, sizes); for (int64_t i = 0, e = tupleType.size(); i < e; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets); // Create slice VectorType type. auto sliceVectorType = VectorType::get(sliceSizes, vectorType.getElementType()); // Verify that 'sliceVectorType' matches tupleType.getTypes(i) if (sliceVectorType != tupleType.getType(i)) return op->emitError("invalid tuple element type ") << sliceVectorType; } return success(); } static LogicalResult verify(ExtractSlicesOp op) { SmallVector<int64_t, 4> sizes; op.getSizes(sizes); SmallVector<int64_t, 4> strides; op.getStrides(strides); return isValidExtractOrInsertSlicesType( op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(), sizes, strides); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl<int64_t> &results) { for (auto attr : arrayAttr) results.push_back(attr.cast<IntegerAttr>().getInt()); } void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(sizes(), results); } void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(strides(), results); } //===----------------------------------------------------------------------===// // ExtractMapOp //===----------------------------------------------------------------------===// void ExtractMapOp::build(OpBuilder &builder, OperationState &result, Value vector, ValueRange ids, ArrayRef<int64_t> multiplicity, AffineMap permutationMap) { assert(ids.size() == multiplicity.size() && ids.size() == permutationMap.getNumResults()); assert(permutationMap.isProjectedPermutation()); VectorType type = vector.getType().cast<VectorType>(); SmallVector<int64_t, 4> newShape(type.getShape().begin(), type.getShape().end()); for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) { AffineExpr expr = permutationMap.getResult(i); auto dim = expr.cast<AffineDimExpr>(); newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i]; } VectorType resultType = VectorType::get(newShape, type.getElementType()); ExtractMapOp::build(builder, result, resultType, vector, ids); } static LogicalResult verify(ExtractMapOp op) { if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) return op.emitOpError( "expected source and destination vectors of same rank"); unsigned numId = 0; for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) { if (op.getSourceVectorType().getDimSize(i) % op.getResultType().getDimSize(i) != 0) return op.emitOpError("source vector dimensions must be a multiple of " "destination vector dimensions"); if (op.getSourceVectorType().getDimSize(i) != op.getResultType().getDimSize(i)) numId++; } if (numId != op.ids().size()) return op.emitOpError("expected number of ids must match the number of " "dimensions distributed"); return success(); } OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) { auto insert = vector().getDefiningOp<vector::InsertMapOp>(); if (insert == nullptr || getType() != insert.vector().getType() || ids() != insert.ids()) return {}; return insert.vector(); } void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) { assert(multiplicity.empty()); for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) { if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) multiplicity.push_back(getSourceVectorType().getDimSize(i) / getResultType().getDimSize(i)); } } template <typename MapOp> AffineMap calculateImplicitMap(MapOp op) { SmallVector<AffineExpr, 4> perm; // Check which dimension have a multiplicity greater than 1 and associated // them to the IDs in order. for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) { if (op.getSourceVectorType().getDimSize(i) != op.getResultType().getDimSize(i)) perm.push_back(getAffineDimExpr(i, op.getContext())); } auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm, op.getContext()); return map; } AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// static LogicalResult verify(BroadcastOp op) { VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>(); VectorType dstVectorType = op.getVectorType(); // Scalar to vector broadcast is always valid. A vector // to vector broadcast needs some additional checking. if (srcVectorType) { int64_t srcRank = srcVectorType.getRank(); int64_t dstRank = dstVectorType.getRank(); if (srcRank > dstRank) return op.emitOpError("source rank higher than destination rank"); // Source has an exact match or singleton value for all trailing dimensions // (all leading dimensions are simply duplicated). int64_t lead = dstRank - srcRank; for (int64_t r = 0; r < srcRank; ++r) { int64_t srcDim = srcVectorType.getDimSize(r); int64_t dstDim = dstVectorType.getDimSize(lead + r); if (srcDim != 1 && srcDim != dstDim) return op.emitOpError("dimension mismatch (") << srcDim << " vs. " << dstDim << ")"; } } return success(); } OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { if (!operands[0]) return {}; auto vectorType = getVectorType(); if (operands[0].getType().isIntOrIndexOrFloat()) return DenseElementsAttr::get(vectorType, operands[0]); if (auto attr = operands[0].dyn_cast<SplatElementsAttr>()) return DenseElementsAttr::get(vectorType, attr.getSplatValue()); return {}; } //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, Value v2, ArrayRef<int64_t> mask) { result.addOperands({v1, v2}); auto maskAttr = getVectorSubscriptAttr(builder, mask); result.addTypes(v1.getType()); result.addAttribute(getMaskAttrName(), maskAttr); } static void print(OpAsmPrinter &p, ShuffleOp op) { p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " " << op.mask(); p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()}); p << " : " << op.v1().getType() << ", " << op.v2().getType(); } static LogicalResult verify(ShuffleOp op) { VectorType resultType = op.getVectorType(); VectorType v1Type = op.getV1VectorType(); VectorType v2Type = op.getV2VectorType(); // Verify ranks. int64_t resRank = resultType.getRank(); int64_t v1Rank = v1Type.getRank(); int64_t v2Rank = v2Type.getRank(); if (resRank != v1Rank || v1Rank != v2Rank) return op.emitOpError("rank mismatch"); // Verify all but leading dimension sizes. for (int64_t r = 1; r < v1Rank; ++r) { int64_t resDim = resultType.getDimSize(r); int64_t v1Dim = v1Type.getDimSize(r); int64_t v2Dim = v2Type.getDimSize(r); if (resDim != v1Dim || v1Dim != v2Dim) return op.emitOpError("dimension mismatch"); } // Verify mask length. auto maskAttr = op.mask().getValue(); int64_t maskLength = maskAttr.size(); if (maskLength != resultType.getDimSize(0)) return op.emitOpError("mask length mismatch"); // Verify all indices. int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); for (auto en : llvm::enumerate(maskAttr)) { auto attr = en.value().dyn_cast<IntegerAttr>(); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) return op.emitOpError("mask index #") << (en.index() + 1) << " out of range"; } return success(); } static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType v1, v2; Attribute attr; VectorType v1Type, v2Type; if (parser.parseOperand(v1) || parser.parseComma() || parser.parseOperand(v2) || parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(v1Type) || parser.parseComma() || parser.parseType(v2Type) || parser.resolveOperand(v1, v1Type, result.operands) || parser.resolveOperand(v2, v2Type, result.operands)) return failure(); // Construct resulting type: leading dimension matches mask length, // all trailing dimensions match the operands. auto maskAttr = attr.dyn_cast<ArrayAttr>(); if (!maskAttr) return parser.emitError(parser.getNameLoc(), "missing mask attribute"); int64_t maskLength = maskAttr.size(); if (maskLength <= 0) return parser.emitError(parser.getNameLoc(), "invalid mask length"); int64_t v1Rank = v1Type.getRank(); SmallVector<int64_t, 4> shape; shape.reserve(v1Rank); shape.push_back(maskLength); for (int64_t r = 1; r < v1Rank; ++r) shape.push_back(v1Type.getDimSize(r)); VectorType resType = VectorType::get(shape, v1Type.getElementType()); parser.addTypeToList(resType, result.types); return success(); } //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// void InsertElementOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, Value position) { result.addOperands({source, dest, position}); result.addTypes(dest.getType()); } void InsertElementOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, int64_t position) { Value pos = builder.create<ConstantIntOp>(result.location, position, 32); build(builder, result, source, dest, pos); } static LogicalResult verify(InsertElementOp op) { auto dstVectorType = op.getDestVectorType(); if (dstVectorType.getRank() != 1) return op.emitOpError("expected 1-D vector"); return success(); } //===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ArrayRef<int64_t> position) { result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(dest.getType()); result.addAttribute(getPositionAttrName(), positionAttr); } // Convenience builder which assumes the values are constant indices. void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ValueRange position) { SmallVector<int64_t, 4> positionConstants = llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { return pos.getDefiningOp<ConstantIndexOp>().getValue(); })); build(builder, result, source, dest, positionConstants); } static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); auto destVectorType = op.getDestVectorType(); if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank())) return op.emitOpError( "expected position attribute of rank smaller than dest vector rank"); auto srcVectorType = op.getSourceType().dyn_cast<VectorType>(); if (srcVectorType && (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) return op.emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); else if (!srcVectorType && (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) return op.emitOpError( "expected position attribute rank to match the dest vector rank"); for (auto en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast<IntegerAttr>(); if (!attr || attr.getInt() < 0 || attr.getInt() >= destVectorType.getDimSize(en.index())) return op.emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " "dest vector dimension"; } return success(); } //===----------------------------------------------------------------------===// // InsertSlicesOp //===----------------------------------------------------------------------===// static LogicalResult verify(InsertSlicesOp op) { SmallVector<int64_t, 4> sizes; op.getSizes(sizes); SmallVector<int64_t, 4> strides; op.getStrides(strides); return isValidExtractOrInsertSlicesType( op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(), sizes, strides); } void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(sizes(), results); } void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(strides(), results); } //===----------------------------------------------------------------------===// // InsertMapOp //===----------------------------------------------------------------------===// void InsertMapOp::build(OpBuilder &builder, OperationState &result, Value vector, Value dest, ValueRange ids) { InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); } static LogicalResult verify(InsertMapOp op) { if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) return op.emitOpError( "expected source and destination vectors of same rank"); unsigned numId = 0; for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) { if (op.getResultType().getDimSize(i) % op.getSourceVectorType().getDimSize(i) != 0) return op.emitOpError( "destination vector size must be a multiple of source vector size"); if (op.getResultType().getDimSize(i) != op.getSourceVectorType().getDimSize(i)) numId++; } if (numId != op.ids().size()) return op.emitOpError("expected number of ids must match the number of " "dimensions distributed"); return success(); } AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); } //===----------------------------------------------------------------------===// // InsertStridedSliceOp //===----------------------------------------------------------------------===// void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ArrayRef<int64_t> offsets, ArrayRef<int64_t> strides) { result.addOperands({source, dest}); auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes(dest.getType()); result.addAttribute(getOffsetsAttrName(), offsetsAttr); result.addAttribute(getStridesAttrName(), stridesAttr); } // TODO: Should be moved to Tablegen Confined attributes. template <typename OpType> static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef<int64_t> shape, StringRef attrName) { if (arrayAttr.size() > shape.size()) return op.emitOpError("expected ") << attrName << " attribute of rank smaller than vector rank"; return success(); } // Returns true if all integers in `arrayAttr` are in the half-open [min, max} // interval. If `halfOpen` is true then the admissible interval is [min, max). // Otherwise, the admissible interval is [min, max]. template <typename OpType> static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen = true) { for (auto attr : arrayAttr) { auto val = attr.cast<IntegerAttr>().getInt(); auto upper = max; if (!halfOpen) upper += 1; if (val < min || val >= upper) return op.emitOpError("expected ") << attrName << " to be confined to [" << min << ", " << upper << ")"; } return success(); } // Returns true if all integers in `arrayAttr` are in the half-open [min, max} // interval. If `halfOpen` is true then the admissible interval is [min, max). // Otherwise, the admissible interval is [min, max]. template <typename OpType> static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef<int64_t> shape, StringRef attrName, bool halfOpen = true, int64_t min = 0) { assert(arrayAttr.size() <= shape.size()); unsigned index = 0; for (auto it : llvm::zip(arrayAttr, shape)) { auto val = std::get<0>(it).cast<IntegerAttr>().getInt(); auto max = std::get<1>(it); if (!halfOpen) max += 1; if (val < min || val >= max) return op.emitOpError("expected ") << attrName << " dimension " << index << " to be confined to [" << min << ", " << max << ")"; ++index; } return success(); } // Returns true if all integers in `arrayAttr` are in the interval [min, max}. // interval. If `halfOpen` is true then the admissible interval is [min, max). // Otherwise, the admissible interval is [min, max]. template <typename OpType> static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, bool halfOpen = true, int64_t min = 1) { assert(arrayAttr1.size() <= shape.size()); assert(arrayAttr2.size() <= shape.size()); unsigned index = 0; for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) { auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt(); auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt(); auto max = std::get<2>(it); if (!halfOpen) max += 1; if (val1 + val2 < 0 || val1 + val2 >= max) return op.emitOpError("expected sum(") << attrName1 << ", " << attrName2 << ") dimension " << index << " to be confined to [" << min << ", " << max << ")"; ++index; } return success(); } static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, MLIRContext *context) { auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v)); }); return ArrayAttr::get(llvm::to_vector<8>(attrs), context); } static LogicalResult verify(InsertStridedSliceOp op) { auto sourceVectorType = op.getSourceVectorType(); auto destVectorType = op.getDestVectorType(); auto offsets = op.offsets(); auto strides = op.strides(); if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) return op.emitOpError( "expected offsets of same size as destination vector rank"); if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) return op.emitOpError( "expected strides of same size as source vector rank"); if (sourceVectorType.getRank() > destVectorType.getRank()) return op.emitOpError( "expected source rank to be smaller than destination rank"); auto sourceShape = sourceVectorType.getShape(); auto destShape = destVectorType.getShape(); SmallVector<int64_t, 4> sourceShapeAsDestShape( destShape.size() - sourceShape.size(), 0); sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); auto offName = InsertStridedSliceOp::getOffsetsAttrName(); auto stridesName = InsertStridedSliceOp::getStridesAttrName(); if (failed( isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) || failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, /*halfOpen=*/false)) || failed(isSumOfIntegerArrayAttrConfinedToShape( op, offsets, makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape, offName, "source vector shape", /*halfOpen=*/false, /*min=*/1))) return failure(); return success(); } //===----------------------------------------------------------------------===// // OuterProductOp //===----------------------------------------------------------------------===// /// Build an op without mask, use the type of `acc` as the return type. void OuterProductOp::build(OpBuilder &builder, OperationState &result, Value lhs, Value rhs, Value acc) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); } static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); if (!op.acc().empty()) p << ", " << op.acc(); p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); } static ParseResult parseOuterProductOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 3> operandsInfo; Type tLHS, tRHS; if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) || parser.parseComma() || parser.parseType(tRHS)) return failure(); if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), "expected at least 2 operands"); VectorType vLHS = tLHS.dyn_cast<VectorType>(); VectorType vRHS = tRHS.dyn_cast<VectorType>(); if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); VectorType resType = vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, vLHS.getElementType()) : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); return failure( parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || (operandsInfo.size() > 2 && parser.resolveOperand(operandsInfo[2], resType, result.operands)) || parser.addTypeToList(resType, result.types)); } static LogicalResult verify(OuterProductOp op) { Type tRHS = op.getOperandTypeRHS(); VectorType vLHS = op.getOperandVectorTypeLHS(), vRHS = tRHS.dyn_cast<VectorType>(), vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); if (vLHS.getRank() != 1) return op.emitOpError("expected 1-d vector for operand #1"); if (vRHS) { // Proper OUTER operation. if (vRHS.getRank() != 1) return op.emitOpError("expected 1-d vector for operand #2"); if (vRES.getRank() != 2) return op.emitOpError("expected 2-d vector result"); if (vLHS.getDimSize(0) != vRES.getDimSize(0)) return op.emitOpError("expected #1 operand dim to match result dim #1"); if (vRHS.getDimSize(0) != vRES.getDimSize(1)) return op.emitOpError("expected #2 operand dim to match result dim #2"); } else { // An AXPY operation. if (vRES.getRank() != 1) return op.emitOpError("expected 1-d vector result"); if (vLHS.getDimSize(0) != vRES.getDimSize(0)) return op.emitOpError("expected #1 operand dim to match result dim #1"); } if (vACC && vACC != vRES) return op.emitOpError("expected operand #3 of same type as result type"); return success(); } //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// static LogicalResult verify(ReshapeOp op) { // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. auto inputVectorType = op.getInputVectorType(); auto outputVectorType = op.getOutputVectorType(); int64_t inputShapeRank = op.getNumInputShapeSizes(); int64_t outputShapeRank = op.getNumOutputShapeSizes(); SmallVector<int64_t, 4> fixedVectorSizes; op.getFixedVectorSizes(fixedVectorSizes); int64_t numFixedVectorSizes = fixedVectorSizes.size(); if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) return op.emitError("invalid input shape for vector type ") << inputVectorType; if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) return op.emitError("invalid output shape for vector type ") << outputVectorType; // Verify that the 'fixedVectorSizes' match an input/output vector shape // suffix. unsigned inputVectorRank = inputVectorType.getRank(); for (unsigned i = 0; i < numFixedVectorSizes; ++i) { unsigned index = inputVectorRank - numFixedVectorSizes - i; if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) return op.emitError("fixed vector size must match input vector for dim ") << i; } unsigned outputVectorRank = outputVectorType.getRank(); for (unsigned i = 0; i < numFixedVectorSizes; ++i) { unsigned index = outputVectorRank - numFixedVectorSizes - i; if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) return op.emitError("fixed vector size must match output vector for dim ") << i; } // If all shape operands are produced by constant ops, verify that product // of dimensions for input/output shape match. auto isDefByConstant = [](Value operand) { return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp()); }; if (llvm::all_of(op.input_shape(), isDefByConstant) && llvm::all_of(op.output_shape(), isDefByConstant)) { int64_t numInputElements = 1; for (auto operand : op.input_shape()) numInputElements *= cast<ConstantIndexOp>(operand.getDefiningOp()).getValue(); int64_t numOutputElements = 1; for (auto operand : op.output_shape()) numOutputElements *= cast<ConstantIndexOp>(operand.getDefiningOp()).getValue(); if (numInputElements != numOutputElements) return op.emitError("product of input and output shape sizes must match"); } return success(); } void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(fixed_vector_sizes(), results); } //===----------------------------------------------------------------------===// // ExtractStridedSliceOp //===----------------------------------------------------------------------===// // Inference works as follows: // 1. Add 'sizes' from prefix of dims in 'offsets'. // 2. Add sizes from 'vectorType' for remaining dims. static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides) { assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); SmallVector<int64_t, 4> shape; shape.reserve(vectorType.getRank()); unsigned idx = 0; for (unsigned e = offsets.size(); idx < e; ++idx) shape.push_back(sizes[idx].cast<IntegerAttr>().getInt()); for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) shape.push_back(vectorType.getShape()[idx]); return VectorType::get(shape, vectorType.getElementType()); } void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, Value source, ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) { result.addOperands(source); auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); auto sizesAttr = getVectorSubscriptAttr(builder, sizes); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes( inferStridedSliceOpResultType(source.getType().cast<VectorType>(), offsetsAttr, sizesAttr, stridesAttr)); result.addAttribute(getOffsetsAttrName(), offsetsAttr); result.addAttribute(getSizesAttrName(), sizesAttr); result.addAttribute(getStridesAttrName(), stridesAttr); } static LogicalResult verify(ExtractStridedSliceOp op) { auto type = op.getVectorType(); auto offsets = op.offsets(); auto sizes = op.sizes(); auto strides = op.strides(); if (offsets.size() != sizes.size() || offsets.size() != strides.size()) { op.emitOpError( "expected offsets, sizes and strides attributes of same size"); return failure(); } auto shape = type.getShape(); auto offName = ExtractStridedSliceOp::getOffsetsAttrName(); auto sizesName = ExtractStridedSliceOp::getSizesAttrName(); auto stridesName = ExtractStridedSliceOp::getStridesAttrName(); if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) || failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) || failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape, stridesName)) || failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) || failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, /*halfOpen=*/false)) || failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape, offName, sizesName, /*halfOpen=*/false))) return failure(); auto resultType = inferStridedSliceOpResultType( op.getVectorType(), op.offsets(), op.sizes(), op.strides()); if (op.getResult().getType() != resultType) { op.emitOpError("expected result type to be ") << resultType; return failure(); } return success(); } // When the source of ExtractStrided comes from a chain of InsertStrided ops try // to use the source of the InsertStrided ops if we can detect that the // extracted vector is a subset of one of the vector inserted. static LogicalResult foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { // Helper to extract integer out of ArrayAttr. auto getElement = [](ArrayAttr array, int idx) { return array[idx].cast<IntegerAttr>().getInt(); }; ArrayAttr extractOffsets = op.offsets(); ArrayAttr extractStrides = op.strides(); ArrayAttr extractSizes = op.sizes(); auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); while (insertOp) { if (op.getVectorType().getRank() != insertOp.getSourceVectorType().getRank()) return failure(); ArrayAttr insertOffsets = insertOp.offsets(); ArrayAttr insertStrides = insertOp.strides(); // If the rank of extract is greater than the rank of insert, we are likely // extracting a partial chunk of the vector inserted. if (extractOffsets.size() > insertOffsets.size()) return failure(); bool patialoverlap = false; bool disjoint = false; SmallVector<int64_t, 4> offsetDiffs; for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) return failure(); int64_t start = getElement(insertOffsets, dim); int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); int64_t offset = getElement(extractOffsets, dim); int64_t size = getElement(extractSizes, dim); // Check if the start of the extract offset is in the interval inserted. if (start <= offset && offset < end) { // If the extract interval overlaps but is not fully included we may // have a partial overlap that will prevent any folding. if (offset + size > end) patialoverlap = true; offsetDiffs.push_back(offset - start); continue; } disjoint = true; break; } // The extract element chunk is a subset of the insert element. if (!disjoint && !patialoverlap) { op.setOperand(insertOp.source()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(), b.getI64ArrayAttr(offsetDiffs)); return success(); } // If the chunk extracted is disjoint from the chunk inserted, keep looking // in the insert chain. if (disjoint) insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); else { // The extracted vector partially overlap the inserted vector, we cannot // fold. return failure(); } } return failure(); } OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) { if (getVectorType() == getResult().getType()) return vector(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); return {}; } void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(offsets(), results); } namespace { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp. class StridedSliceConstantMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { // Return if 'extractStridedSliceOp' operand is not defined by a // ConstantMaskOp. auto defOp = extractStridedSliceOp.vector().getDefiningOp(); auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); if (!constantMaskOp) return failure(); // Return if 'extractStridedSliceOp' has non-unit strides. if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) { return attr.cast<IntegerAttr>().getInt() != 1; })) return failure(); // Gather constant mask dimension sizes. SmallVector<int64_t, 4> maskDimSizes; populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); // Gather strided slice offsets and sizes. SmallVector<int64_t, 4> sliceOffsets; populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets); SmallVector<int64_t, 4> sliceSizes; populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes); // Compute slice of vector mask region. SmallVector<int64_t, 4> sliceMaskDimSizes; assert(sliceOffsets.size() == maskDimSizes.size()); for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { int64_t maskDimSize = std::get<0>(it); int64_t sliceOffset = std::get<1>(it); int64_t sliceSize = std::get<2>(it); int64_t sliceMaskDimSize = std::max( static_cast<int64_t>(0), std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); sliceMaskDimSizes.push_back(sliceMaskDimSize); } // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked // region is a conjunction of mask dim intervals). if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; })) sliceMaskDimSizes.assign(maskDimSizes.size(), 0); // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask // region. rewriter.replaceOpWithNewOp<ConstantMaskOp>( extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); return success(); } }; // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. class StridedSliceConstantFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { // Return if 'extractStridedSliceOp' operand is not defined by a // ConstantOp. auto constantOp = extractStridedSliceOp.vector().getDefiningOp<ConstantOp>(); if (!constantOp) return failure(); auto dense = constantOp.value().dyn_cast<SplatElementsAttr>(); if (!dense) return failure(); auto newAttr = DenseElementsAttr::get( extractStridedSliceOp.getType().cast<VectorType>(), dense.getSplatValue()); rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr); return success(); } }; } // end anonymous namespace void ExtractStridedSliceOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>( context); } //===----------------------------------------------------------------------===// // TransferReadOp //===----------------------------------------------------------------------===// template <typename EmitFun> static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); for (auto expr : permutationMap.getResults()) { auto dim = expr.dyn_cast<AffineDimExpr>(); auto zero = expr.dyn_cast<AffineConstantExpr>(); if (zero) { if (zero.getValue() != 0) { return emitOpError( "requires a projected permutation_map (at most one dim or the zero " "constant can appear in each result)"); } continue; } if (!dim) { return emitOpError("requires a projected permutation_map (at most one " "dim or the zero constant can appear in each result)"); } if (seen[dim.getPosition()]) { return emitOpError( "requires a permutation_map that is a permutation (found one dim " "used more than once)"); } seen[dim.getPosition()] = true; } return success(); } static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, VectorType vectorType, AffineMap permutationMap, ArrayAttr optionalMasked) { auto memrefElementType = memrefType.getElementType(); if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) { // Memref has vector element type. unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() * memrefVectorElementType.getShape().back(); unsigned resultVecSize = vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); if (resultVecSize % memrefVecSize != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " "multiple of the bitwidth of the minor 1-D vector of the memref"); unsigned memrefVecEltRank = memrefVectorElementType.getRank(); unsigned resultVecRank = vectorType.getRank(); if (memrefVecEltRank > resultVecRank) return op->emitOpError( "requires memref vector element and vector result ranks to match."); unsigned rankOffset = resultVecRank - memrefVecEltRank; // Check that permutation map results match 'rankOffset' of vector type. if (permutationMap.getNumResults() != rankOffset) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); } else { // Memref has scalar element type. unsigned resultVecSize = vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " "multiple of the bitwidth of the memref element type"); // Check that permutation map results match rank of vector type. if (permutationMap.getNumResults() != vectorType.getRank()) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); } if (permutationMap.getNumSymbols() != 0) return op->emitOpError("requires permutation_map without symbols"); if (permutationMap.getNumInputs() != memrefType.getRank()) return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the memref type"); if (optionalMasked) { if (permutationMap.getNumResults() != static_cast<int64_t>(optionalMasked.size())) return op->emitOpError("expects the optional masked attr of same rank as " "permutation_map results: ") << AffineMapAttr::get(permutationMap); } return success(); } /// Builder that sets padding to zero. void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vector, Value memref, ValueRange indices, AffineMap permutationMap, ArrayRef<bool> maybeMasked) { Type elemType = memref.getType().cast<MemRefType>().getElementType(); Value padding = builder.create<ConstantOp>(result.location, elemType, builder.getZeroAttr(elemType)); if (maybeMasked.empty()) return build(builder, result, vector, memref, indices, permutationMap, padding, ArrayAttr()); ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); build(builder, result, vector, memref, indices, permutationMap, padding, maskedArrayAttr); } /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap' /// (resp. zero). void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vectorType, Value memref, ValueRange indices, ArrayRef<bool> maybeMasked) { auto permMap = getTransferMinorIdentityMap( memref.getType().cast<MemRefType>(), vectorType); build(builder, result, vectorType, memref, indices, permMap, maybeMasked); } static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector<StringRef, 2> elidedAttrs; if (op.permutation_map() == getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType())) elidedAttrs.push_back(op.getPermutationMapAttrName()); bool elideMasked = true; if (auto maybeMasked = op.masked()) { for (auto attr : *maybeMasked) { if (!attr.template cast<BoolAttr>().getValue()) { elideMasked = false; break; } } } if (elideMasked) elidedAttrs.push_back(op.getMaskedAttrName()); p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); } static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.memref() << "[" << op.indices() << "], " << op.padding(); printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation())); p << " : " << op.getMemRefType() << ", " << op.getVectorType(); } static ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc typesLoc; OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 8> indexInfo; OpAsmParser::OperandType paddingInfo; SmallVector<Type, 2> types; // Parsing with support for paddingValue. if (parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(paddingInfo) || parser.parseOptionalAttrDict(result.attributes) || parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = parser.getBuilder().getIndexType(); MemRefType memRefType = types[0].dyn_cast<MemRefType>(); if (!memRefType) return parser.emitError(typesLoc, "requires memref type"); VectorType vectorType = types[1].dyn_cast<VectorType>(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { auto permMap = getTransferMinorIdentityMap(memRefType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } return failure( parser.resolveOperand(memrefInfo, memRefType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || parser.resolveOperand(paddingInfo, memRefType.getElementType(), result.operands) || parser.addTypeToList(vectorType, result.types)); } static LogicalResult verify(TransferReadOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); auto paddingType = op.padding().getType(); auto permutationMap = op.permutation_map(); auto memrefElementType = memrefType.getElementType(); if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank()) return op.emitOpError("requires ") << memrefType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, permutationMap, op.masked() ? *op.masked() : ArrayAttr()))) return failure(); if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) { // Memref has vector element type. // Check that 'memrefVectorElementType' and 'paddingType' types match. if (memrefVectorElementType != paddingType) return op.emitOpError( "requires memref element type and padding type to match."); } else { // Check that 'paddingType' is valid to store in a vector type. if (!VectorType::isValidElementType(paddingType)) return op.emitOpError("requires valid padding vector elemental type"); // Check that padding type and vector element types match. if (paddingType != memrefElementType) return op.emitOpError( "requires formal padding and memref of the same elemental type"); } return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast) -> someop /// ``` /// It folds the source of the memref_cast into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp<MemRefCastOp>(); if (castOp && canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } template <typename TransferOp> static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // TODO: support more aggressive createOrFold on: // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)` if (op.getMemRefType().isDynamicDim(indicesIdx)) return false; Value index = op.indices()[indicesIdx]; auto cstOp = index.getDefiningOp<ConstantIndexOp>(); if (!cstOp) return false; int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx); int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); return cstOp.getValue() + vectorSize <= memrefSize; } template <typename TransferOp> static LogicalResult foldTransferMaskAttribute(TransferOp op) { AffineMap permutationMap = op.permutation_map(); if (!permutationMap.isMinorIdentity()) return failure(); bool changed = false; SmallVector<bool, 4> isMasked; isMasked.reserve(op.getTransferRank()); op.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { // Already marked unmasked, nothing to see here. if (!op.isMaskedDim(resultIdx)) { isMasked.push_back(false); return; } // Currently masked, check whether we can statically determine it is // inBounds. auto inBounds = isInBounds(op, resultIdx, indicesIdx); isMasked.push_back(!inBounds); // We commit the pattern if it is "more inbounds". changed |= inBounds; }); if (!changed) return failure(); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked)); return success(); } OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { /// transfer_read(memrefcast) -> transfer_read if (succeeded(foldTransferMaskAttribute(*this))) return getResult(); if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { auto s = getVectorType().getShape(); return SmallVector<int64_t, 4>{s.begin(), s.end()}; } //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// /// Builder that sets permutation map to 'getMinorIdentityMap'. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value memref, ValueRange indices, ArrayRef<bool> maybeMasked) { auto vectorType = vector.getType().cast<VectorType>(); auto permMap = getTransferMinorIdentityMap( memref.getType().cast<MemRefType>(), vectorType); if (maybeMasked.empty()) return build(builder, result, vector, memref, indices, permMap, ArrayAttr()); ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); build(builder, result, vector, memref, indices, permMap, maskedArrayAttr); } void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value memref, ValueRange indices, AffineMap permutationMap) { build(builder, result, vector, memref, indices, permutationMap, /*maybeMasked=*/ArrayAttr()); } static ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc typesLoc; OpAsmParser::OperandType vectorInfo, memrefInfo; SmallVector<OpAsmParser::OperandType, 8> indexInfo; SmallVector<Type, 2> types; if (parser.parseOperand(vectorInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = parser.getBuilder().getIndexType(); VectorType vectorType = types[0].dyn_cast<VectorType>(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); MemRefType memRefType = types[1].dyn_cast<MemRefType>(); if (!memRefType) return parser.emitError(typesLoc, "requires memref type"); auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { auto permMap = getTransferMinorIdentityMap(memRefType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } return failure( parser.resolveOperand(vectorInfo, vectorType, result.operands) || parser.resolveOperand(memrefInfo, memRefType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands)); } static void print(OpAsmPrinter &p, TransferWriteOp op) { p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "[" << op.indices() << "]"; printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation())); p << " : " << op.getVectorType() << ", " << op.getMemRefType(); } static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); auto permutationMap = op.permutation_map(); if (llvm::size(op.indices()) != memrefType.getRank()) return op.emitOpError("requires ") << memrefType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, permutationMap, op.masked() ? *op.masked() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) { if (succeeded(foldTransferMaskAttribute(*this))) return success(); return foldMemRefCast(*this); } Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// static LogicalResult verify(MaskedLoadOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); VectorType resVType = op.getResultVectorType(); if (resVType.getElementType() != op.getMemRefType().getElementType()) return op.emitOpError("base and result element type should match"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (resVType != passVType) return op.emitOpError("expected pass_thru of same type as result type"); return success(); } namespace { class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { public: using OpRewritePattern<MaskedLoadOp>::OpRewritePattern; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { Value newBase; switch (get1DMaskFormat(load.mask())) { case MaskFormat::AllTrue: if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(), load.getResultVectorType(), rewriter, newBase)) return failure(); rewriter.replaceOpWithNewOp<LoadOp>(load, newBase); return success(); case MaskFormat::AllFalse: rewriter.replaceOp(load, load.pass_thru()); return success(); case MaskFormat::Unknown: return failure(); } llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad"); } }; } // namespace void MaskedLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<MaskedLoadFolder>(context); } //===----------------------------------------------------------------------===// // MaskedStoreOp //===----------------------------------------------------------------------===// static LogicalResult verify(MaskedStoreOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); if (valueVType.getElementType() != op.getMemRefType().getElementType()) return op.emitOpError("base and value element type should match"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); } namespace { class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { public: using OpRewritePattern<MaskedStoreOp>::OpRewritePattern; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { Value newBase; switch (get1DMaskFormat(store.mask())) { case MaskFormat::AllTrue: if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(), store.getValueVectorType(), rewriter, newBase)) return failure(); rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(store); return success(); case MaskFormat::Unknown: return failure(); } llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore"); } }; } // namespace void MaskedStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<MaskedStoreFolder>(context); } //===----------------------------------------------------------------------===// // GatherOp //===----------------------------------------------------------------------===// static LogicalResult verify(GatherOp op) { VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType resVType = op.getResultVectorType(); if (resVType.getElementType() != op.getMemRefType().getElementType()) return op.emitOpError("base and result element type should match"); if (resVType.getDimSize(0) != indicesVType.getDimSize(0)) return op.emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (llvm::size(op.pass_thru()) != 0) { VectorType passVType = op.getPassThruVectorType(); if (resVType != passVType) return op.emitOpError("expected pass_thru of same type as result type"); } return success(); } namespace { class GatherFolder final : public OpRewritePattern<GatherOp> { public: using OpRewritePattern<GatherOp>::OpRewritePattern; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (get1DMaskFormat(gather.mask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent case MaskFormat::AllFalse: rewriter.replaceOp(gather, gather.pass_thru()); return success(); case MaskFormat::Unknown: return failure(); } llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); } }; } // namespace void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<GatherFolder>(context); } //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// static LogicalResult verify(ScatterOp op) { VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); if (valueVType.getElementType() != op.getMemRefType().getElementType()) return op.emitOpError("base and value element type should match"); if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) return op.emitOpError("expected value dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); } namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: using OpRewritePattern<ScatterOp>::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (get1DMaskFormat(scatter.mask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent case MaskFormat::AllFalse: rewriter.eraseOp(scatter); return success(); case MaskFormat::Unknown: return failure(); } llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); } }; } // namespace void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<ScatterFolder>(context); } //===----------------------------------------------------------------------===// // ExpandLoadOp //===----------------------------------------------------------------------===// static LogicalResult verify(ExpandLoadOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); VectorType resVType = op.getResultVectorType(); if (resVType.getElementType() != op.getMemRefType().getElementType()) return op.emitOpError("base and result element type should match"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (resVType != passVType) return op.emitOpError("expected pass_thru of same type as result type"); return success(); } namespace { class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { public: using OpRewritePattern<ExpandLoadOp>::OpRewritePattern; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { Value newBase; switch (get1DMaskFormat(expand.mask())) { case MaskFormat::AllTrue: if (!castedToMemRef(expand.getLoc(), expand.base(), expand.getMemRefType(), expand.getResultVectorType(), rewriter, newBase)) return failure(); rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase); return success(); case MaskFormat::AllFalse: rewriter.replaceOp(expand, expand.pass_thru()); return success(); case MaskFormat::Unknown: return failure(); } llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder"); } }; } // namespace void ExpandLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<ExpandLoadFolder>(context); } //===----------------------------------------------------------------------===// // CompressStoreOp //===----------------------------------------------------------------------===// static LogicalResult verify(CompressStoreOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); if (valueVType.getElementType() != op.getMemRefType().getElementType()) return op.emitOpError("base and value element type should match"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); } namespace { class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { public: using OpRewritePattern<CompressStoreOp>::OpRewritePattern; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { Value newBase; switch (get1DMaskFormat(compress.mask())) { case MaskFormat::AllTrue: if (!castedToMemRef(compress.getLoc(), compress.base(), compress.getMemRefType(), compress.getValueVectorType(), rewriter, newBase)) return failure(); rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(compress); return success(); case MaskFormat::Unknown: return failure(); } llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder"); } }; } // namespace void CompressStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<CompressStoreFolder>(context); } //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// /// Returns true if each element of 'a' is equal to the product of a contiguous /// sequence of the elements of 'b'. Returns false otherwise. static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { unsigned rankA = a.size(); unsigned rankB = b.size(); assert(rankA < rankB); unsigned i = 0; unsigned j = 0; while (i < rankA && j < rankB) { int64_t dimA = a[i]; int64_t dimB = 1; while (dimB < dimA && j < rankB) dimB *= b[j++]; if (dimA != dimB) break; ++i; // Handle the case when trailing dimensions are of size 1. // Include them into the contiguous sequence. auto isOne = [](int64_t v) { return v == 1; }; if (i < rankA && llvm::all_of(a.slice(i), isOne)) i = rankA; if (j < rankB && llvm::all_of(b.slice(j), isOne)) j = rankB; } return i == rankA && j == rankB; } static LogicalResult verifyVectorShapeCast(Operation *op, VectorType sourceVectorType, VectorType resultVectorType) { // Check that element type is the same. if (sourceVectorType.getElementType() != resultVectorType.getElementType()) return op->emitOpError("source/result vectors must have same element type"); auto sourceShape = sourceVectorType.getShape(); auto resultShape = resultVectorType.getShape(); // Check that product of source dim sizes matches product of result dim sizes. int64_t sourceDimProduct = std::accumulate( sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); int64_t resultDimProduct = std::accumulate( resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); if (sourceDimProduct != resultDimProduct) return op->emitOpError("source/result number of elements must match"); // Check that expanding/contracting rank cases. unsigned sourceRank = sourceVectorType.getRank(); unsigned resultRank = resultVectorType.getRank(); if (sourceRank < resultRank) { if (!isValidShapeCast(sourceShape, resultShape)) return op->emitOpError("invalid shape cast"); } else if (sourceRank > resultRank) { if (!isValidShapeCast(resultShape, sourceShape)) return op->emitOpError("invalid shape cast"); } return success(); } static LogicalResult verify(ShapeCastOp op) { auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>(); auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>(); // Check if source/result are of vector type. if (sourceVectorType && resultVectorType) return verifyVectorShapeCast(op, sourceVectorType, resultVectorType); // Check if source/result are "tuple of vectors" type. auto sourceTupleType = op.source().getType().dyn_cast_or_null<TupleType>(); auto resultTupleType = op.result().getType().dyn_cast_or_null<TupleType>(); if (!sourceTupleType || !resultTupleType) return op.emitOpError("source/result must be of same type"); // Check that source/result tuple sizes are the same. if (sourceTupleType.size() != resultTupleType.size()) return op.emitOpError("source/result tuples must be the same size"); // Check each source/result tuple element pair. for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) if (failed(verifyVectorShapeCast( op, sourceTupleType.getType(i).cast<VectorType>(), resultTupleType.getType(i).cast<VectorType>()))) return failure(); return success(); } OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { // Nop shape cast. if (source().getType() == result().getType()) return source(); // Canceling shape casts. if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) if (result().getType() == otherOp.source().getType()) return otherOp.source(); return {}; } namespace { // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { public: using OpRewritePattern<ShapeCastOp>::OpRewritePattern; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>(); if (!constantOp) return failure(); // Only handle splat for now. auto dense = constantOp.value().dyn_cast<SplatElementsAttr>(); if (!dense) return failure(); auto newAttr = DenseElementsAttr::get( shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue()); rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr); return success(); } }; } // namespace void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. results.insert<ShapeCastConstantFolder>(context); } //===----------------------------------------------------------------------===// // VectorBitCastOp //===----------------------------------------------------------------------===// static LogicalResult verify(BitCastOp op) { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) return op.emitOpError("dimension size mismatch at: ") << i; } if (sourceVectorType.getElementTypeBitWidth() * sourceVectorType.getShape().back() != resultVectorType.getElementTypeBitWidth() * resultVectorType.getShape().back()) return op.emitOpError( "source/result bitwidth of the minor 1-D vectors must be equal"); return success(); } OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) { // Nop cast. if (source().getType() == result().getType()) return source(); // Canceling bitcasts. if (auto otherOp = source().getDefiningOp<BitCastOp>()) if (result().getType() == otherOp.source().getType()) return otherOp.source(); return {}; } //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { auto vectorType = memRefType.getElementType().dyn_cast<VectorType>(); SmallVector<int64_t, 8> res(memRefType.getShape().begin(), memRefType.getShape().end()); if (vectorType) res.append(vectorType.getShape().begin(), vectorType.getShape().end()); return res; } /// Build the canonical memRefType with a single vector. /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. void TypeCastOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands(source); MemRefType memRefType = source.getType().cast<MemRefType>(); VectorType vectorType = VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); result.addTypes( MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace())); } static LogicalResult verify(TypeCastOp op) { MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); if (!canonicalType.getAffineMaps().empty()) return op.emitOpError("expects operand to be a memref with no layout"); if (!op.getResultMemRefType().getAffineMaps().empty()) return op.emitOpError("expects result to be a memref with no layout"); if (op.getResultMemRefType().getMemorySpace() != op.getMemRefType().getMemorySpace()) return op.emitOpError("expects result in same memory space"); auto sourceType = op.getMemRefType(); auto resultType = op.getResultMemRefType(); if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != getElementTypeOrSelf(getElementTypeOrSelf(resultType))) return op.emitOpError( "expects result and operand with same underlying scalar type: ") << resultType; if (extractShape(sourceType) != extractShape(resultType)) return op.emitOpError( "expects concatenated result and operand shapes to be equal: ") << resultType; return success(); } //===----------------------------------------------------------------------===// // TupleOp //===----------------------------------------------------------------------===// static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 4> operandInfos; SmallVector<Type, 4> types; auto loc = parser.getCurrentLocation(); auto *ctx = parser.getBuilder().getContext(); return failure( parser.parseOperandList(operandInfos) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types) || parser.resolveOperands(operandInfos, types, loc, result.operands) || parser.addTypeToList(TupleType::get(types, ctx), result.types)); } static void print(OpAsmPrinter &p, TupleOp op) { p << op.getOperationName() << ' '; p.printOperands(op.getOperands()); p.printOptionalAttrDict(op.getAttrs()); p << " : "; llvm::interleaveComma(op->getOperandTypes(), p); } static LogicalResult verify(TupleOp op) { return success(); } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, Value vector, ArrayRef<int64_t> transp) { VectorType vt = vector.getType().cast<VectorType>(); SmallVector<int64_t, 4> transposedShape(vt.getRank()); for (unsigned i = 0; i < transp.size(); ++i) transposedShape[i] = vt.getShape()[transp[i]]; result.addOperands(vector); result.addTypes(VectorType::get(transposedShape, vt.getElementType())); result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); } // Eliminates transpose operations, which produce values identical to their // input values. This happens when the dimensions of the input vector remain in // their original order after the transpose operation. OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) { SmallVector<int64_t, 4> transp; getTransp(transp); // Check if the permutation of the dimensions contains sequential values: // {0, 1, 2, ...}. for (int64_t i = 0, e = transp.size(); i < e; i++) { if (transp[i] != i) return {}; } return vector(); } static LogicalResult verify(vector::TransposeOp op) { VectorType vectorType = op.getVectorType(); VectorType resultType = op.getResultType(); int64_t rank = resultType.getRank(); if (vectorType.getRank() != rank) return op.emitOpError("vector result rank mismatch: ") << rank; // Verify transposition array. auto transpAttr = op.transp().getValue(); int64_t size = transpAttr.size(); if (rank != size) return op.emitOpError("transposition length mismatch: ") << size; SmallVector<bool, 8> seen(rank, false); for (auto ta : llvm::enumerate(transpAttr)) { int64_t i = ta.value().cast<IntegerAttr>().getInt(); if (i < 0 || i >= rank) return op.emitOpError("transposition index out of range: ") << i; if (seen[i]) return op.emitOpError("duplicate position index: ") << i; seen[i] = true; if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) return op.emitOpError("dimension size mismatch at: ") << i; } return success(); } namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { public: using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { // Wrapper around vector::TransposeOp::getTransp() for cleaner code. auto getPermutation = [](vector::TransposeOp transpose) { SmallVector<int64_t, 4> permutation; transpose.getTransp(permutation); return permutation; }; // Composes two permutations: result[i] = permutation1[permutation2[i]]. auto composePermutations = [](ArrayRef<int64_t> permutation1, ArrayRef<int64_t> permutation2) { SmallVector<int64_t, 4> result; for (auto index : permutation2) result.push_back(permutation1[index]); return result; }; // Return if the input of 'transposeOp' is not defined by another transpose. vector::TransposeOp parentTransposeOp = transposeOp.vector().getDefiningOp<vector::TransposeOp>(); if (!parentTransposeOp) return failure(); SmallVector<int64_t, 4> permutation = composePermutations( getPermutation(parentTransposeOp), getPermutation(transposeOp)); // Replace 'transposeOp' with a new transpose operation. rewriter.replaceOpWithNewOp<vector::TransposeOp>( transposeOp, transposeOp.getResult().getType(), parentTransposeOp.vector(), vector::getVectorSubscriptAttr(rewriter, permutation)); return success(); } }; } // end anonymous namespace void vector::TransposeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<TransposeFolder>(context); } void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) { populateFromInt64AttrArray(transp(), results); } //===----------------------------------------------------------------------===// // TupleGetOp //===----------------------------------------------------------------------===// static ParseResult parseTupleGetOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; StringRef indexAttrName = TupleGetOp::getIndexAttrName(); Type indexType = parser.getBuilder().getIndexType(); TupleType tupleType; if (parser.parseOperand(operandInfo) || parser.parseComma() || parser.parseAttribute(indexAttr, indexType, indexAttrName, result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(tupleType) || parser.resolveOperand(operandInfo, tupleType, result.operands)) return failure(); if (indexAttr.getInt() < 0 || indexAttr.getInt() >= static_cast<int64_t>(tupleType.size())) return failure(); parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types); return success(); } static void print(OpAsmPrinter &p, TupleGetOp op) { p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()}); p << " : " << op.getOperand().getType(); } static LogicalResult verify(TupleGetOp op) { auto tupleType = op.getOperand().getType().cast<TupleType>(); if (op.getIndex() < 0 || op.getIndex() >= static_cast<int64_t>(tupleType.size())) return op.emitOpError("tuple get index out of range"); return success(); } OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) { // Rewrite: // %t = vector.tuple .., %e_i, .. // %x = vector.tuple_get %t, i // into: // %t = vector.tuple .., %e_i, .. // one less use // %x = %e_i if (auto tupleOp = getOperand().getDefiningOp<TupleOp>()) return tupleOp.getOperand(getIndex()); return {}; } //===----------------------------------------------------------------------===// // ConstantMaskOp //===----------------------------------------------------------------------===// static LogicalResult verify(ConstantMaskOp &op) { // Verify that array attr size matches the rank of the vector result. auto resultType = op.getResult().getType().cast<VectorType>(); if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank()) return op.emitOpError( "must specify array attr of size equal vector result rank"); // Verify that each array attr element is in bounds of corresponding vector // result dimension size. auto resultShape = resultType.getShape(); SmallVector<int64_t, 4> maskDimSizes; for (auto it : llvm::enumerate(op.mask_dim_sizes())) { int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); if (attrValue < 0 || attrValue > resultShape[it.index()]) return op.emitOpError( "array attr of size out of bounds of vector result dimension size"); maskDimSizes.push_back(attrValue); } // Verify that if one mask dim size is zero, they all should be zero (because // the mask region is a conjunction of each mask dimension interval). bool any_zeros = llvm::is_contained(maskDimSizes, 0); bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); if (any_zeros && !all_zeros) return op.emitOpError("expected all mask dim sizes to be zeros, " "as a result of conjunction with zero mask dim"); return success(); } //===----------------------------------------------------------------------===// // CreateMaskOp //===----------------------------------------------------------------------===// static LogicalResult verify(CreateMaskOp op) { // Verify that an operand was specified for each result vector each dimension. if (op.getNumOperands() != op.getResult().getType().cast<VectorType>().getRank()) return op.emitOpError( "must specify an operand for each result vector dimension"); return success(); } namespace { // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { public: using OpRewritePattern<CreateMaskOp>::OpRewritePattern; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { // Return if any of 'createMaskOp' operands are not defined by a constant. auto is_not_def_by_constant = [](Value operand) { return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp()); }; if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant)) return failure(); // Gather constant mask dimension sizes. SmallVector<int64_t, 4> maskDimSizes; for (auto operand : createMaskOp.operands()) { auto defOp = operand.getDefiningOp(); maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue()); } // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp<ConstantMaskOp>( createMaskOp, createMaskOp.getResult().getType(), vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); return success(); } }; } // end anonymous namespace void CreateMaskOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<CreateMaskFolder>(context); } void mlir::vector::populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, ScatterFolder, ExpandLoadFolder, CompressStoreFolder, StridedSliceConstantMaskFolder, TransposeFolder>(context); } #define GET_OP_CLASSES #include "mlir/Dialect/Vector/VectorOps.cpp.inc"