1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/Dialect/Vector/VectorUtils.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Support/MathExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include <numeric>
29 
30 using namespace mlir;
31 using namespace mlir::vector;
32 
33 /// Helper enum to classify mask value.
34 enum class MaskFormat {
35   AllTrue = 0,
36   AllFalse = 1,
37   Unknown = 2,
38 };
39 
40 /// Helper method to classify a 1-D mask value. Currently, the method
41 /// looks "under the hood" of a constant value with dense attributes
42 /// and a constant mask operation (since the client may be called at
43 /// various stages during progressive lowering).
get1DMaskFormat(Value mask)44 static MaskFormat get1DMaskFormat(Value mask) {
45   if (auto c = mask.getDefiningOp<ConstantOp>()) {
46     // Inspect constant dense values. We count up for bits that
47     // are set, count down for bits that are cleared, and bail
48     // when a mix is detected.
49     if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
50       int64_t val = 0;
51       for (bool b : denseElts.getValues<bool>())
52         if (b && val >= 0)
53           val++;
54         else if (!b && val <= 0)
55           val--;
56         else
57           return MaskFormat::Unknown;
58       if (val > 0)
59         return MaskFormat::AllTrue;
60       if (val < 0)
61         return MaskFormat::AllFalse;
62     }
63   } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
64     // Inspect constant mask index. If the index exceeds the
65     // dimension size, all bits are set. If the index is zero
66     // or less, no bits are set.
67     ArrayAttr masks = m.mask_dim_sizes();
68     assert(masks.size() == 1);
69     int64_t i = masks[0].cast<IntegerAttr>().getInt();
70     int64_t u = m.getType().cast<VectorType>().getDimSize(0);
71     if (i >= u)
72       return MaskFormat::AllTrue;
73     if (i <= 0)
74       return MaskFormat::AllFalse;
75   }
76   return MaskFormat::Unknown;
77 }
78 
79 /// Helper method to cast a 1-D memref<10xf32> "base" into a
80 /// memref<vector<10xf32>> in the output parameter "newBase",
81 /// using the 'element' vector type "vt". Returns true on success.
castedToMemRef(Location loc,Value base,MemRefType mt,VectorType vt,PatternRewriter & rewriter,Value & newBase)82 static bool castedToMemRef(Location loc, Value base, MemRefType mt,
83                            VectorType vt, PatternRewriter &rewriter,
84                            Value &newBase) {
85   // The vector.type_cast operation does not accept unknown memref<?xf32>.
86   // TODO: generalize the cast and accept this case too
87   if (!mt.hasStaticShape())
88     return false;
89   newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
90   return true;
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // VectorDialect
95 //===----------------------------------------------------------------------===//
96 
initialize()97 void VectorDialect::initialize() {
98   addOperations<
99 #define GET_OP_LIST
100 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
101       >();
102 }
103 
104 /// Materialize a single constant operation from a given attribute value with
105 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)106 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
107                                               Attribute value, Type type,
108                                               Location loc) {
109   return builder.create<ConstantOp>(loc, type, value);
110 }
111 
getVectorSubscriptType(Builder & builder)112 IntegerType vector::getVectorSubscriptType(Builder &builder) {
113   return builder.getIntegerType(64);
114 }
115 
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)116 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
117                                          ArrayRef<int64_t> values) {
118   return builder.getI64ArrayAttr(values);
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // ReductionOp
123 //===----------------------------------------------------------------------===//
124 
verify(ReductionOp op)125 static LogicalResult verify(ReductionOp op) {
126   // Verify for 1-D vector.
127   int64_t rank = op.getVectorType().getRank();
128   if (rank != 1)
129     return op.emitOpError("unsupported reduction rank: ") << rank;
130 
131   // Verify supported reduction kind.
132   auto kind = op.kind();
133   Type eltType = op.dest().getType();
134   if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
135     if (!eltType.isIntOrIndexOrFloat())
136       return op.emitOpError("unsupported reduction type");
137   } else if (kind == "and" || kind == "or" || kind == "xor") {
138     if (!eltType.isIntOrIndex())
139       return op.emitOpError("unsupported reduction type");
140   } else {
141     return op.emitOpError("unknown reduction kind: ") << kind;
142   }
143 
144   // Verify optional accumulator.
145   if (!op.acc().empty()) {
146     if (kind != "add" && kind != "mul")
147       return op.emitOpError("no accumulator for reduction kind: ") << kind;
148     if (!eltType.isa<FloatType>())
149       return op.emitOpError("no accumulator for type: ") << eltType;
150   }
151 
152   return success();
153 }
154 
parseReductionOp(OpAsmParser & parser,OperationState & result)155 static ParseResult parseReductionOp(OpAsmParser &parser,
156                                     OperationState &result) {
157   SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
158   Type redType;
159   Type resType;
160   Attribute attr;
161   if (parser.parseAttribute(attr, "kind", result.attributes) ||
162       parser.parseComma() || parser.parseOperandList(operandsInfo) ||
163       parser.parseColonType(redType) ||
164       parser.parseKeywordType("into", resType) ||
165       (operandsInfo.size() > 0 &&
166        parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
167       (operandsInfo.size() > 1 &&
168        parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
169       parser.addTypeToList(resType, result.types))
170     return failure();
171   if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
172     return parser.emitError(parser.getNameLoc(),
173                             "unsupported number of operands");
174   return success();
175 }
176 
print(OpAsmPrinter & p,ReductionOp op)177 static void print(OpAsmPrinter &p, ReductionOp op) {
178   p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector();
179   if (!op.acc().empty())
180     p << ", " << op.acc();
181   p << " : " << op.vector().getType() << " into " << op.dest().getType();
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // ContractionOp
186 //===----------------------------------------------------------------------===//
187 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)188 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
189                                   Value lhs, Value rhs, Value acc,
190                                   ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
191                                   ArrayRef<StringRef> iteratorTypes) {
192   result.addOperands({lhs, rhs, acc});
193   result.addTypes(acc.getType());
194   result.addAttribute(getIndexingMapsAttrName(),
195                       builder.getAffineMapArrayAttr(
196                           AffineMap::inferFromExprList(indexingExprs)));
197   result.addAttribute(getIteratorTypesAttrName(),
198                       builder.getStrArrayAttr(iteratorTypes));
199 }
200 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)201 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
202                                   Value lhs, Value rhs, Value acc,
203                                   ArrayAttr indexingMaps,
204                                   ArrayAttr iteratorTypes) {
205   result.addOperands({lhs, rhs, acc});
206   result.addTypes(acc.getType());
207   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
208   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
209 }
210 
parseContractionOp(OpAsmParser & parser,OperationState & result)211 static ParseResult parseContractionOp(OpAsmParser &parser,
212                                       OperationState &result) {
213   OpAsmParser::OperandType lhsInfo;
214   OpAsmParser::OperandType rhsInfo;
215   OpAsmParser::OperandType accInfo;
216   SmallVector<OpAsmParser::OperandType, 2> masksInfo;
217   SmallVector<Type, 2> types;
218   Type resultType;
219   auto loc = parser.getCurrentLocation();
220   DictionaryAttr dictAttr;
221   // TODO: Unify linalg op attribute parsing.
222   if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
223       parser.parseOperand(lhsInfo) || parser.parseComma() ||
224       parser.parseOperand(rhsInfo) || parser.parseComma() ||
225       parser.parseOperand(accInfo) ||
226       parser.parseTrailingOperandList(masksInfo) ||
227       parser.parseOptionalAttrDict(result.attributes) ||
228       parser.parseColonTypeList(types) ||
229       parser.parseKeywordType("into", resultType) ||
230       parser.resolveOperand(lhsInfo, types[0], result.operands) ||
231       parser.resolveOperand(rhsInfo, types[1], result.operands) ||
232       parser.resolveOperand(accInfo, resultType, result.operands) ||
233       parser.addTypeToList(resultType, result.types))
234     return failure();
235   result.attributes.assign(dictAttr.getValue().begin(),
236                            dictAttr.getValue().end());
237   if (masksInfo.empty())
238     return success();
239   if (masksInfo.size() != 2)
240     return parser.emitError(parser.getNameLoc(),
241                             "expected zero or exactly 2 vector mask operands");
242   auto lhsType = types[0].cast<VectorType>();
243   auto rhsType = types[1].cast<VectorType>();
244   auto maskElementType = parser.getBuilder().getI1Type();
245   std::array<Type, 2> maskTypes = {
246       VectorType::get(lhsType.getShape(), maskElementType),
247       VectorType::get(rhsType.getShape(), maskElementType)};
248   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
249     return failure();
250   return success();
251 }
252 
print(OpAsmPrinter & p,ContractionOp op)253 static void print(OpAsmPrinter &p, ContractionOp op) {
254   // TODO: Unify printing code with linalg ops.
255   auto attrNames = op.getTraitAttrNames();
256   llvm::StringSet<> traitAttrsSet;
257   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
258   SmallVector<NamedAttribute, 8> attrs;
259   for (auto attr : op.getAttrs())
260     if (traitAttrsSet.count(attr.first.strref()) > 0)
261       attrs.push_back(attr);
262 
263   auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
264   p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
265   p << op.rhs() << ", " << op.acc();
266   if (op.masks().size() == 2)
267     p << ", " << op.masks();
268 
269   p.printOptionalAttrDict(op.getAttrs(), attrNames);
270   p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
271     << op.getResultType();
272 }
273 
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)274 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
275                          const std::vector<std::pair<int64_t, int64_t>> &map) {
276   for (auto &dimPair : map) {
277     if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
278         dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
279         lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
280       return false;
281   }
282   return true;
283 }
284 
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)285 static LogicalResult verifyOutputShape(
286     ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
287     Type resType,
288     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
289     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
290   DenseSet<int64_t> lhsContractingDimSet;
291   DenseSet<int64_t> rhsContractingDimSet;
292   for (auto &dimPair : contractingDimMap) {
293     lhsContractingDimSet.insert(dimPair.first);
294     rhsContractingDimSet.insert(dimPair.second);
295   }
296   DenseSet<int64_t> rhsBatchDimSet;
297   for (auto &dimPair : batchDimMap)
298     rhsBatchDimSet.insert(dimPair.second);
299 
300   // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
301   SmallVector<int64_t, 4> expectedResultDims;
302   for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
303     if (lhsContractingDimSet.count(i) > 0)
304       continue;
305     expectedResultDims.push_back(lhsType.getDimSize(i));
306   }
307 
308   // Add free dimensions from 'rhsType' to 'expectedResultDims'.
309   for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
310     if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
311       continue;
312     expectedResultDims.push_back(rhsType.getDimSize(i));
313   }
314 
315   // Verify 'expectedResultDims'.
316   if (expectedResultDims.size() == 0) {
317     // No batch or free dimension implies a scalar result.
318     if (resType.isa<VectorType>() || accType.isa<VectorType>())
319       return op.emitOpError("invalid accumulator/result vector shape");
320   } else {
321     // At least one batch or free dimension implies a vector result.
322     auto resVectorType = resType.dyn_cast<VectorType>();
323     auto accVectorType = accType.dyn_cast<VectorType>();
324     if (!resVectorType || !accVectorType)
325       return op.emitOpError("invalid accumulator/result vector shape");
326 
327     // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
328     // types fully define the result vector type. This assumes the affine maps
329     // are well-formed, which must have been verified already.
330     MLIRContext *ctx = op.getContext();
331     AffineMap lhsMap = op.getIndexingMaps()[0];
332     AffineMap rhsMap = op.getIndexingMaps()[1];
333     SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
334     for (auto pair :
335          {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
336       VectorType v = pair.first;
337       auto map = pair.second;
338       for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
339         unsigned pos = map.getDimPosition(idx);
340         if (!extents[pos])
341           extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
342       }
343     }
344     assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
345            "expected extent along all dimensions.");
346 
347     AffineMap resMap = op.getIndexingMaps()[2];
348     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
349                                      /*symCount=*/0, extents, ctx);
350     // Compose the resMap with the extentsMap, which is a constant map.
351     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
352     assert(llvm::all_of(
353                expectedMap.getResults(),
354                [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
355            "expected constant extent along all dimensions.");
356     // Extract the expected shape and build the type.
357     auto expectedShape = llvm::to_vector<4>(
358         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
359           return e.cast<AffineConstantExpr>().getValue();
360         }));
361     auto expected =
362         VectorType::get(expectedShape, resVectorType.getElementType());
363     if (resVectorType != expected || accVectorType != expected)
364       return op.emitOpError(
365                  "invalid accumulator/result vector shape, expected: ")
366              << expected;
367   }
368   return success();
369 }
370 
verify(ContractionOp op)371 static LogicalResult verify(ContractionOp op) {
372   auto lhsType = op.getLhsType();
373   auto rhsType = op.getRhsType();
374   auto accType = op.getAccType();
375   auto resType = op.getResultType();
376 
377   // Verify that an indexing map was specified for each vector operand.
378   if (op.indexing_maps().size() != 3)
379     return op.emitOpError("expected an indexing map for each vector operand");
380 
381   // Verify that each index map has 'numIterators' inputs, no symbols, and
382   // that the number of map outputs equals the rank of its associated
383   // vector operand.
384   unsigned numIterators = op.iterator_types().getValue().size();
385   for (auto it : llvm::enumerate(op.indexing_maps())) {
386     auto index = it.index();
387     auto map = it.value().cast<AffineMapAttr>().getValue();
388     if (map.getNumSymbols() != 0)
389       return op.emitOpError("expected indexing map ")
390              << index << " to have no symbols";
391     auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
392     unsigned rank = vectorType ? vectorType.getShape().size() : 0;
393     // Verify that the map has the right number of inputs, outputs, and indices.
394     // This also correctly accounts for (..) -> () for rank-0 results.
395     if (map.getNumDims() != numIterators)
396       return op.emitOpError("expected indexing map ")
397              << index << " to have " << numIterators << " number of inputs";
398     if (map.getNumResults() != rank)
399       return op.emitOpError("expected indexing map ")
400              << index << " to have " << rank << " number of outputs";
401     if (!map.isProjectedPermutation())
402       return op.emitOpError("expected indexing map ")
403              << index << " to be a projected permutation of its inputs";
404   }
405 
406   auto contractingDimMap = op.getContractingDimMap();
407   auto batchDimMap = op.getBatchDimMap();
408 
409   // Verify at least one contracting dimension pair was specified.
410   if (contractingDimMap.empty())
411     return op.emitOpError("expected at least one contracting dimension pair");
412 
413   // Verify contracting dimension map was properly constructed.
414   if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
415     return op.emitOpError("invalid contracting dimension map");
416 
417   // Verify batch dimension map was properly constructed.
418   if (!verifyDimMap(lhsType, rhsType, batchDimMap))
419     return op.emitOpError("invalid batch dimension map");
420 
421   // Verify 'accType' and 'resType' shape.
422   if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
423                                contractingDimMap, batchDimMap)))
424     return failure();
425 
426   // Verify that either two vector masks are set or none are set.
427   auto lhsMaskType = op.getLHSVectorMaskType();
428   auto rhsMaskType = op.getRHSVectorMaskType();
429   if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
430     return op.emitOpError("invalid number of vector masks specified");
431   if (lhsMaskType && rhsMaskType) {
432     // Verify mask rank == argument rank.
433     if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
434         rhsMaskType.getShape().size() != rhsType.getShape().size())
435       return op.emitOpError("invalid vector mask rank");
436   }
437   return success();
438 }
439 
getTraitAttrNames()440 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
441   static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
442                                          getIteratorTypesAttrName()};
443   return llvm::makeArrayRef(names);
444 }
445 
getResultIndex(AffineMap map,AffineExpr targetExpr)446 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
447   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
448     if (targetExpr == map.getResult(i))
449       return i;
450   return -1;
451 }
452 
453 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)454 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
455           StringRef targetIteratorTypeName, MLIRContext *context) {
456   std::vector<std::pair<int64_t, int64_t>> dimMap;
457   for (auto it : llvm::enumerate(iteratorTypes)) {
458     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
459     if (iteratorTypeName != targetIteratorTypeName)
460       continue;
461     // Search lhs/rhs map results for 'targetExpr'.
462     auto targetExpr = getAffineDimExpr(it.index(), context);
463     int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
464     int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
465     if (lhsDim >= 0 && rhsDim >= 0)
466       dimMap.push_back({lhsDim, rhsDim});
467   }
468   return dimMap;
469 }
470 
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)471 void ContractionOp::getIterationBounds(
472     SmallVectorImpl<int64_t> &iterationBounds) {
473   auto lhsShape = getLhsType().getShape();
474   auto resVectorType = getResultType().dyn_cast<VectorType>();
475   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
476   SmallVector<int64_t, 2> iterationShape;
477   for (auto it : llvm::enumerate(iterator_types())) {
478     // Search lhs/rhs map results for 'targetExpr'.
479     auto targetExpr = getAffineDimExpr(it.index(), getContext());
480     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
481     if (iteratorTypeName == getReductionIteratorTypeName()) {
482       // Get reduction dim size from lhs shape (same size in rhsShape).
483       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
484       assert(lhsDimIndex >= 0);
485       iterationBounds.push_back(lhsShape[lhsDimIndex]);
486       continue;
487     }
488     // Get parallel dimension size from result shape.
489     int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
490     assert(resDimIndex >= 0);
491     assert(resVectorType != nullptr);
492     iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
493   }
494 }
495 
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)496 void ContractionOp::getIterationIndexMap(
497     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
498   unsigned numMaps = indexing_maps().getValue().size();
499   iterationIndexMap.resize(numMaps);
500   for (auto it : llvm::enumerate(indexing_maps())) {
501     auto index = it.index();
502     auto map = it.value().cast<AffineMapAttr>().getValue();
503     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
504       auto dim = map.getResult(i).cast<AffineDimExpr>();
505       iterationIndexMap[index][dim.getPosition()] = i;
506     }
507   }
508 }
509 
getContractingDimMap()510 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
511   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
512   return getDimMap(indexingMaps, iterator_types(),
513                    getReductionIteratorTypeName(), getContext());
514 }
515 
getBatchDimMap()516 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
517   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
518   return getDimMap(indexingMaps, iterator_types(),
519                    getParallelIteratorTypeName(), getContext());
520 }
521 
getIndexingMaps()522 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
523   return llvm::to_vector<4>(
524       llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
525         return mapAttr.cast<AffineMapAttr>().getValue();
526       }));
527 }
528 
getShapeForUnroll()529 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
530   SmallVector<int64_t, 4> shape;
531   getIterationBounds(shape);
532   return shape;
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // ExtractElementOp
537 //===----------------------------------------------------------------------===//
538 
build(OpBuilder & builder,OperationState & result,Value source,Value position)539 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
540                                      Value source, Value position) {
541   result.addOperands({source, position});
542   result.addTypes(source.getType().cast<VectorType>().getElementType());
543 }
544 
build(OpBuilder & builder,OperationState & result,Value source,int64_t position)545 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
546                                      Value source, int64_t position) {
547   Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
548   build(builder, result, source, pos);
549 }
550 
verify(vector::ExtractElementOp op)551 static LogicalResult verify(vector::ExtractElementOp op) {
552   VectorType vectorType = op.getVectorType();
553   if (vectorType.getRank() != 1)
554     return op.emitOpError("expected 1-D vector");
555   return success();
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // ExtractOp
560 //===----------------------------------------------------------------------===//
561 
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)562 static Type inferExtractOpResultType(VectorType vectorType,
563                                      ArrayAttr position) {
564   if (static_cast<int64_t>(position.size()) == vectorType.getRank())
565     return vectorType.getElementType();
566   return VectorType::get(vectorType.getShape().drop_front(position.size()),
567                          vectorType.getElementType());
568 }
569 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)570 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
571                               Value source, ArrayRef<int64_t> position) {
572   result.addOperands(source);
573   auto positionAttr = getVectorSubscriptAttr(builder, position);
574   result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
575                                            positionAttr));
576   result.addAttribute(getPositionAttrName(), positionAttr);
577 }
578 
579 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)580 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
581                               Value source, ValueRange position) {
582   SmallVector<int64_t, 4> positionConstants =
583       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
584         return pos.getDefiningOp<ConstantIndexOp>().getValue();
585       }));
586   build(builder, result, source, positionConstants);
587 }
588 
print(OpAsmPrinter & p,vector::ExtractOp op)589 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
590   p << op.getOperationName() << " " << op.vector() << op.position();
591   p.printOptionalAttrDict(op.getAttrs(), {"position"});
592   p << " : " << op.vector().getType();
593 }
594 
parseExtractOp(OpAsmParser & parser,OperationState & result)595 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
596   llvm::SMLoc attributeLoc, typeLoc;
597   NamedAttrList attrs;
598   OpAsmParser::OperandType vector;
599   Type type;
600   Attribute attr;
601   if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
602       parser.parseAttribute(attr, "position", attrs) ||
603       parser.parseOptionalAttrDict(attrs) ||
604       parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
605     return failure();
606 
607   auto vectorType = type.dyn_cast<VectorType>();
608   if (!vectorType)
609     return parser.emitError(typeLoc, "expected vector type");
610 
611   auto positionAttr = attr.dyn_cast<ArrayAttr>();
612   if (!positionAttr ||
613       static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
614     return parser.emitError(
615         attributeLoc,
616         "expected position attribute of rank smaller than vector rank");
617 
618   Type resType = inferExtractOpResultType(vectorType, positionAttr);
619   result.attributes = attrs;
620   return failure(parser.resolveOperand(vector, type, result.operands) ||
621                  parser.addTypeToList(resType, result.types));
622 }
623 
verify(vector::ExtractOp op)624 static LogicalResult verify(vector::ExtractOp op) {
625   auto positionAttr = op.position().getValue();
626   if (positionAttr.empty())
627     return op.emitOpError("expected non-empty position attribute");
628   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
629     return op.emitOpError(
630         "expected position attribute of rank smaller than vector rank");
631   for (auto en : llvm::enumerate(positionAttr)) {
632     auto attr = en.value().dyn_cast<IntegerAttr>();
633     if (!attr || attr.getInt() < 0 ||
634         attr.getInt() >= op.getVectorType().getDimSize(en.index()))
635       return op.emitOpError("expected position attribute #")
636              << (en.index() + 1)
637              << " to be a non-negative integer smaller than the corresponding "
638                 "vector dimension";
639   }
640   return success();
641 }
642 
643 template <typename IntType>
extractVector(ArrayAttr arrayAttr)644 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
645   return llvm::to_vector<4>(llvm::map_range(
646       arrayAttr.getAsRange<IntegerAttr>(),
647       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
648 }
649 
650 /// Fold the result of chains of ExtractOp in place by simply concatenating the
651 /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)652 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
653   if (!extractOp.vector().getDefiningOp<ExtractOp>())
654     return failure();
655 
656   SmallVector<int64_t, 4> globalPosition;
657   ExtractOp currentOp = extractOp;
658   auto extractedPos = extractVector<int64_t>(currentOp.position());
659   globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
660   while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
661     currentOp = nextOp;
662     auto extractedPos = extractVector<int64_t>(currentOp.position());
663     globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
664   }
665   extractOp.setOperand(currentOp.vector());
666   // OpBuilder is only used as a helper to build an I64ArrayAttr.
667   OpBuilder b(extractOp.getContext());
668   std::reverse(globalPosition.begin(), globalPosition.end());
669   extractOp.setAttr(ExtractOp::getPositionAttrName(),
670                     b.getI64ArrayAttr(globalPosition));
671   return success();
672 }
673 
674 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
foldExtractOpFromTranspose(ExtractOp extractOp)675 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
676   auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
677   if (!transposeOp)
678     return failure();
679 
680   auto permutation = extractVector<unsigned>(transposeOp.transp());
681   auto extractedPos = extractVector<int64_t>(extractOp.position());
682 
683   // If transposition permutation is larger than the ExtractOp, all minor
684   // dimensions must be an identity for folding to occur. If not, individual
685   // elements within the extracted value are transposed and this is not just a
686   // simple folding.
687   unsigned minorRank = permutation.size() - extractedPos.size();
688   MLIRContext *ctx = extractOp.getContext();
689   AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
690   AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
691   if (minorMap && !minorMap.isMinorIdentity())
692     return failure();
693 
694   //   %1 = transpose %0[x, y, z] : vector<axbxcxf32>
695   //   %2 = extract %1[u, v] : vector<..xf32>
696   // may turn into:
697   //   %2 = extract %0[w, x] : vector<..xf32>
698   // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
699   // -1 denotes the inverse.
700   permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
701   // The major submap has fewer results but the same number of dims. To compose
702   // cleanly, we need to drop dims to form a "square matrix". This is possible
703   // because:
704   //   (a) this is a permutation map and
705   //   (b) the minor map has already been checked to be identity.
706   // Therefore, the major map cannot contain dims of position greater or equal
707   // than the number of results.
708   assert(llvm::all_of(permutationMap.getResults(),
709                       [&](AffineExpr e) {
710                         auto dim = e.dyn_cast<AffineDimExpr>();
711                         return dim && dim.getPosition() <
712                                           permutationMap.getNumResults();
713                       }) &&
714          "Unexpected map results depend on higher rank positions");
715   // Project on the first domain dimensions to allow composition.
716   permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
717                                   permutationMap.getResults(), ctx);
718 
719   extractOp.setOperand(transposeOp.vector());
720   // Compose the inverse permutation map with the extractedPos.
721   auto newExtractedPos =
722       inversePermutation(permutationMap).compose(extractedPos);
723   // OpBuilder is only used as a helper to build an I64ArrayAttr.
724   OpBuilder b(extractOp.getContext());
725   extractOp.setAttr(ExtractOp::getPositionAttrName(),
726                     b.getI64ArrayAttr(newExtractedPos));
727 
728   return success();
729 }
730 
731 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
732 /// result is always the input to some InsertOp.
foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp)733 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
734   MLIRContext *context = extractOp.getContext();
735   AffineMap permutationMap;
736   auto extractedPos = extractVector<unsigned>(extractOp.position());
737   // Walk back a chain of InsertOp/TransposeOp until we hit a match.
738   // Compose TransposeOp permutations as we walk back.
739   auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
740   auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
741   while (insertOp || transposeOp) {
742     if (transposeOp) {
743       // If it is transposed, compose the map and iterate.
744       auto permutation = extractVector<unsigned>(transposeOp.transp());
745       AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
746       if (!permutationMap)
747         permutationMap = newMap;
748       else if (newMap.getNumInputs() != permutationMap.getNumResults())
749         return Value();
750       else
751         permutationMap = newMap.compose(permutationMap);
752       // Compute insert/transpose for the next iteration.
753       Value transposed = transposeOp.vector();
754       insertOp = transposed.getDefiningOp<vector::InsertOp>();
755       transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
756       continue;
757     }
758 
759     assert(insertOp);
760     Value insertionDest = insertOp.dest();
761     // If it is inserted into, either the position matches and we have a
762     // successful folding; or we iterate until we run out of
763     // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
764     // produces a new vector with 1 modified value/slice in exactly the static
765     // position we need to match.
766     auto insertedPos = extractVector<unsigned>(insertOp.position());
767     // Trivial permutations are solved with position equality checks.
768     if (!permutationMap || permutationMap.isIdentity()) {
769       if (extractedPos == insertedPos)
770         return insertOp.source();
771       // Fallthrough: if the position does not match, just skip to the next
772       // producing `vector.insert` / `vector.transpose`.
773       // Compute insert/transpose for the next iteration.
774       insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
775       transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
776       continue;
777     }
778 
779     // More advanced permutations require application of the permutation.
780     // However, the rank of `insertedPos` may be different from that of the
781     // `permutationMap`. To support such case, we need to:
782     //   1. apply on the `insertedPos.size()` major dimensions
783     //   2. check the other dimensions of the permutation form a minor identity.
784     assert(permutationMap.isPermutation() && "expected a permutation");
785     if (insertedPos.size() == extractedPos.size()) {
786       bool fold = true;
787       for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
788         auto pos = permutationMap.getDimPosition(idx);
789         if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
790           fold = false;
791           break;
792         }
793       }
794       if (fold) {
795         assert(permutationMap.getNumResults() >= insertedPos.size() &&
796                "expected map of rank larger than insert indexing");
797         unsigned minorRank =
798             permutationMap.getNumResults() - insertedPos.size();
799         AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
800         if (!minorMap || minorMap.isMinorIdentity())
801           return insertOp.source();
802       }
803     }
804 
805     // If we haven't found a match, just continue to the next producing
806     // `vector.insert` / `vector.transpose`.
807     // Compute insert/transpose for the next iteration.
808     insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
809     transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
810   }
811   return Value();
812 }
813 
814 /// Fold extractOp with scalar result coming from BroadcastOp.
foldExtractFromBroadcast(ExtractOp extractOp)815 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
816   auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
817   if (!broadcastOp)
818     return Value();
819   if (extractOp.getType() == broadcastOp.getSourceType())
820     return broadcastOp.source();
821   auto getRank = [](Type type) {
822     return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
823   };
824   unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
825   unsigned extractResultRank = getRank(extractOp.getType());
826   if (extractResultRank < broadcasrSrcRank) {
827     auto extractPos = extractVector<int64_t>(extractOp.position());
828     unsigned rankDiff = broadcasrSrcRank - extractResultRank;
829     extractPos.erase(
830         extractPos.begin(),
831         std::next(extractPos.begin(), extractPos.size() - rankDiff));
832     extractOp.setOperand(broadcastOp.source());
833     // OpBuilder is only used as a helper to build an I64ArrayAttr.
834     OpBuilder b(extractOp.getContext());
835     extractOp.setAttr(ExtractOp::getPositionAttrName(),
836                       b.getI64ArrayAttr(extractPos));
837     return extractOp.getResult();
838   }
839   // TODO: In case the rank of the broadcast source is greater than the rank of
840   // the extract result this can be combined into a new broadcast op. This needs
841   // to be added a canonicalization pattern if needed.
842   return Value();
843 }
844 
845 // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)846 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
847   auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
848   if (!shapeCastOp)
849     return Value();
850   // Get the nth dimension size starting from lowest dimension.
851   auto getDimReverse = [](VectorType type, int64_t n) {
852     return type.getShape().take_back(n+1).front();
853   };
854   int64_t destinationRank =
855       extractOp.getType().isa<VectorType>()
856           ? extractOp.getType().cast<VectorType>().getRank()
857           : 0;
858   if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
859     return Value();
860   if (destinationRank > 0) {
861     auto destinationType = extractOp.getResult().getType().cast<VectorType>();
862     for (int64_t i = 0; i < destinationRank; i++) {
863       // The lowest dimension of of the destination must match the lowest
864       // dimension of the shapecast op source.
865       // TODO: This case could be support in a canonicalization pattern.
866       if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
867           getDimReverse(destinationType, i))
868         return Value();
869     }
870   }
871   // Extract the strides associated with the extract op vector source. Then use
872   // this to calculate a linearized position for the extract.
873   auto extractedPos = extractVector<int64_t>(extractOp.position());
874   std::reverse(extractedPos.begin(), extractedPos.end());
875   SmallVector<int64_t, 4> strides;
876   int64_t stride = 1;
877   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
878     strides.push_back(stride);
879     stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
880   }
881 
882   int64_t position = linearize(extractedPos, strides);
883   // Then extract the strides associated to the shapeCast op vector source and
884   // delinearize the position using those strides.
885   SmallVector<int64_t, 4> newStrides;
886   int64_t numDimension =
887       shapeCastOp.getSourceVectorType().getRank() - destinationRank;
888   stride = 1;
889   for (int64_t i = 0; i < numDimension; i++) {
890     newStrides.push_back(stride);
891     stride *=
892         getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
893   }
894   std::reverse(newStrides.begin(), newStrides.end());
895   SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
896   // OpBuilder is only used as a helper to build an I64ArrayAttr.
897   OpBuilder b(extractOp.getContext());
898   extractOp.setAttr(ExtractOp::getPositionAttrName(),
899                     b.getI64ArrayAttr(newPosition));
900   extractOp.setOperand(shapeCastOp.source());
901   return extractOp.getResult();
902 }
903 
fold(ArrayRef<Attribute>)904 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
905   if (succeeded(foldExtractOpFromExtractChain(*this)))
906     return getResult();
907   if (succeeded(foldExtractOpFromTranspose(*this)))
908     return getResult();
909   if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
910     return val;
911   if (auto val = foldExtractFromBroadcast(*this))
912     return val;
913   if (auto val = foldExtractFromShapeCast(*this))
914     return val;
915   return OpFoldResult();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // ExtractSlicesOp
920 //===----------------------------------------------------------------------===//
921 
build(OpBuilder & builder,OperationState & result,TupleType tupleType,Value vector,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)922 void ExtractSlicesOp::build(OpBuilder &builder, OperationState &result,
923                             TupleType tupleType, Value vector,
924                             ArrayRef<int64_t> sizes,
925                             ArrayRef<int64_t> strides) {
926   result.addOperands(vector);
927   auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
928   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
929   result.addTypes(tupleType);
930   result.addAttribute(getSizesAttrName(), sizesAttr);
931   result.addAttribute(getStridesAttrName(), stridesAttr);
932 }
933 
934 static LogicalResult
isValidExtractOrInsertSlicesType(Operation * op,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)935 isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
936                                  TupleType tupleType, ArrayRef<int64_t> sizes,
937                                  ArrayRef<int64_t> strides) {
938   // Check for non-unit strides.
939   // TODO: Support non-1 strides.
940   if (llvm::any_of(strides, [](int64_t s) { return s != 1; }))
941     return op->emitError("requires unit strides");
942   // Check that 'vectorType' rank matches rank of tuple element vectors.
943   unsigned rank = vectorType.getRank();
944   auto is_vector_type_of_rank = [&](Type t) {
945     return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank;
946   };
947   if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank))
948     return op->emitError("requires vector tuple elements of rank ") << rank;
949   // Check that 'sizes' and 'strides' are of size == 'rank'.
950   if (sizes.size() != rank || strides.size() != rank)
951     return op->emitError("requires sizes and strides of rank ") << rank;
952 
953   // Generate each slice shape based on 'sizes', 'strides' and 'vectorType',
954   // and verify that the same matches the corresponding tuple element 'i'.
955   auto shape = vectorType.getShape();
956   auto sliceStrides = computeStrides(shape, sizes);
957   for (int64_t i = 0, e = tupleType.size(); i < e; ++i) {
958     auto vectorOffsets = delinearize(sliceStrides, i);
959     auto elementOffsets =
960         computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
961     auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
962     // Create slice VectorType type.
963     auto sliceVectorType =
964         VectorType::get(sliceSizes, vectorType.getElementType());
965     // Verify that 'sliceVectorType' matches tupleType.getTypes(i)
966     if (sliceVectorType != tupleType.getType(i))
967       return op->emitError("invalid tuple element type ") << sliceVectorType;
968   }
969   return success();
970 }
971 
verify(ExtractSlicesOp op)972 static LogicalResult verify(ExtractSlicesOp op) {
973   SmallVector<int64_t, 4> sizes;
974   op.getSizes(sizes);
975   SmallVector<int64_t, 4> strides;
976   op.getStrides(strides);
977   return isValidExtractOrInsertSlicesType(
978       op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(),
979       sizes, strides);
980 }
981 
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)982 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
983                                        SmallVectorImpl<int64_t> &results) {
984   for (auto attr : arrayAttr)
985     results.push_back(attr.cast<IntegerAttr>().getInt());
986 }
987 
getSizes(SmallVectorImpl<int64_t> & results)988 void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
989   populateFromInt64AttrArray(sizes(), results);
990 }
991 
getStrides(SmallVectorImpl<int64_t> & results)992 void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
993   populateFromInt64AttrArray(strides(), results);
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // ExtractMapOp
998 //===----------------------------------------------------------------------===//
999 
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)1000 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1001                          Value vector, ValueRange ids,
1002                          ArrayRef<int64_t> multiplicity,
1003                          AffineMap permutationMap) {
1004   assert(ids.size() == multiplicity.size() &&
1005          ids.size() == permutationMap.getNumResults());
1006   assert(permutationMap.isProjectedPermutation());
1007   VectorType type = vector.getType().cast<VectorType>();
1008   SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1009                                    type.getShape().end());
1010   for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1011     AffineExpr expr = permutationMap.getResult(i);
1012     auto dim = expr.cast<AffineDimExpr>();
1013     newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1014   }
1015   VectorType resultType = VectorType::get(newShape, type.getElementType());
1016   ExtractMapOp::build(builder, result, resultType, vector, ids);
1017 }
1018 
verify(ExtractMapOp op)1019 static LogicalResult verify(ExtractMapOp op) {
1020   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1021     return op.emitOpError(
1022         "expected source and destination vectors of same rank");
1023   unsigned numId = 0;
1024   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1025     if (op.getSourceVectorType().getDimSize(i) %
1026             op.getResultType().getDimSize(i) !=
1027         0)
1028       return op.emitOpError("source vector dimensions must be a multiple of "
1029                             "destination vector dimensions");
1030     if (op.getSourceVectorType().getDimSize(i) !=
1031         op.getResultType().getDimSize(i))
1032       numId++;
1033   }
1034   if (numId != op.ids().size())
1035     return op.emitOpError("expected number of ids must match the number of "
1036                           "dimensions distributed");
1037   return success();
1038 }
1039 
fold(ArrayRef<Attribute> operands)1040 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1041   auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1042   if (insert == nullptr || getType() != insert.vector().getType() ||
1043       ids() != insert.ids())
1044     return {};
1045   return insert.vector();
1046 }
1047 
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)1048 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1049   assert(multiplicity.empty());
1050   for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1051     if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1052       multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1053                              getResultType().getDimSize(i));
1054   }
1055 }
1056 
1057 template <typename MapOp>
calculateImplicitMap(MapOp op)1058 AffineMap calculateImplicitMap(MapOp op) {
1059   SmallVector<AffineExpr, 4> perm;
1060   // Check which dimension have a multiplicity greater than 1 and associated
1061   // them to the IDs in order.
1062   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1063     if (op.getSourceVectorType().getDimSize(i) !=
1064         op.getResultType().getDimSize(i))
1065       perm.push_back(getAffineDimExpr(i, op.getContext()));
1066   }
1067   auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1068                             op.getContext());
1069   return map;
1070 }
1071 
map()1072 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // BroadcastOp
1076 //===----------------------------------------------------------------------===//
1077 
verify(BroadcastOp op)1078 static LogicalResult verify(BroadcastOp op) {
1079   VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1080   VectorType dstVectorType = op.getVectorType();
1081   // Scalar to vector broadcast is always valid. A vector
1082   // to vector broadcast needs some additional checking.
1083   if (srcVectorType) {
1084     int64_t srcRank = srcVectorType.getRank();
1085     int64_t dstRank = dstVectorType.getRank();
1086     if (srcRank > dstRank)
1087       return op.emitOpError("source rank higher than destination rank");
1088     // Source has an exact match or singleton value for all trailing dimensions
1089     // (all leading dimensions are simply duplicated).
1090     int64_t lead = dstRank - srcRank;
1091     for (int64_t r = 0; r < srcRank; ++r) {
1092       int64_t srcDim = srcVectorType.getDimSize(r);
1093       int64_t dstDim = dstVectorType.getDimSize(lead + r);
1094       if (srcDim != 1 && srcDim != dstDim)
1095         return op.emitOpError("dimension mismatch (")
1096                << srcDim << " vs. " << dstDim << ")";
1097     }
1098   }
1099   return success();
1100 }
1101 
fold(ArrayRef<Attribute> operands)1102 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1103   if (!operands[0])
1104     return {};
1105   auto vectorType = getVectorType();
1106   if (operands[0].getType().isIntOrIndexOrFloat())
1107     return DenseElementsAttr::get(vectorType, operands[0]);
1108   if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1109     return DenseElementsAttr::get(vectorType, attr.getSplatValue());
1110   return {};
1111 }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // ShuffleOp
1115 //===----------------------------------------------------------------------===//
1116 
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)1117 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1118                       Value v2, ArrayRef<int64_t> mask) {
1119   result.addOperands({v1, v2});
1120   auto maskAttr = getVectorSubscriptAttr(builder, mask);
1121   result.addTypes(v1.getType());
1122   result.addAttribute(getMaskAttrName(), maskAttr);
1123 }
1124 
print(OpAsmPrinter & p,ShuffleOp op)1125 static void print(OpAsmPrinter &p, ShuffleOp op) {
1126   p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
1127     << op.mask();
1128   p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
1129   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1130 }
1131 
verify(ShuffleOp op)1132 static LogicalResult verify(ShuffleOp op) {
1133   VectorType resultType = op.getVectorType();
1134   VectorType v1Type = op.getV1VectorType();
1135   VectorType v2Type = op.getV2VectorType();
1136   // Verify ranks.
1137   int64_t resRank = resultType.getRank();
1138   int64_t v1Rank = v1Type.getRank();
1139   int64_t v2Rank = v2Type.getRank();
1140   if (resRank != v1Rank || v1Rank != v2Rank)
1141     return op.emitOpError("rank mismatch");
1142   // Verify all but leading dimension sizes.
1143   for (int64_t r = 1; r < v1Rank; ++r) {
1144     int64_t resDim = resultType.getDimSize(r);
1145     int64_t v1Dim = v1Type.getDimSize(r);
1146     int64_t v2Dim = v2Type.getDimSize(r);
1147     if (resDim != v1Dim || v1Dim != v2Dim)
1148       return op.emitOpError("dimension mismatch");
1149   }
1150   // Verify mask length.
1151   auto maskAttr = op.mask().getValue();
1152   int64_t maskLength = maskAttr.size();
1153   if (maskLength != resultType.getDimSize(0))
1154     return op.emitOpError("mask length mismatch");
1155   // Verify all indices.
1156   int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1157   for (auto en : llvm::enumerate(maskAttr)) {
1158     auto attr = en.value().dyn_cast<IntegerAttr>();
1159     if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1160       return op.emitOpError("mask index #")
1161              << (en.index() + 1) << " out of range";
1162   }
1163   return success();
1164 }
1165 
parseShuffleOp(OpAsmParser & parser,OperationState & result)1166 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1167   OpAsmParser::OperandType v1, v2;
1168   Attribute attr;
1169   VectorType v1Type, v2Type;
1170   if (parser.parseOperand(v1) || parser.parseComma() ||
1171       parser.parseOperand(v2) ||
1172       parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1173                             result.attributes) ||
1174       parser.parseOptionalAttrDict(result.attributes) ||
1175       parser.parseColonType(v1Type) || parser.parseComma() ||
1176       parser.parseType(v2Type) ||
1177       parser.resolveOperand(v1, v1Type, result.operands) ||
1178       parser.resolveOperand(v2, v2Type, result.operands))
1179     return failure();
1180   // Construct resulting type: leading dimension matches mask length,
1181   // all trailing dimensions match the operands.
1182   auto maskAttr = attr.dyn_cast<ArrayAttr>();
1183   if (!maskAttr)
1184     return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1185   int64_t maskLength = maskAttr.size();
1186   if (maskLength <= 0)
1187     return parser.emitError(parser.getNameLoc(), "invalid mask length");
1188   int64_t v1Rank = v1Type.getRank();
1189   SmallVector<int64_t, 4> shape;
1190   shape.reserve(v1Rank);
1191   shape.push_back(maskLength);
1192   for (int64_t r = 1; r < v1Rank; ++r)
1193     shape.push_back(v1Type.getDimSize(r));
1194   VectorType resType = VectorType::get(shape, v1Type.getElementType());
1195   parser.addTypeToList(resType, result.types);
1196   return success();
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // InsertElementOp
1201 //===----------------------------------------------------------------------===//
1202 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,Value position)1203 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1204                             Value source, Value dest, Value position) {
1205   result.addOperands({source, dest, position});
1206   result.addTypes(dest.getType());
1207 }
1208 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,int64_t position)1209 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1210                             Value source, Value dest, int64_t position) {
1211   Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
1212   build(builder, result, source, dest, pos);
1213 }
1214 
verify(InsertElementOp op)1215 static LogicalResult verify(InsertElementOp op) {
1216   auto dstVectorType = op.getDestVectorType();
1217   if (dstVectorType.getRank() != 1)
1218     return op.emitOpError("expected 1-D vector");
1219   return success();
1220 }
1221 
1222 //===----------------------------------------------------------------------===//
1223 // InsertOp
1224 //===----------------------------------------------------------------------===//
1225 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)1226 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1227                      Value dest, ArrayRef<int64_t> position) {
1228   result.addOperands({source, dest});
1229   auto positionAttr = getVectorSubscriptAttr(builder, position);
1230   result.addTypes(dest.getType());
1231   result.addAttribute(getPositionAttrName(), positionAttr);
1232 }
1233 
1234 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)1235 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1236                      Value dest, ValueRange position) {
1237   SmallVector<int64_t, 4> positionConstants =
1238       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1239         return pos.getDefiningOp<ConstantIndexOp>().getValue();
1240       }));
1241   build(builder, result, source, dest, positionConstants);
1242 }
1243 
verify(InsertOp op)1244 static LogicalResult verify(InsertOp op) {
1245   auto positionAttr = op.position().getValue();
1246   if (positionAttr.empty())
1247     return op.emitOpError("expected non-empty position attribute");
1248   auto destVectorType = op.getDestVectorType();
1249   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1250     return op.emitOpError(
1251         "expected position attribute of rank smaller than dest vector rank");
1252   auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1253   if (srcVectorType &&
1254       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1255        static_cast<unsigned>(destVectorType.getRank())))
1256     return op.emitOpError("expected position attribute rank + source rank to "
1257                           "match dest vector rank");
1258   else if (!srcVectorType && (positionAttr.size() !=
1259                               static_cast<unsigned>(destVectorType.getRank())))
1260     return op.emitOpError(
1261         "expected position attribute rank to match the dest vector rank");
1262   for (auto en : llvm::enumerate(positionAttr)) {
1263     auto attr = en.value().dyn_cast<IntegerAttr>();
1264     if (!attr || attr.getInt() < 0 ||
1265         attr.getInt() >= destVectorType.getDimSize(en.index()))
1266       return op.emitOpError("expected position attribute #")
1267              << (en.index() + 1)
1268              << " to be a non-negative integer smaller than the corresponding "
1269                 "dest vector dimension";
1270   }
1271   return success();
1272 }
1273 
1274 //===----------------------------------------------------------------------===//
1275 // InsertSlicesOp
1276 //===----------------------------------------------------------------------===//
1277 
verify(InsertSlicesOp op)1278 static LogicalResult verify(InsertSlicesOp op) {
1279   SmallVector<int64_t, 4> sizes;
1280   op.getSizes(sizes);
1281   SmallVector<int64_t, 4> strides;
1282   op.getStrides(strides);
1283   return isValidExtractOrInsertSlicesType(
1284       op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
1285       sizes, strides);
1286 }
1287 
getSizes(SmallVectorImpl<int64_t> & results)1288 void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
1289   populateFromInt64AttrArray(sizes(), results);
1290 }
1291 
getStrides(SmallVectorImpl<int64_t> & results)1292 void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
1293   populateFromInt64AttrArray(strides(), results);
1294 }
1295 
1296 //===----------------------------------------------------------------------===//
1297 // InsertMapOp
1298 //===----------------------------------------------------------------------===//
1299 
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange ids)1300 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1301                         Value vector, Value dest, ValueRange ids) {
1302   InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1303 }
1304 
verify(InsertMapOp op)1305 static LogicalResult verify(InsertMapOp op) {
1306   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1307     return op.emitOpError(
1308         "expected source and destination vectors of same rank");
1309   unsigned numId = 0;
1310   for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1311     if (op.getResultType().getDimSize(i) %
1312             op.getSourceVectorType().getDimSize(i) !=
1313         0)
1314       return op.emitOpError(
1315           "destination vector size must be a multiple of source vector size");
1316     if (op.getResultType().getDimSize(i) !=
1317         op.getSourceVectorType().getDimSize(i))
1318       numId++;
1319   }
1320   if (numId != op.ids().size())
1321     return op.emitOpError("expected number of ids must match the number of "
1322                           "dimensions distributed");
1323   return success();
1324 }
1325 
map()1326 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1327 
1328 //===----------------------------------------------------------------------===//
1329 // InsertStridedSliceOp
1330 //===----------------------------------------------------------------------===//
1331 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)1332 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1333                                  Value source, Value dest,
1334                                  ArrayRef<int64_t> offsets,
1335                                  ArrayRef<int64_t> strides) {
1336   result.addOperands({source, dest});
1337   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1338   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1339   result.addTypes(dest.getType());
1340   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1341   result.addAttribute(getStridesAttrName(), stridesAttr);
1342 }
1343 
1344 // TODO: Should be moved to Tablegen Confined attributes.
1345 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)1346 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1347                                                         ArrayAttr arrayAttr,
1348                                                         ArrayRef<int64_t> shape,
1349                                                         StringRef attrName) {
1350   if (arrayAttr.size() > shape.size())
1351     return op.emitOpError("expected ")
1352            << attrName << " attribute of rank smaller than vector rank";
1353   return success();
1354 }
1355 
1356 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1357 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1358 // Otherwise, the admissible interval is [min, max].
1359 template <typename OpType>
1360 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)1361 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1362                                   int64_t max, StringRef attrName,
1363                                   bool halfOpen = true) {
1364   for (auto attr : arrayAttr) {
1365     auto val = attr.cast<IntegerAttr>().getInt();
1366     auto upper = max;
1367     if (!halfOpen)
1368       upper += 1;
1369     if (val < min || val >= upper)
1370       return op.emitOpError("expected ") << attrName << " to be confined to ["
1371                                          << min << ", " << upper << ")";
1372   }
1373   return success();
1374 }
1375 
1376 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1377 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1378 // Otherwise, the admissible interval is [min, max].
1379 template <typename OpType>
1380 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)1381 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1382                                   ArrayRef<int64_t> shape, StringRef attrName,
1383                                   bool halfOpen = true, int64_t min = 0) {
1384   assert(arrayAttr.size() <= shape.size());
1385   unsigned index = 0;
1386   for (auto it : llvm::zip(arrayAttr, shape)) {
1387     auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1388     auto max = std::get<1>(it);
1389     if (!halfOpen)
1390       max += 1;
1391     if (val < min || val >= max)
1392       return op.emitOpError("expected ")
1393              << attrName << " dimension " << index << " to be confined to ["
1394              << min << ", " << max << ")";
1395     ++index;
1396   }
1397   return success();
1398 }
1399 
1400 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1401 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1402 // Otherwise, the admissible interval is [min, max].
1403 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)1404 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
1405     OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1406     ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1407     bool halfOpen = true, int64_t min = 1) {
1408   assert(arrayAttr1.size() <= shape.size());
1409   assert(arrayAttr2.size() <= shape.size());
1410   unsigned index = 0;
1411   for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1412     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1413     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1414     auto max = std::get<2>(it);
1415     if (!halfOpen)
1416       max += 1;
1417     if (val1 + val2 < 0 || val1 + val2 >= max)
1418       return op.emitOpError("expected sum(")
1419              << attrName1 << ", " << attrName2 << ") dimension " << index
1420              << " to be confined to [" << min << ", " << max << ")";
1421     ++index;
1422   }
1423   return success();
1424 }
1425 
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1426 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1427                                   MLIRContext *context) {
1428   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1429     return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
1430   });
1431   return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
1432 }
1433 
verify(InsertStridedSliceOp op)1434 static LogicalResult verify(InsertStridedSliceOp op) {
1435   auto sourceVectorType = op.getSourceVectorType();
1436   auto destVectorType = op.getDestVectorType();
1437   auto offsets = op.offsets();
1438   auto strides = op.strides();
1439   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1440     return op.emitOpError(
1441         "expected offsets of same size as destination vector rank");
1442   if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1443     return op.emitOpError(
1444         "expected strides of same size as source vector rank");
1445   if (sourceVectorType.getRank() > destVectorType.getRank())
1446     return op.emitOpError(
1447         "expected source rank to be smaller than destination rank");
1448 
1449   auto sourceShape = sourceVectorType.getShape();
1450   auto destShape = destVectorType.getShape();
1451   SmallVector<int64_t, 4> sourceShapeAsDestShape(
1452       destShape.size() - sourceShape.size(), 0);
1453   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1454   auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1455   auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1456   if (failed(
1457           isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1458       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1459                                                /*halfOpen=*/false)) ||
1460       failed(isSumOfIntegerArrayAttrConfinedToShape(
1461           op, offsets,
1462           makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1463           offName, "source vector shape",
1464           /*halfOpen=*/false, /*min=*/1)))
1465     return failure();
1466 
1467   return success();
1468 }
1469 
1470 //===----------------------------------------------------------------------===//
1471 // OuterProductOp
1472 //===----------------------------------------------------------------------===//
1473 
1474 /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)1475 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
1476                            Value lhs, Value rhs, Value acc) {
1477   result.addOperands({lhs, rhs, acc});
1478   result.addTypes(acc.getType());
1479 }
1480 
print(OpAsmPrinter & p,OuterProductOp op)1481 static void print(OpAsmPrinter &p, OuterProductOp op) {
1482   p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
1483   if (!op.acc().empty())
1484     p << ", " << op.acc();
1485   p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1486 }
1487 
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1488 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1489                                        OperationState &result) {
1490   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1491   Type tLHS, tRHS;
1492   if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
1493       parser.parseComma() || parser.parseType(tRHS))
1494     return failure();
1495   if (operandsInfo.size() < 2)
1496     return parser.emitError(parser.getNameLoc(),
1497                             "expected at least 2 operands");
1498   VectorType vLHS = tLHS.dyn_cast<VectorType>();
1499   VectorType vRHS = tRHS.dyn_cast<VectorType>();
1500   if (!vLHS)
1501     return parser.emitError(parser.getNameLoc(),
1502                             "expected vector type for operand #1");
1503   VectorType resType =
1504       vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1505                              vLHS.getElementType())
1506            : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
1507   return failure(
1508       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1509       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1510       (operandsInfo.size() > 2 &&
1511        parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1512       parser.addTypeToList(resType, result.types));
1513 }
1514 
verify(OuterProductOp op)1515 static LogicalResult verify(OuterProductOp op) {
1516   Type tRHS = op.getOperandTypeRHS();
1517   VectorType vLHS = op.getOperandVectorTypeLHS(),
1518              vRHS = tRHS.dyn_cast<VectorType>(),
1519              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1520 
1521   if (vLHS.getRank() != 1)
1522     return op.emitOpError("expected 1-d vector for operand #1");
1523 
1524   if (vRHS) {
1525     // Proper OUTER operation.
1526     if (vRHS.getRank() != 1)
1527       return op.emitOpError("expected 1-d vector for operand #2");
1528     if (vRES.getRank() != 2)
1529       return op.emitOpError("expected 2-d vector result");
1530     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1531       return op.emitOpError("expected #1 operand dim to match result dim #1");
1532     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1533       return op.emitOpError("expected #2 operand dim to match result dim #2");
1534   } else {
1535     // An AXPY operation.
1536     if (vRES.getRank() != 1)
1537       return op.emitOpError("expected 1-d vector result");
1538     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1539       return op.emitOpError("expected #1 operand dim to match result dim #1");
1540   }
1541 
1542   if (vACC && vACC != vRES)
1543     return op.emitOpError("expected operand #3 of same type as result type");
1544   return success();
1545 }
1546 
1547 //===----------------------------------------------------------------------===//
1548 // ReshapeOp
1549 //===----------------------------------------------------------------------===//
1550 
verify(ReshapeOp op)1551 static LogicalResult verify(ReshapeOp op) {
1552   // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1553   auto inputVectorType = op.getInputVectorType();
1554   auto outputVectorType = op.getOutputVectorType();
1555   int64_t inputShapeRank = op.getNumInputShapeSizes();
1556   int64_t outputShapeRank = op.getNumOutputShapeSizes();
1557   SmallVector<int64_t, 4> fixedVectorSizes;
1558   op.getFixedVectorSizes(fixedVectorSizes);
1559   int64_t numFixedVectorSizes = fixedVectorSizes.size();
1560 
1561   if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1562     return op.emitError("invalid input shape for vector type ")
1563            << inputVectorType;
1564 
1565   if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1566     return op.emitError("invalid output shape for vector type ")
1567            << outputVectorType;
1568 
1569   // Verify that the 'fixedVectorSizes' match an input/output vector shape
1570   // suffix.
1571   unsigned inputVectorRank = inputVectorType.getRank();
1572   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1573     unsigned index = inputVectorRank - numFixedVectorSizes - i;
1574     if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1575       return op.emitError("fixed vector size must match input vector for dim ")
1576              << i;
1577   }
1578 
1579   unsigned outputVectorRank = outputVectorType.getRank();
1580   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1581     unsigned index = outputVectorRank - numFixedVectorSizes - i;
1582     if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1583       return op.emitError("fixed vector size must match output vector for dim ")
1584              << i;
1585   }
1586 
1587   // If all shape operands are produced by constant ops, verify that product
1588   // of dimensions for input/output shape match.
1589   auto isDefByConstant = [](Value operand) {
1590     return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1591   };
1592   if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1593       llvm::all_of(op.output_shape(), isDefByConstant)) {
1594     int64_t numInputElements = 1;
1595     for (auto operand : op.input_shape())
1596       numInputElements *=
1597           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1598     int64_t numOutputElements = 1;
1599     for (auto operand : op.output_shape())
1600       numOutputElements *=
1601           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1602     if (numInputElements != numOutputElements)
1603       return op.emitError("product of input and output shape sizes must match");
1604   }
1605   return success();
1606 }
1607 
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1608 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1609   populateFromInt64AttrArray(fixed_vector_sizes(), results);
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // ExtractStridedSliceOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 // Inference works as follows:
1617 //   1. Add 'sizes' from prefix of dims in 'offsets'.
1618 //   2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1619 static Type inferStridedSliceOpResultType(VectorType vectorType,
1620                                           ArrayAttr offsets, ArrayAttr sizes,
1621                                           ArrayAttr strides) {
1622   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1623   SmallVector<int64_t, 4> shape;
1624   shape.reserve(vectorType.getRank());
1625   unsigned idx = 0;
1626   for (unsigned e = offsets.size(); idx < e; ++idx)
1627     shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
1628   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1629     shape.push_back(vectorType.getShape()[idx]);
1630 
1631   return VectorType::get(shape, vectorType.getElementType());
1632 }
1633 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1634 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1635                                   Value source, ArrayRef<int64_t> offsets,
1636                                   ArrayRef<int64_t> sizes,
1637                                   ArrayRef<int64_t> strides) {
1638   result.addOperands(source);
1639   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1640   auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
1641   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1642   result.addTypes(
1643       inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1644                                     offsetsAttr, sizesAttr, stridesAttr));
1645   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1646   result.addAttribute(getSizesAttrName(), sizesAttr);
1647   result.addAttribute(getStridesAttrName(), stridesAttr);
1648 }
1649 
verify(ExtractStridedSliceOp op)1650 static LogicalResult verify(ExtractStridedSliceOp op) {
1651   auto type = op.getVectorType();
1652   auto offsets = op.offsets();
1653   auto sizes = op.sizes();
1654   auto strides = op.strides();
1655   if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
1656     op.emitOpError(
1657         "expected offsets, sizes and strides attributes of same size");
1658     return failure();
1659   }
1660 
1661   auto shape = type.getShape();
1662   auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
1663   auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
1664   auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
1665   if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
1666       failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
1667       failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
1668                                                 stridesName)) ||
1669       failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
1670       failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
1671                                                /*halfOpen=*/false,
1672                                                /*min=*/1)) ||
1673       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1674                                                /*halfOpen=*/false)) ||
1675       failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
1676                                                     offName, sizesName,
1677                                                     /*halfOpen=*/false)))
1678     return failure();
1679 
1680   auto resultType = inferStridedSliceOpResultType(
1681       op.getVectorType(), op.offsets(), op.sizes(), op.strides());
1682   if (op.getResult().getType() != resultType) {
1683     op.emitOpError("expected result type to be ") << resultType;
1684     return failure();
1685   }
1686 
1687   return success();
1688 }
1689 
1690 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
1691 // to use the source of the InsertStrided ops if we can detect that the
1692 // extracted vector is a subset of one of the vector inserted.
1693 static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)1694 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
1695   // Helper to extract integer out of ArrayAttr.
1696   auto getElement = [](ArrayAttr array, int idx) {
1697     return array[idx].cast<IntegerAttr>().getInt();
1698   };
1699   ArrayAttr extractOffsets = op.offsets();
1700   ArrayAttr extractStrides = op.strides();
1701   ArrayAttr extractSizes = op.sizes();
1702   auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1703   while (insertOp) {
1704     if (op.getVectorType().getRank() !=
1705         insertOp.getSourceVectorType().getRank())
1706       return failure();
1707     ArrayAttr insertOffsets = insertOp.offsets();
1708     ArrayAttr insertStrides = insertOp.strides();
1709     // If the rank of extract is greater than the rank of insert, we are likely
1710     // extracting a partial chunk of the vector inserted.
1711     if (extractOffsets.size() > insertOffsets.size())
1712       return failure();
1713     bool patialoverlap = false;
1714     bool disjoint = false;
1715     SmallVector<int64_t, 4> offsetDiffs;
1716     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1717       if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
1718         return failure();
1719       int64_t start = getElement(insertOffsets, dim);
1720       int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
1721       int64_t offset = getElement(extractOffsets, dim);
1722       int64_t size = getElement(extractSizes, dim);
1723       // Check if the start of the extract offset is in the interval inserted.
1724       if (start <= offset && offset < end) {
1725         // If the extract interval overlaps but is not fully included we may
1726         // have a partial overlap that will prevent any folding.
1727         if (offset + size > end)
1728           patialoverlap = true;
1729         offsetDiffs.push_back(offset - start);
1730         continue;
1731       }
1732       disjoint = true;
1733       break;
1734     }
1735     // The extract element chunk is a subset of the insert element.
1736     if (!disjoint && !patialoverlap) {
1737       op.setOperand(insertOp.source());
1738       // OpBuilder is only used as a helper to build an I64ArrayAttr.
1739       OpBuilder b(op.getContext());
1740       op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
1741                  b.getI64ArrayAttr(offsetDiffs));
1742       return success();
1743     }
1744     // If the chunk extracted is disjoint from the chunk inserted, keep looking
1745     // in the insert chain.
1746     if (disjoint)
1747       insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
1748     else {
1749       // The extracted vector partially overlap the inserted vector, we cannot
1750       // fold.
1751       return failure();
1752     }
1753   }
1754   return failure();
1755 }
1756 
fold(ArrayRef<Attribute> operands)1757 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
1758   if (getVectorType() == getResult().getType())
1759     return vector();
1760   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
1761     return getResult();
1762   return {};
1763 }
1764 
getOffsets(SmallVectorImpl<int64_t> & results)1765 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
1766   populateFromInt64AttrArray(offsets(), results);
1767 }
1768 
1769 namespace {
1770 
1771 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
1772 class StridedSliceConstantMaskFolder final
1773     : public OpRewritePattern<ExtractStridedSliceOp> {
1774 public:
1775   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1776 
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const1777   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
1778                                 PatternRewriter &rewriter) const override {
1779     // Return if 'extractStridedSliceOp' operand is not defined by a
1780     // ConstantMaskOp.
1781     auto defOp = extractStridedSliceOp.vector().getDefiningOp();
1782     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
1783     if (!constantMaskOp)
1784       return failure();
1785     // Return if 'extractStridedSliceOp' has non-unit strides.
1786     if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
1787           return attr.cast<IntegerAttr>().getInt() != 1;
1788         }))
1789       return failure();
1790     // Gather constant mask dimension sizes.
1791     SmallVector<int64_t, 4> maskDimSizes;
1792     populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
1793     // Gather strided slice offsets and sizes.
1794     SmallVector<int64_t, 4> sliceOffsets;
1795     populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
1796     SmallVector<int64_t, 4> sliceSizes;
1797     populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
1798 
1799     // Compute slice of vector mask region.
1800     SmallVector<int64_t, 4> sliceMaskDimSizes;
1801     assert(sliceOffsets.size() == maskDimSizes.size());
1802     for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
1803       int64_t maskDimSize = std::get<0>(it);
1804       int64_t sliceOffset = std::get<1>(it);
1805       int64_t sliceSize = std::get<2>(it);
1806       int64_t sliceMaskDimSize = std::max(
1807           static_cast<int64_t>(0),
1808           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
1809       sliceMaskDimSizes.push_back(sliceMaskDimSize);
1810     }
1811     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
1812     // region is a conjunction of mask dim intervals).
1813     if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
1814       sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
1815 
1816     // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
1817     // region.
1818     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
1819         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
1820         vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
1821     return success();
1822   }
1823 };
1824 
1825 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
1826 class StridedSliceConstantFolder final
1827     : public OpRewritePattern<ExtractStridedSliceOp> {
1828 public:
1829   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1830 
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const1831   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
1832                                 PatternRewriter &rewriter) const override {
1833     // Return if 'extractStridedSliceOp' operand is not defined by a
1834     // ConstantOp.
1835     auto constantOp =
1836         extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
1837     if (!constantOp)
1838       return failure();
1839     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
1840     if (!dense)
1841       return failure();
1842     auto newAttr = DenseElementsAttr::get(
1843         extractStridedSliceOp.getType().cast<VectorType>(),
1844         dense.getSplatValue());
1845     rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
1846     return success();
1847   }
1848 };
1849 
1850 } // end anonymous namespace
1851 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1852 void ExtractStridedSliceOp::getCanonicalizationPatterns(
1853     OwningRewritePatternList &results, MLIRContext *context) {
1854   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
1855   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
1856   results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
1857       context);
1858 }
1859 
1860 //===----------------------------------------------------------------------===//
1861 // TransferReadOp
1862 //===----------------------------------------------------------------------===//
1863 
1864 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)1865 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
1866                                           EmitFun emitOpError) {
1867   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
1868   for (auto expr : permutationMap.getResults()) {
1869     auto dim = expr.dyn_cast<AffineDimExpr>();
1870     auto zero = expr.dyn_cast<AffineConstantExpr>();
1871     if (zero) {
1872       if (zero.getValue() != 0) {
1873         return emitOpError(
1874             "requires a projected permutation_map (at most one dim or the zero "
1875             "constant can appear in each result)");
1876       }
1877       continue;
1878     }
1879     if (!dim) {
1880       return emitOpError("requires a projected permutation_map (at most one "
1881                          "dim or the zero constant can appear in each result)");
1882     }
1883     if (seen[dim.getPosition()]) {
1884       return emitOpError(
1885           "requires a permutation_map that is a permutation (found one dim "
1886           "used more than once)");
1887     }
1888     seen[dim.getPosition()] = true;
1889   }
1890   return success();
1891 }
1892 
verifyTransferOp(Operation * op,MemRefType memrefType,VectorType vectorType,AffineMap permutationMap,ArrayAttr optionalMasked)1893 static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
1894                                       VectorType vectorType,
1895                                       AffineMap permutationMap,
1896                                       ArrayAttr optionalMasked) {
1897   auto memrefElementType = memrefType.getElementType();
1898   if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
1899     // Memref has vector element type.
1900 
1901     unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
1902                              memrefVectorElementType.getShape().back();
1903     unsigned resultVecSize =
1904         vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
1905     if (resultVecSize % memrefVecSize != 0)
1906       return op->emitOpError(
1907           "requires the bitwidth of the minor 1-D vector to be an integral "
1908           "multiple of the bitwidth of the minor 1-D vector of the memref");
1909 
1910     unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1911     unsigned resultVecRank = vectorType.getRank();
1912     if (memrefVecEltRank > resultVecRank)
1913       return op->emitOpError(
1914           "requires memref vector element and vector result ranks to match.");
1915     unsigned rankOffset = resultVecRank - memrefVecEltRank;
1916     // Check that permutation map results match 'rankOffset' of vector type.
1917     if (permutationMap.getNumResults() != rankOffset)
1918       return op->emitOpError("requires a permutation_map with result dims of "
1919                              "the same rank as the vector type");
1920   } else {
1921     // Memref has scalar element type.
1922     unsigned resultVecSize =
1923         vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
1924     if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
1925       return op->emitOpError(
1926           "requires the bitwidth of the minor 1-D vector to be an integral "
1927           "multiple of the bitwidth of the memref element type");
1928 
1929     // Check that permutation map results match rank of vector type.
1930     if (permutationMap.getNumResults() != vectorType.getRank())
1931       return op->emitOpError("requires a permutation_map with result dims of "
1932                              "the same rank as the vector type");
1933   }
1934 
1935   if (permutationMap.getNumSymbols() != 0)
1936     return op->emitOpError("requires permutation_map without symbols");
1937   if (permutationMap.getNumInputs() != memrefType.getRank())
1938     return op->emitOpError("requires a permutation_map with input dims of the "
1939                            "same rank as the memref type");
1940 
1941   if (optionalMasked) {
1942     if (permutationMap.getNumResults() !=
1943         static_cast<int64_t>(optionalMasked.size()))
1944       return op->emitOpError("expects the optional masked attr of same rank as "
1945                              "permutation_map results: ")
1946              << AffineMapAttr::get(permutationMap);
1947   }
1948 
1949   return success();
1950 }
1951 
1952 /// Builder that sets padding to zero.
build(OpBuilder & builder,OperationState & result,VectorType vector,Value memref,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> maybeMasked)1953 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
1954                            VectorType vector, Value memref, ValueRange indices,
1955                            AffineMap permutationMap,
1956                            ArrayRef<bool> maybeMasked) {
1957   Type elemType = memref.getType().cast<MemRefType>().getElementType();
1958   Value padding = builder.create<ConstantOp>(result.location, elemType,
1959                                              builder.getZeroAttr(elemType));
1960   if (maybeMasked.empty())
1961     return build(builder, result, vector, memref, indices, permutationMap,
1962                  padding, ArrayAttr());
1963   ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
1964   build(builder, result, vector, memref, indices, permutationMap, padding,
1965         maskedArrayAttr);
1966 }
1967 
1968 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
1969 /// (resp. zero).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value memref,ValueRange indices,ArrayRef<bool> maybeMasked)1970 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
1971                            VectorType vectorType, Value memref,
1972                            ValueRange indices, ArrayRef<bool> maybeMasked) {
1973   auto permMap = getTransferMinorIdentityMap(
1974       memref.getType().cast<MemRefType>(), vectorType);
1975   build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
1976 }
1977 
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)1978 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
1979   SmallVector<StringRef, 2> elidedAttrs;
1980   if (op.permutation_map() ==
1981       getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
1982     elidedAttrs.push_back(op.getPermutationMapAttrName());
1983   bool elideMasked = true;
1984   if (auto maybeMasked = op.masked()) {
1985     for (auto attr : *maybeMasked) {
1986       if (!attr.template cast<BoolAttr>().getValue()) {
1987         elideMasked = false;
1988         break;
1989       }
1990     }
1991   }
1992   if (elideMasked)
1993     elidedAttrs.push_back(op.getMaskedAttrName());
1994   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
1995 }
1996 
print(OpAsmPrinter & p,TransferReadOp op)1997 static void print(OpAsmPrinter &p, TransferReadOp op) {
1998   p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
1999     << "], " << op.padding();
2000   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2001   p << " : " << op.getMemRefType() << ", " << op.getVectorType();
2002 }
2003 
parseTransferReadOp(OpAsmParser & parser,OperationState & result)2004 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2005                                        OperationState &result) {
2006   llvm::SMLoc typesLoc;
2007   OpAsmParser::OperandType memrefInfo;
2008   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2009   OpAsmParser::OperandType paddingInfo;
2010   SmallVector<Type, 2> types;
2011   // Parsing with support for paddingValue.
2012   if (parser.parseOperand(memrefInfo) ||
2013       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2014       parser.parseComma() || parser.parseOperand(paddingInfo) ||
2015       parser.parseOptionalAttrDict(result.attributes) ||
2016       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2017     return failure();
2018   if (types.size() != 2)
2019     return parser.emitError(typesLoc, "requires two types");
2020   auto indexType = parser.getBuilder().getIndexType();
2021   MemRefType memRefType = types[0].dyn_cast<MemRefType>();
2022   if (!memRefType)
2023     return parser.emitError(typesLoc, "requires memref type");
2024   VectorType vectorType = types[1].dyn_cast<VectorType>();
2025   if (!vectorType)
2026     return parser.emitError(typesLoc, "requires vector type");
2027   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2028   auto attr = result.attributes.get(permutationAttrName);
2029   if (!attr) {
2030     auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
2031     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2032   }
2033   return failure(
2034       parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
2035       parser.resolveOperands(indexInfo, indexType, result.operands) ||
2036       parser.resolveOperand(paddingInfo, memRefType.getElementType(),
2037                             result.operands) ||
2038       parser.addTypeToList(vectorType, result.types));
2039 }
2040 
verify(TransferReadOp op)2041 static LogicalResult verify(TransferReadOp op) {
2042   // Consistency of elemental types in memref and vector.
2043   MemRefType memrefType = op.getMemRefType();
2044   VectorType vectorType = op.getVectorType();
2045   auto paddingType = op.padding().getType();
2046   auto permutationMap = op.permutation_map();
2047   auto memrefElementType = memrefType.getElementType();
2048 
2049   if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
2050     return op.emitOpError("requires ") << memrefType.getRank() << " indices";
2051 
2052   if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
2053                               permutationMap,
2054                               op.masked() ? *op.masked() : ArrayAttr())))
2055     return failure();
2056 
2057   if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
2058     // Memref has vector element type.
2059     // Check that 'memrefVectorElementType' and 'paddingType' types match.
2060     if (memrefVectorElementType != paddingType)
2061       return op.emitOpError(
2062           "requires memref element type and padding type to match.");
2063 
2064   } else {
2065     // Check that 'paddingType' is valid to store in a vector type.
2066     if (!VectorType::isValidElementType(paddingType))
2067       return op.emitOpError("requires valid padding vector elemental type");
2068 
2069     // Check that padding type and vector element types match.
2070     if (paddingType != memrefElementType)
2071       return op.emitOpError(
2072           "requires formal padding and memref of the same elemental type");
2073   }
2074 
2075   return verifyPermutationMap(permutationMap,
2076                               [&op](Twine t) { return op.emitOpError(t); });
2077 }
2078 
2079 /// This is a common class used for patterns of the form
2080 /// ```
2081 ///    someop(memrefcast) -> someop
2082 /// ```
2083 /// It folds the source of the memref_cast into the root operation directly.
foldMemRefCast(Operation * op)2084 static LogicalResult foldMemRefCast(Operation *op) {
2085   bool folded = false;
2086   for (OpOperand &operand : op->getOpOperands()) {
2087     auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
2088     if (castOp && canFoldIntoConsumerOp(castOp)) {
2089       operand.set(castOp.getOperand());
2090       folded = true;
2091     }
2092   }
2093   return success(folded);
2094 }
2095 
2096 template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)2097 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2098   // TODO: support more aggressive createOrFold on:
2099   // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
2100   if (op.getMemRefType().isDynamicDim(indicesIdx))
2101     return false;
2102   Value index = op.indices()[indicesIdx];
2103   auto cstOp = index.getDefiningOp<ConstantIndexOp>();
2104   if (!cstOp)
2105     return false;
2106 
2107   int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
2108   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2109 
2110   return cstOp.getValue() + vectorSize <= memrefSize;
2111 }
2112 
2113 template <typename TransferOp>
foldTransferMaskAttribute(TransferOp op)2114 static LogicalResult foldTransferMaskAttribute(TransferOp op) {
2115   AffineMap permutationMap = op.permutation_map();
2116   if (!permutationMap.isMinorIdentity())
2117     return failure();
2118   bool changed = false;
2119   SmallVector<bool, 4> isMasked;
2120   isMasked.reserve(op.getTransferRank());
2121   op.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2122     // Already marked unmasked, nothing to see here.
2123     if (!op.isMaskedDim(resultIdx)) {
2124       isMasked.push_back(false);
2125       return;
2126     }
2127     // Currently masked, check whether we can statically determine it is
2128     // inBounds.
2129     auto inBounds = isInBounds(op, resultIdx, indicesIdx);
2130     isMasked.push_back(!inBounds);
2131     // We commit the pattern if it is "more inbounds".
2132     changed |= inBounds;
2133   });
2134   if (!changed)
2135     return failure();
2136   // OpBuilder is only used as a helper to build an I64ArrayAttr.
2137   OpBuilder b(op.getContext());
2138   op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked));
2139   return success();
2140 }
2141 
fold(ArrayRef<Attribute>)2142 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2143   /// transfer_read(memrefcast) -> transfer_read
2144   if (succeeded(foldTransferMaskAttribute(*this)))
2145     return getResult();
2146   if (succeeded(foldMemRefCast(*this)))
2147     return getResult();
2148   return OpFoldResult();
2149 }
2150 
getShapeForUnroll()2151 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2152   auto s = getVectorType().getShape();
2153   return SmallVector<int64_t, 4>{s.begin(), s.end()};
2154 }
2155 
2156 //===----------------------------------------------------------------------===//
2157 // TransferWriteOp
2158 //===----------------------------------------------------------------------===//
2159 
2160 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value memref,ValueRange indices,ArrayRef<bool> maybeMasked)2161 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2162                             Value vector, Value memref, ValueRange indices,
2163                             ArrayRef<bool> maybeMasked) {
2164   auto vectorType = vector.getType().cast<VectorType>();
2165   auto permMap = getTransferMinorIdentityMap(
2166       memref.getType().cast<MemRefType>(), vectorType);
2167   if (maybeMasked.empty())
2168     return build(builder, result, vector, memref, indices, permMap,
2169                  ArrayAttr());
2170   ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
2171   build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
2172 }
2173 
build(OpBuilder & builder,OperationState & result,Value vector,Value memref,ValueRange indices,AffineMap permutationMap)2174 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2175                             Value vector, Value memref, ValueRange indices,
2176                             AffineMap permutationMap) {
2177   build(builder, result, vector, memref, indices, permutationMap,
2178         /*maybeMasked=*/ArrayAttr());
2179 }
2180 
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)2181 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
2182                                         OperationState &result) {
2183   llvm::SMLoc typesLoc;
2184   OpAsmParser::OperandType vectorInfo, memrefInfo;
2185   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2186   SmallVector<Type, 2> types;
2187   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
2188       parser.parseOperand(memrefInfo) ||
2189       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2190       parser.parseOptionalAttrDict(result.attributes) ||
2191       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2192     return failure();
2193   if (types.size() != 2)
2194     return parser.emitError(typesLoc, "requires two types");
2195   auto indexType = parser.getBuilder().getIndexType();
2196   VectorType vectorType = types[0].dyn_cast<VectorType>();
2197   if (!vectorType)
2198     return parser.emitError(typesLoc, "requires vector type");
2199   MemRefType memRefType = types[1].dyn_cast<MemRefType>();
2200   if (!memRefType)
2201     return parser.emitError(typesLoc, "requires memref type");
2202   auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
2203   auto attr = result.attributes.get(permutationAttrName);
2204   if (!attr) {
2205     auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
2206     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2207   }
2208   return failure(
2209       parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
2210       parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
2211       parser.resolveOperands(indexInfo, indexType, result.operands));
2212 }
2213 
print(OpAsmPrinter & p,TransferWriteOp op)2214 static void print(OpAsmPrinter &p, TransferWriteOp op) {
2215   p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
2216     << op.indices() << "]";
2217   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2218   p << " : " << op.getVectorType() << ", " << op.getMemRefType();
2219 }
2220 
verify(TransferWriteOp op)2221 static LogicalResult verify(TransferWriteOp op) {
2222   // Consistency of elemental types in memref and vector.
2223   MemRefType memrefType = op.getMemRefType();
2224   VectorType vectorType = op.getVectorType();
2225   auto permutationMap = op.permutation_map();
2226 
2227   if (llvm::size(op.indices()) != memrefType.getRank())
2228     return op.emitOpError("requires ") << memrefType.getRank() << " indices";
2229 
2230   if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
2231                               permutationMap,
2232                               op.masked() ? *op.masked() : ArrayAttr())))
2233     return failure();
2234 
2235   return verifyPermutationMap(permutationMap,
2236                               [&op](Twine t) { return op.emitOpError(t); });
2237 }
2238 
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2239 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
2240                                     SmallVectorImpl<OpFoldResult> &) {
2241   if (succeeded(foldTransferMaskAttribute(*this)))
2242     return success();
2243   return foldMemRefCast(*this);
2244 }
2245 
getShapeForUnroll()2246 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
2247   return llvm::to_vector<4>(getVectorType().getShape());
2248 }
2249 
2250 //===----------------------------------------------------------------------===//
2251 // MaskedLoadOp
2252 //===----------------------------------------------------------------------===//
2253 
verify(MaskedLoadOp op)2254 static LogicalResult verify(MaskedLoadOp op) {
2255   VectorType maskVType = op.getMaskVectorType();
2256   VectorType passVType = op.getPassThruVectorType();
2257   VectorType resVType = op.getResultVectorType();
2258 
2259   if (resVType.getElementType() != op.getMemRefType().getElementType())
2260     return op.emitOpError("base and result element type should match");
2261 
2262   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
2263     return op.emitOpError("expected result dim to match mask dim");
2264   if (resVType != passVType)
2265     return op.emitOpError("expected pass_thru of same type as result type");
2266   return success();
2267 }
2268 
2269 namespace {
2270 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
2271 public:
2272   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const2273   LogicalResult matchAndRewrite(MaskedLoadOp load,
2274                                 PatternRewriter &rewriter) const override {
2275     Value newBase;
2276     switch (get1DMaskFormat(load.mask())) {
2277     case MaskFormat::AllTrue:
2278       if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
2279                           load.getResultVectorType(), rewriter, newBase))
2280         return failure();
2281       rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
2282       return success();
2283     case MaskFormat::AllFalse:
2284       rewriter.replaceOp(load, load.pass_thru());
2285       return success();
2286     case MaskFormat::Unknown:
2287       return failure();
2288     }
2289     llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
2290   }
2291 };
2292 } // namespace
2293 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2294 void MaskedLoadOp::getCanonicalizationPatterns(
2295     OwningRewritePatternList &results, MLIRContext *context) {
2296   results.insert<MaskedLoadFolder>(context);
2297 }
2298 
2299 //===----------------------------------------------------------------------===//
2300 // MaskedStoreOp
2301 //===----------------------------------------------------------------------===//
2302 
verify(MaskedStoreOp op)2303 static LogicalResult verify(MaskedStoreOp op) {
2304   VectorType maskVType = op.getMaskVectorType();
2305   VectorType valueVType = op.getValueVectorType();
2306 
2307   if (valueVType.getElementType() != op.getMemRefType().getElementType())
2308     return op.emitOpError("base and value element type should match");
2309 
2310   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2311     return op.emitOpError("expected value dim to match mask dim");
2312   return success();
2313 }
2314 
2315 namespace {
2316 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
2317 public:
2318   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const2319   LogicalResult matchAndRewrite(MaskedStoreOp store,
2320                                 PatternRewriter &rewriter) const override {
2321     Value newBase;
2322     switch (get1DMaskFormat(store.mask())) {
2323     case MaskFormat::AllTrue:
2324       if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
2325                           store.getValueVectorType(), rewriter, newBase))
2326         return failure();
2327       rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
2328       return success();
2329     case MaskFormat::AllFalse:
2330       rewriter.eraseOp(store);
2331       return success();
2332     case MaskFormat::Unknown:
2333       return failure();
2334     }
2335     llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
2336   }
2337 };
2338 } // namespace
2339 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2340 void MaskedStoreOp::getCanonicalizationPatterns(
2341     OwningRewritePatternList &results, MLIRContext *context) {
2342   results.insert<MaskedStoreFolder>(context);
2343 }
2344 
2345 //===----------------------------------------------------------------------===//
2346 // GatherOp
2347 //===----------------------------------------------------------------------===//
2348 
verify(GatherOp op)2349 static LogicalResult verify(GatherOp op) {
2350   VectorType indicesVType = op.getIndicesVectorType();
2351   VectorType maskVType = op.getMaskVectorType();
2352   VectorType resVType = op.getResultVectorType();
2353 
2354   if (resVType.getElementType() != op.getMemRefType().getElementType())
2355     return op.emitOpError("base and result element type should match");
2356 
2357   if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
2358     return op.emitOpError("expected result dim to match indices dim");
2359   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
2360     return op.emitOpError("expected result dim to match mask dim");
2361   if (llvm::size(op.pass_thru()) != 0) {
2362     VectorType passVType = op.getPassThruVectorType();
2363     if (resVType != passVType)
2364       return op.emitOpError("expected pass_thru of same type as result type");
2365   }
2366   return success();
2367 }
2368 
2369 namespace {
2370 class GatherFolder final : public OpRewritePattern<GatherOp> {
2371 public:
2372   using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const2373   LogicalResult matchAndRewrite(GatherOp gather,
2374                                 PatternRewriter &rewriter) const override {
2375     switch (get1DMaskFormat(gather.mask())) {
2376     case MaskFormat::AllTrue:
2377       return failure(); // no unmasked equivalent
2378     case MaskFormat::AllFalse:
2379       rewriter.replaceOp(gather, gather.pass_thru());
2380       return success();
2381     case MaskFormat::Unknown:
2382       return failure();
2383     }
2384     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
2385   }
2386 };
2387 } // namespace
2388 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2389 void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2390                                            MLIRContext *context) {
2391   results.insert<GatherFolder>(context);
2392 }
2393 
2394 //===----------------------------------------------------------------------===//
2395 // ScatterOp
2396 //===----------------------------------------------------------------------===//
2397 
verify(ScatterOp op)2398 static LogicalResult verify(ScatterOp op) {
2399   VectorType indicesVType = op.getIndicesVectorType();
2400   VectorType maskVType = op.getMaskVectorType();
2401   VectorType valueVType = op.getValueVectorType();
2402 
2403   if (valueVType.getElementType() != op.getMemRefType().getElementType())
2404     return op.emitOpError("base and value element type should match");
2405 
2406   if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
2407     return op.emitOpError("expected value dim to match indices dim");
2408   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2409     return op.emitOpError("expected value dim to match mask dim");
2410   return success();
2411 }
2412 
2413 namespace {
2414 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
2415 public:
2416   using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const2417   LogicalResult matchAndRewrite(ScatterOp scatter,
2418                                 PatternRewriter &rewriter) const override {
2419     switch (get1DMaskFormat(scatter.mask())) {
2420     case MaskFormat::AllTrue:
2421       return failure(); // no unmasked equivalent
2422     case MaskFormat::AllFalse:
2423       rewriter.eraseOp(scatter);
2424       return success();
2425     case MaskFormat::Unknown:
2426       return failure();
2427     }
2428     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
2429   }
2430 };
2431 } // namespace
2432 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2433 void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2434                                             MLIRContext *context) {
2435   results.insert<ScatterFolder>(context);
2436 }
2437 
2438 //===----------------------------------------------------------------------===//
2439 // ExpandLoadOp
2440 //===----------------------------------------------------------------------===//
2441 
verify(ExpandLoadOp op)2442 static LogicalResult verify(ExpandLoadOp op) {
2443   VectorType maskVType = op.getMaskVectorType();
2444   VectorType passVType = op.getPassThruVectorType();
2445   VectorType resVType = op.getResultVectorType();
2446 
2447   if (resVType.getElementType() != op.getMemRefType().getElementType())
2448     return op.emitOpError("base and result element type should match");
2449 
2450   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
2451     return op.emitOpError("expected result dim to match mask dim");
2452   if (resVType != passVType)
2453     return op.emitOpError("expected pass_thru of same type as result type");
2454   return success();
2455 }
2456 
2457 namespace {
2458 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
2459 public:
2460   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const2461   LogicalResult matchAndRewrite(ExpandLoadOp expand,
2462                                 PatternRewriter &rewriter) const override {
2463     Value newBase;
2464     switch (get1DMaskFormat(expand.mask())) {
2465     case MaskFormat::AllTrue:
2466       if (!castedToMemRef(expand.getLoc(), expand.base(),
2467                           expand.getMemRefType(), expand.getResultVectorType(),
2468                           rewriter, newBase))
2469         return failure();
2470       rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
2471       return success();
2472     case MaskFormat::AllFalse:
2473       rewriter.replaceOp(expand, expand.pass_thru());
2474       return success();
2475     case MaskFormat::Unknown:
2476       return failure();
2477     }
2478     llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
2479   }
2480 };
2481 } // namespace
2482 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2483 void ExpandLoadOp::getCanonicalizationPatterns(
2484     OwningRewritePatternList &results, MLIRContext *context) {
2485   results.insert<ExpandLoadFolder>(context);
2486 }
2487 
2488 //===----------------------------------------------------------------------===//
2489 // CompressStoreOp
2490 //===----------------------------------------------------------------------===//
2491 
verify(CompressStoreOp op)2492 static LogicalResult verify(CompressStoreOp op) {
2493   VectorType maskVType = op.getMaskVectorType();
2494   VectorType valueVType = op.getValueVectorType();
2495 
2496   if (valueVType.getElementType() != op.getMemRefType().getElementType())
2497     return op.emitOpError("base and value element type should match");
2498 
2499   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2500     return op.emitOpError("expected value dim to match mask dim");
2501   return success();
2502 }
2503 
2504 namespace {
2505 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
2506 public:
2507   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const2508   LogicalResult matchAndRewrite(CompressStoreOp compress,
2509                                 PatternRewriter &rewriter) const override {
2510     Value newBase;
2511     switch (get1DMaskFormat(compress.mask())) {
2512     case MaskFormat::AllTrue:
2513       if (!castedToMemRef(compress.getLoc(), compress.base(),
2514                           compress.getMemRefType(),
2515                           compress.getValueVectorType(), rewriter, newBase))
2516         return failure();
2517       rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
2518       return success();
2519     case MaskFormat::AllFalse:
2520       rewriter.eraseOp(compress);
2521       return success();
2522     case MaskFormat::Unknown:
2523       return failure();
2524     }
2525     llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
2526   }
2527 };
2528 } // namespace
2529 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2530 void CompressStoreOp::getCanonicalizationPatterns(
2531     OwningRewritePatternList &results, MLIRContext *context) {
2532   results.insert<CompressStoreFolder>(context);
2533 }
2534 
2535 //===----------------------------------------------------------------------===//
2536 // ShapeCastOp
2537 //===----------------------------------------------------------------------===//
2538 
2539 /// Returns true if each element of 'a' is equal to the product of a contiguous
2540 /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)2541 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2542   unsigned rankA = a.size();
2543   unsigned rankB = b.size();
2544   assert(rankA < rankB);
2545 
2546   unsigned i = 0;
2547   unsigned j = 0;
2548   while (i < rankA && j < rankB) {
2549     int64_t dimA = a[i];
2550     int64_t dimB = 1;
2551     while (dimB < dimA && j < rankB)
2552       dimB *= b[j++];
2553     if (dimA != dimB)
2554       break;
2555     ++i;
2556 
2557     // Handle the case when trailing dimensions are of size 1.
2558     // Include them into the contiguous sequence.
2559     auto isOne = [](int64_t v) { return v == 1; };
2560     if (i < rankA && llvm::all_of(a.slice(i), isOne))
2561       i = rankA;
2562     if (j < rankB && llvm::all_of(b.slice(j), isOne))
2563       j = rankB;
2564   }
2565 
2566   return i == rankA && j == rankB;
2567 }
2568 
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)2569 static LogicalResult verifyVectorShapeCast(Operation *op,
2570                                            VectorType sourceVectorType,
2571                                            VectorType resultVectorType) {
2572   // Check that element type is the same.
2573   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
2574     return op->emitOpError("source/result vectors must have same element type");
2575   auto sourceShape = sourceVectorType.getShape();
2576   auto resultShape = resultVectorType.getShape();
2577 
2578   // Check that product of source dim sizes matches product of result dim sizes.
2579   int64_t sourceDimProduct = std::accumulate(
2580       sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
2581   int64_t resultDimProduct = std::accumulate(
2582       resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
2583   if (sourceDimProduct != resultDimProduct)
2584     return op->emitOpError("source/result number of elements must match");
2585 
2586   // Check that expanding/contracting rank cases.
2587   unsigned sourceRank = sourceVectorType.getRank();
2588   unsigned resultRank = resultVectorType.getRank();
2589   if (sourceRank < resultRank) {
2590     if (!isValidShapeCast(sourceShape, resultShape))
2591       return op->emitOpError("invalid shape cast");
2592   } else if (sourceRank > resultRank) {
2593     if (!isValidShapeCast(resultShape, sourceShape))
2594       return op->emitOpError("invalid shape cast");
2595   }
2596   return success();
2597 }
2598 
verify(ShapeCastOp op)2599 static LogicalResult verify(ShapeCastOp op) {
2600   auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
2601   auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
2602 
2603   // Check if source/result are of vector type.
2604   if (sourceVectorType && resultVectorType)
2605     return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
2606 
2607   // Check if source/result are "tuple of vectors" type.
2608   auto sourceTupleType = op.source().getType().dyn_cast_or_null<TupleType>();
2609   auto resultTupleType = op.result().getType().dyn_cast_or_null<TupleType>();
2610   if (!sourceTupleType || !resultTupleType)
2611     return op.emitOpError("source/result must be of same type");
2612 
2613   // Check that source/result tuple sizes are the same.
2614   if (sourceTupleType.size() != resultTupleType.size())
2615     return op.emitOpError("source/result tuples must be the same size");
2616 
2617   // Check each source/result tuple element pair.
2618   for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i)
2619     if (failed(verifyVectorShapeCast(
2620             op, sourceTupleType.getType(i).cast<VectorType>(),
2621             resultTupleType.getType(i).cast<VectorType>())))
2622       return failure();
2623 
2624   return success();
2625 }
2626 
fold(ArrayRef<Attribute> operands)2627 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
2628   // Nop shape cast.
2629   if (source().getType() == result().getType())
2630     return source();
2631 
2632   // Canceling shape casts.
2633   if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
2634     if (result().getType() == otherOp.source().getType())
2635       return otherOp.source();
2636 
2637   return {};
2638 }
2639 
2640 namespace {
2641 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
2642 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
2643 public:
2644   using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
2645 
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const2646   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
2647                                 PatternRewriter &rewriter) const override {
2648     auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
2649     if (!constantOp)
2650       return failure();
2651     // Only handle splat for now.
2652     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2653     if (!dense)
2654       return failure();
2655     auto newAttr = DenseElementsAttr::get(
2656         shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
2657     rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
2658     return success();
2659   }
2660 };
2661 
2662 } // namespace
2663 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2664 void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2665                                               MLIRContext *context) {
2666   // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
2667   results.insert<ShapeCastConstantFolder>(context);
2668 }
2669 
2670 //===----------------------------------------------------------------------===//
2671 // VectorBitCastOp
2672 //===----------------------------------------------------------------------===//
2673 
verify(BitCastOp op)2674 static LogicalResult verify(BitCastOp op) {
2675   auto sourceVectorType = op.getSourceVectorType();
2676   auto resultVectorType = op.getResultVectorType();
2677 
2678   for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
2679     if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
2680       return op.emitOpError("dimension size mismatch at: ") << i;
2681   }
2682 
2683   if (sourceVectorType.getElementTypeBitWidth() *
2684           sourceVectorType.getShape().back() !=
2685       resultVectorType.getElementTypeBitWidth() *
2686           resultVectorType.getShape().back())
2687     return op.emitOpError(
2688         "source/result bitwidth of the minor 1-D vectors must be equal");
2689 
2690   return success();
2691 }
2692 
fold(ArrayRef<Attribute> operands)2693 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
2694   // Nop cast.
2695   if (source().getType() == result().getType())
2696     return source();
2697 
2698   // Canceling bitcasts.
2699   if (auto otherOp = source().getDefiningOp<BitCastOp>())
2700     if (result().getType() == otherOp.source().getType())
2701       return otherOp.source();
2702 
2703   return {};
2704 }
2705 
2706 //===----------------------------------------------------------------------===//
2707 // TypeCastOp
2708 //===----------------------------------------------------------------------===//
2709 
extractShape(MemRefType memRefType)2710 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
2711   auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
2712   SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
2713                               memRefType.getShape().end());
2714   if (vectorType)
2715     res.append(vectorType.getShape().begin(), vectorType.getShape().end());
2716   return res;
2717 }
2718 
2719 /// Build the canonical memRefType with a single vector.
2720 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)2721 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
2722                        Value source) {
2723   result.addOperands(source);
2724   MemRefType memRefType = source.getType().cast<MemRefType>();
2725   VectorType vectorType =
2726       VectorType::get(extractShape(memRefType),
2727                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
2728   result.addTypes(
2729       MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
2730 }
2731 
verify(TypeCastOp op)2732 static LogicalResult verify(TypeCastOp op) {
2733   MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
2734   if (!canonicalType.getAffineMaps().empty())
2735     return op.emitOpError("expects operand to be a memref with no layout");
2736   if (!op.getResultMemRefType().getAffineMaps().empty())
2737     return op.emitOpError("expects result to be a memref with no layout");
2738   if (op.getResultMemRefType().getMemorySpace() !=
2739       op.getMemRefType().getMemorySpace())
2740     return op.emitOpError("expects result in same memory space");
2741 
2742   auto sourceType = op.getMemRefType();
2743   auto resultType = op.getResultMemRefType();
2744   if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
2745       getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
2746     return op.emitOpError(
2747                "expects result and operand with same underlying scalar type: ")
2748            << resultType;
2749   if (extractShape(sourceType) != extractShape(resultType))
2750     return op.emitOpError(
2751                "expects concatenated result and operand shapes to be equal: ")
2752            << resultType;
2753   return success();
2754 }
2755 
2756 //===----------------------------------------------------------------------===//
2757 // TupleOp
2758 //===----------------------------------------------------------------------===//
2759 
parseTupleOp(OpAsmParser & parser,OperationState & result)2760 static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
2761   SmallVector<OpAsmParser::OperandType, 4> operandInfos;
2762   SmallVector<Type, 4> types;
2763   auto loc = parser.getCurrentLocation();
2764   auto *ctx = parser.getBuilder().getContext();
2765   return failure(
2766       parser.parseOperandList(operandInfos) ||
2767       parser.parseOptionalAttrDict(result.attributes) ||
2768       parser.parseColonTypeList(types) ||
2769       parser.resolveOperands(operandInfos, types, loc, result.operands) ||
2770       parser.addTypeToList(TupleType::get(types, ctx), result.types));
2771 }
2772 
print(OpAsmPrinter & p,TupleOp op)2773 static void print(OpAsmPrinter &p, TupleOp op) {
2774   p << op.getOperationName() << ' ';
2775   p.printOperands(op.getOperands());
2776   p.printOptionalAttrDict(op.getAttrs());
2777   p << " : ";
2778   llvm::interleaveComma(op->getOperandTypes(), p);
2779 }
2780 
verify(TupleOp op)2781 static LogicalResult verify(TupleOp op) { return success(); }
2782 
2783 //===----------------------------------------------------------------------===//
2784 // TransposeOp
2785 //===----------------------------------------------------------------------===//
2786 
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)2787 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
2788                                 Value vector, ArrayRef<int64_t> transp) {
2789   VectorType vt = vector.getType().cast<VectorType>();
2790   SmallVector<int64_t, 4> transposedShape(vt.getRank());
2791   for (unsigned i = 0; i < transp.size(); ++i)
2792     transposedShape[i] = vt.getShape()[transp[i]];
2793 
2794   result.addOperands(vector);
2795   result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
2796   result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
2797 }
2798 
2799 // Eliminates transpose operations, which produce values identical to their
2800 // input values. This happens when the dimensions of the input vector remain in
2801 // their original order after the transpose operation.
fold(ArrayRef<Attribute> operands)2802 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
2803   SmallVector<int64_t, 4> transp;
2804   getTransp(transp);
2805 
2806   // Check if the permutation of the dimensions contains sequential values:
2807   // {0, 1, 2, ...}.
2808   for (int64_t i = 0, e = transp.size(); i < e; i++) {
2809     if (transp[i] != i)
2810       return {};
2811   }
2812 
2813   return vector();
2814 }
2815 
verify(vector::TransposeOp op)2816 static LogicalResult verify(vector::TransposeOp op) {
2817   VectorType vectorType = op.getVectorType();
2818   VectorType resultType = op.getResultType();
2819   int64_t rank = resultType.getRank();
2820   if (vectorType.getRank() != rank)
2821     return op.emitOpError("vector result rank mismatch: ") << rank;
2822   // Verify transposition array.
2823   auto transpAttr = op.transp().getValue();
2824   int64_t size = transpAttr.size();
2825   if (rank != size)
2826     return op.emitOpError("transposition length mismatch: ") << size;
2827   SmallVector<bool, 8> seen(rank, false);
2828   for (auto ta : llvm::enumerate(transpAttr)) {
2829     int64_t i = ta.value().cast<IntegerAttr>().getInt();
2830     if (i < 0 || i >= rank)
2831       return op.emitOpError("transposition index out of range: ") << i;
2832     if (seen[i])
2833       return op.emitOpError("duplicate position index: ") << i;
2834     seen[i] = true;
2835     if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
2836       return op.emitOpError("dimension size mismatch at: ") << i;
2837   }
2838   return success();
2839 }
2840 
2841 namespace {
2842 
2843 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
2844 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
2845 public:
2846   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
2847 
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const2848   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2849                                 PatternRewriter &rewriter) const override {
2850     // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
2851     auto getPermutation = [](vector::TransposeOp transpose) {
2852       SmallVector<int64_t, 4> permutation;
2853       transpose.getTransp(permutation);
2854       return permutation;
2855     };
2856 
2857     // Composes two permutations: result[i] = permutation1[permutation2[i]].
2858     auto composePermutations = [](ArrayRef<int64_t> permutation1,
2859                                   ArrayRef<int64_t> permutation2) {
2860       SmallVector<int64_t, 4> result;
2861       for (auto index : permutation2)
2862         result.push_back(permutation1[index]);
2863       return result;
2864     };
2865 
2866     // Return if the input of 'transposeOp' is not defined by another transpose.
2867     vector::TransposeOp parentTransposeOp =
2868         transposeOp.vector().getDefiningOp<vector::TransposeOp>();
2869     if (!parentTransposeOp)
2870       return failure();
2871 
2872     SmallVector<int64_t, 4> permutation = composePermutations(
2873         getPermutation(parentTransposeOp), getPermutation(transposeOp));
2874     // Replace 'transposeOp' with a new transpose operation.
2875     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
2876         transposeOp, transposeOp.getResult().getType(),
2877         parentTransposeOp.vector(),
2878         vector::getVectorSubscriptAttr(rewriter, permutation));
2879     return success();
2880   }
2881 };
2882 
2883 } // end anonymous namespace
2884 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2885 void vector::TransposeOp::getCanonicalizationPatterns(
2886     OwningRewritePatternList &results, MLIRContext *context) {
2887   results.insert<TransposeFolder>(context);
2888 }
2889 
getTransp(SmallVectorImpl<int64_t> & results)2890 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
2891   populateFromInt64AttrArray(transp(), results);
2892 }
2893 
2894 //===----------------------------------------------------------------------===//
2895 // TupleGetOp
2896 //===----------------------------------------------------------------------===//
2897 
parseTupleGetOp(OpAsmParser & parser,OperationState & result)2898 static ParseResult parseTupleGetOp(OpAsmParser &parser,
2899                                    OperationState &result) {
2900   OpAsmParser::OperandType operandInfo;
2901   IntegerAttr indexAttr;
2902   StringRef indexAttrName = TupleGetOp::getIndexAttrName();
2903   Type indexType = parser.getBuilder().getIndexType();
2904   TupleType tupleType;
2905   if (parser.parseOperand(operandInfo) || parser.parseComma() ||
2906       parser.parseAttribute(indexAttr, indexType, indexAttrName,
2907                             result.attributes) ||
2908       parser.parseOptionalAttrDict(result.attributes) ||
2909       parser.parseColonType(tupleType) ||
2910       parser.resolveOperand(operandInfo, tupleType, result.operands))
2911     return failure();
2912   if (indexAttr.getInt() < 0 ||
2913       indexAttr.getInt() >= static_cast<int64_t>(tupleType.size()))
2914     return failure();
2915   parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types);
2916   return success();
2917 }
2918 
print(OpAsmPrinter & p,TupleGetOp op)2919 static void print(OpAsmPrinter &p, TupleGetOp op) {
2920   p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index();
2921   p.printOptionalAttrDict(op.getAttrs(),
2922                           /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
2923   p << " : " << op.getOperand().getType();
2924 }
2925 
verify(TupleGetOp op)2926 static LogicalResult verify(TupleGetOp op) {
2927   auto tupleType = op.getOperand().getType().cast<TupleType>();
2928   if (op.getIndex() < 0 ||
2929       op.getIndex() >= static_cast<int64_t>(tupleType.size()))
2930     return op.emitOpError("tuple get index out of range");
2931   return success();
2932 }
2933 
fold(ArrayRef<Attribute> operands)2934 OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
2935   // Rewrite:
2936   //    %t = vector.tuple .., %e_i, ..
2937   //    %x = vector.tuple_get %t, i
2938   // into:
2939   //    %t = vector.tuple .., %e_i, ..  // one less use
2940   //    %x = %e_i
2941   if (auto tupleOp = getOperand().getDefiningOp<TupleOp>())
2942     return tupleOp.getOperand(getIndex());
2943   return {};
2944 }
2945 
2946 //===----------------------------------------------------------------------===//
2947 // ConstantMaskOp
2948 //===----------------------------------------------------------------------===//
2949 
verify(ConstantMaskOp & op)2950 static LogicalResult verify(ConstantMaskOp &op) {
2951   // Verify that array attr size matches the rank of the vector result.
2952   auto resultType = op.getResult().getType().cast<VectorType>();
2953   if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
2954     return op.emitOpError(
2955         "must specify array attr of size equal vector result rank");
2956   // Verify that each array attr element is in bounds of corresponding vector
2957   // result dimension size.
2958   auto resultShape = resultType.getShape();
2959   SmallVector<int64_t, 4> maskDimSizes;
2960   for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
2961     int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
2962     if (attrValue < 0 || attrValue > resultShape[it.index()])
2963       return op.emitOpError(
2964           "array attr of size out of bounds of vector result dimension size");
2965     maskDimSizes.push_back(attrValue);
2966   }
2967   // Verify that if one mask dim size is zero, they all should be zero (because
2968   // the mask region is a conjunction of each mask dimension interval).
2969   bool any_zeros = llvm::is_contained(maskDimSizes, 0);
2970   bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
2971   if (any_zeros && !all_zeros)
2972     return op.emitOpError("expected all mask dim sizes to be zeros, "
2973                           "as a result of conjunction with zero mask dim");
2974   return success();
2975 }
2976 
2977 //===----------------------------------------------------------------------===//
2978 // CreateMaskOp
2979 //===----------------------------------------------------------------------===//
2980 
verify(CreateMaskOp op)2981 static LogicalResult verify(CreateMaskOp op) {
2982   // Verify that an operand was specified for each result vector each dimension.
2983   if (op.getNumOperands() !=
2984       op.getResult().getType().cast<VectorType>().getRank())
2985     return op.emitOpError(
2986         "must specify an operand for each result vector dimension");
2987   return success();
2988 }
2989 
2990 namespace {
2991 
2992 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
2993 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
2994 public:
2995   using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
2996 
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const2997   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
2998                                 PatternRewriter &rewriter) const override {
2999     // Return if any of 'createMaskOp' operands are not defined by a constant.
3000     auto is_not_def_by_constant = [](Value operand) {
3001       return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
3002     };
3003     if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
3004       return failure();
3005     // Gather constant mask dimension sizes.
3006     SmallVector<int64_t, 4> maskDimSizes;
3007     for (auto operand : createMaskOp.operands()) {
3008       auto defOp = operand.getDefiningOp();
3009       maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
3010     }
3011     // Replace 'createMaskOp' with ConstantMaskOp.
3012     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3013         createMaskOp, createMaskOp.getResult().getType(),
3014         vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
3015     return success();
3016   }
3017 };
3018 
3019 } // end anonymous namespace
3020 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3021 void CreateMaskOp::getCanonicalizationPatterns(
3022     OwningRewritePatternList &results, MLIRContext *context) {
3023   results.insert<CreateMaskFolder>(context);
3024 }
3025 
populateVectorToVectorCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)3026 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
3027     OwningRewritePatternList &patterns, MLIRContext *context) {
3028   patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
3029                   GatherFolder, ScatterFolder, ExpandLoadFolder,
3030                   CompressStoreFolder, StridedSliceConstantMaskFolder,
3031                   TransposeFolder>(context);
3032 }
3033 
3034 #define GET_OP_CLASSES
3035 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
3036