1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/Dialect/Vector/VectorUtils.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Support/MathExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include <numeric>
29
30 using namespace mlir;
31 using namespace mlir::vector;
32
33 /// Helper enum to classify mask value.
34 enum class MaskFormat {
35 AllTrue = 0,
36 AllFalse = 1,
37 Unknown = 2,
38 };
39
40 /// Helper method to classify a 1-D mask value. Currently, the method
41 /// looks "under the hood" of a constant value with dense attributes
42 /// and a constant mask operation (since the client may be called at
43 /// various stages during progressive lowering).
get1DMaskFormat(Value mask)44 static MaskFormat get1DMaskFormat(Value mask) {
45 if (auto c = mask.getDefiningOp<ConstantOp>()) {
46 // Inspect constant dense values. We count up for bits that
47 // are set, count down for bits that are cleared, and bail
48 // when a mix is detected.
49 if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
50 int64_t val = 0;
51 for (bool b : denseElts.getValues<bool>())
52 if (b && val >= 0)
53 val++;
54 else if (!b && val <= 0)
55 val--;
56 else
57 return MaskFormat::Unknown;
58 if (val > 0)
59 return MaskFormat::AllTrue;
60 if (val < 0)
61 return MaskFormat::AllFalse;
62 }
63 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
64 // Inspect constant mask index. If the index exceeds the
65 // dimension size, all bits are set. If the index is zero
66 // or less, no bits are set.
67 ArrayAttr masks = m.mask_dim_sizes();
68 assert(masks.size() == 1);
69 int64_t i = masks[0].cast<IntegerAttr>().getInt();
70 int64_t u = m.getType().cast<VectorType>().getDimSize(0);
71 if (i >= u)
72 return MaskFormat::AllTrue;
73 if (i <= 0)
74 return MaskFormat::AllFalse;
75 }
76 return MaskFormat::Unknown;
77 }
78
79 /// Helper method to cast a 1-D memref<10xf32> "base" into a
80 /// memref<vector<10xf32>> in the output parameter "newBase",
81 /// using the 'element' vector type "vt". Returns true on success.
castedToMemRef(Location loc,Value base,MemRefType mt,VectorType vt,PatternRewriter & rewriter,Value & newBase)82 static bool castedToMemRef(Location loc, Value base, MemRefType mt,
83 VectorType vt, PatternRewriter &rewriter,
84 Value &newBase) {
85 // The vector.type_cast operation does not accept unknown memref<?xf32>.
86 // TODO: generalize the cast and accept this case too
87 if (!mt.hasStaticShape())
88 return false;
89 newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
90 return true;
91 }
92
93 //===----------------------------------------------------------------------===//
94 // VectorDialect
95 //===----------------------------------------------------------------------===//
96
initialize()97 void VectorDialect::initialize() {
98 addOperations<
99 #define GET_OP_LIST
100 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
101 >();
102 }
103
104 /// Materialize a single constant operation from a given attribute value with
105 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)106 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
107 Attribute value, Type type,
108 Location loc) {
109 return builder.create<ConstantOp>(loc, type, value);
110 }
111
getVectorSubscriptType(Builder & builder)112 IntegerType vector::getVectorSubscriptType(Builder &builder) {
113 return builder.getIntegerType(64);
114 }
115
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)116 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
117 ArrayRef<int64_t> values) {
118 return builder.getI64ArrayAttr(values);
119 }
120
121 //===----------------------------------------------------------------------===//
122 // ReductionOp
123 //===----------------------------------------------------------------------===//
124
verify(ReductionOp op)125 static LogicalResult verify(ReductionOp op) {
126 // Verify for 1-D vector.
127 int64_t rank = op.getVectorType().getRank();
128 if (rank != 1)
129 return op.emitOpError("unsupported reduction rank: ") << rank;
130
131 // Verify supported reduction kind.
132 auto kind = op.kind();
133 Type eltType = op.dest().getType();
134 if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
135 if (!eltType.isIntOrIndexOrFloat())
136 return op.emitOpError("unsupported reduction type");
137 } else if (kind == "and" || kind == "or" || kind == "xor") {
138 if (!eltType.isIntOrIndex())
139 return op.emitOpError("unsupported reduction type");
140 } else {
141 return op.emitOpError("unknown reduction kind: ") << kind;
142 }
143
144 // Verify optional accumulator.
145 if (!op.acc().empty()) {
146 if (kind != "add" && kind != "mul")
147 return op.emitOpError("no accumulator for reduction kind: ") << kind;
148 if (!eltType.isa<FloatType>())
149 return op.emitOpError("no accumulator for type: ") << eltType;
150 }
151
152 return success();
153 }
154
parseReductionOp(OpAsmParser & parser,OperationState & result)155 static ParseResult parseReductionOp(OpAsmParser &parser,
156 OperationState &result) {
157 SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
158 Type redType;
159 Type resType;
160 Attribute attr;
161 if (parser.parseAttribute(attr, "kind", result.attributes) ||
162 parser.parseComma() || parser.parseOperandList(operandsInfo) ||
163 parser.parseColonType(redType) ||
164 parser.parseKeywordType("into", resType) ||
165 (operandsInfo.size() > 0 &&
166 parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
167 (operandsInfo.size() > 1 &&
168 parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
169 parser.addTypeToList(resType, result.types))
170 return failure();
171 if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
172 return parser.emitError(parser.getNameLoc(),
173 "unsupported number of operands");
174 return success();
175 }
176
print(OpAsmPrinter & p,ReductionOp op)177 static void print(OpAsmPrinter &p, ReductionOp op) {
178 p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector();
179 if (!op.acc().empty())
180 p << ", " << op.acc();
181 p << " : " << op.vector().getType() << " into " << op.dest().getType();
182 }
183
184 //===----------------------------------------------------------------------===//
185 // ContractionOp
186 //===----------------------------------------------------------------------===//
187
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)188 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
189 Value lhs, Value rhs, Value acc,
190 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
191 ArrayRef<StringRef> iteratorTypes) {
192 result.addOperands({lhs, rhs, acc});
193 result.addTypes(acc.getType());
194 result.addAttribute(getIndexingMapsAttrName(),
195 builder.getAffineMapArrayAttr(
196 AffineMap::inferFromExprList(indexingExprs)));
197 result.addAttribute(getIteratorTypesAttrName(),
198 builder.getStrArrayAttr(iteratorTypes));
199 }
200
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)201 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
202 Value lhs, Value rhs, Value acc,
203 ArrayAttr indexingMaps,
204 ArrayAttr iteratorTypes) {
205 result.addOperands({lhs, rhs, acc});
206 result.addTypes(acc.getType());
207 result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
208 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
209 }
210
parseContractionOp(OpAsmParser & parser,OperationState & result)211 static ParseResult parseContractionOp(OpAsmParser &parser,
212 OperationState &result) {
213 OpAsmParser::OperandType lhsInfo;
214 OpAsmParser::OperandType rhsInfo;
215 OpAsmParser::OperandType accInfo;
216 SmallVector<OpAsmParser::OperandType, 2> masksInfo;
217 SmallVector<Type, 2> types;
218 Type resultType;
219 auto loc = parser.getCurrentLocation();
220 DictionaryAttr dictAttr;
221 // TODO: Unify linalg op attribute parsing.
222 if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
223 parser.parseOperand(lhsInfo) || parser.parseComma() ||
224 parser.parseOperand(rhsInfo) || parser.parseComma() ||
225 parser.parseOperand(accInfo) ||
226 parser.parseTrailingOperandList(masksInfo) ||
227 parser.parseOptionalAttrDict(result.attributes) ||
228 parser.parseColonTypeList(types) ||
229 parser.parseKeywordType("into", resultType) ||
230 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
231 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
232 parser.resolveOperand(accInfo, resultType, result.operands) ||
233 parser.addTypeToList(resultType, result.types))
234 return failure();
235 result.attributes.assign(dictAttr.getValue().begin(),
236 dictAttr.getValue().end());
237 if (masksInfo.empty())
238 return success();
239 if (masksInfo.size() != 2)
240 return parser.emitError(parser.getNameLoc(),
241 "expected zero or exactly 2 vector mask operands");
242 auto lhsType = types[0].cast<VectorType>();
243 auto rhsType = types[1].cast<VectorType>();
244 auto maskElementType = parser.getBuilder().getI1Type();
245 std::array<Type, 2> maskTypes = {
246 VectorType::get(lhsType.getShape(), maskElementType),
247 VectorType::get(rhsType.getShape(), maskElementType)};
248 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
249 return failure();
250 return success();
251 }
252
print(OpAsmPrinter & p,ContractionOp op)253 static void print(OpAsmPrinter &p, ContractionOp op) {
254 // TODO: Unify printing code with linalg ops.
255 auto attrNames = op.getTraitAttrNames();
256 llvm::StringSet<> traitAttrsSet;
257 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
258 SmallVector<NamedAttribute, 8> attrs;
259 for (auto attr : op.getAttrs())
260 if (traitAttrsSet.count(attr.first.strref()) > 0)
261 attrs.push_back(attr);
262
263 auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
264 p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
265 p << op.rhs() << ", " << op.acc();
266 if (op.masks().size() == 2)
267 p << ", " << op.masks();
268
269 p.printOptionalAttrDict(op.getAttrs(), attrNames);
270 p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
271 << op.getResultType();
272 }
273
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)274 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
275 const std::vector<std::pair<int64_t, int64_t>> &map) {
276 for (auto &dimPair : map) {
277 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
278 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
279 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
280 return false;
281 }
282 return true;
283 }
284
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)285 static LogicalResult verifyOutputShape(
286 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
287 Type resType,
288 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
289 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
290 DenseSet<int64_t> lhsContractingDimSet;
291 DenseSet<int64_t> rhsContractingDimSet;
292 for (auto &dimPair : contractingDimMap) {
293 lhsContractingDimSet.insert(dimPair.first);
294 rhsContractingDimSet.insert(dimPair.second);
295 }
296 DenseSet<int64_t> rhsBatchDimSet;
297 for (auto &dimPair : batchDimMap)
298 rhsBatchDimSet.insert(dimPair.second);
299
300 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
301 SmallVector<int64_t, 4> expectedResultDims;
302 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
303 if (lhsContractingDimSet.count(i) > 0)
304 continue;
305 expectedResultDims.push_back(lhsType.getDimSize(i));
306 }
307
308 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
309 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
310 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
311 continue;
312 expectedResultDims.push_back(rhsType.getDimSize(i));
313 }
314
315 // Verify 'expectedResultDims'.
316 if (expectedResultDims.size() == 0) {
317 // No batch or free dimension implies a scalar result.
318 if (resType.isa<VectorType>() || accType.isa<VectorType>())
319 return op.emitOpError("invalid accumulator/result vector shape");
320 } else {
321 // At least one batch or free dimension implies a vector result.
322 auto resVectorType = resType.dyn_cast<VectorType>();
323 auto accVectorType = accType.dyn_cast<VectorType>();
324 if (!resVectorType || !accVectorType)
325 return op.emitOpError("invalid accumulator/result vector shape");
326
327 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
328 // types fully define the result vector type. This assumes the affine maps
329 // are well-formed, which must have been verified already.
330 MLIRContext *ctx = op.getContext();
331 AffineMap lhsMap = op.getIndexingMaps()[0];
332 AffineMap rhsMap = op.getIndexingMaps()[1];
333 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
334 for (auto pair :
335 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
336 VectorType v = pair.first;
337 auto map = pair.second;
338 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
339 unsigned pos = map.getDimPosition(idx);
340 if (!extents[pos])
341 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
342 }
343 }
344 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
345 "expected extent along all dimensions.");
346
347 AffineMap resMap = op.getIndexingMaps()[2];
348 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
349 /*symCount=*/0, extents, ctx);
350 // Compose the resMap with the extentsMap, which is a constant map.
351 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
352 assert(llvm::all_of(
353 expectedMap.getResults(),
354 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
355 "expected constant extent along all dimensions.");
356 // Extract the expected shape and build the type.
357 auto expectedShape = llvm::to_vector<4>(
358 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
359 return e.cast<AffineConstantExpr>().getValue();
360 }));
361 auto expected =
362 VectorType::get(expectedShape, resVectorType.getElementType());
363 if (resVectorType != expected || accVectorType != expected)
364 return op.emitOpError(
365 "invalid accumulator/result vector shape, expected: ")
366 << expected;
367 }
368 return success();
369 }
370
verify(ContractionOp op)371 static LogicalResult verify(ContractionOp op) {
372 auto lhsType = op.getLhsType();
373 auto rhsType = op.getRhsType();
374 auto accType = op.getAccType();
375 auto resType = op.getResultType();
376
377 // Verify that an indexing map was specified for each vector operand.
378 if (op.indexing_maps().size() != 3)
379 return op.emitOpError("expected an indexing map for each vector operand");
380
381 // Verify that each index map has 'numIterators' inputs, no symbols, and
382 // that the number of map outputs equals the rank of its associated
383 // vector operand.
384 unsigned numIterators = op.iterator_types().getValue().size();
385 for (auto it : llvm::enumerate(op.indexing_maps())) {
386 auto index = it.index();
387 auto map = it.value().cast<AffineMapAttr>().getValue();
388 if (map.getNumSymbols() != 0)
389 return op.emitOpError("expected indexing map ")
390 << index << " to have no symbols";
391 auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
392 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
393 // Verify that the map has the right number of inputs, outputs, and indices.
394 // This also correctly accounts for (..) -> () for rank-0 results.
395 if (map.getNumDims() != numIterators)
396 return op.emitOpError("expected indexing map ")
397 << index << " to have " << numIterators << " number of inputs";
398 if (map.getNumResults() != rank)
399 return op.emitOpError("expected indexing map ")
400 << index << " to have " << rank << " number of outputs";
401 if (!map.isProjectedPermutation())
402 return op.emitOpError("expected indexing map ")
403 << index << " to be a projected permutation of its inputs";
404 }
405
406 auto contractingDimMap = op.getContractingDimMap();
407 auto batchDimMap = op.getBatchDimMap();
408
409 // Verify at least one contracting dimension pair was specified.
410 if (contractingDimMap.empty())
411 return op.emitOpError("expected at least one contracting dimension pair");
412
413 // Verify contracting dimension map was properly constructed.
414 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
415 return op.emitOpError("invalid contracting dimension map");
416
417 // Verify batch dimension map was properly constructed.
418 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
419 return op.emitOpError("invalid batch dimension map");
420
421 // Verify 'accType' and 'resType' shape.
422 if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
423 contractingDimMap, batchDimMap)))
424 return failure();
425
426 // Verify that either two vector masks are set or none are set.
427 auto lhsMaskType = op.getLHSVectorMaskType();
428 auto rhsMaskType = op.getRHSVectorMaskType();
429 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
430 return op.emitOpError("invalid number of vector masks specified");
431 if (lhsMaskType && rhsMaskType) {
432 // Verify mask rank == argument rank.
433 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
434 rhsMaskType.getShape().size() != rhsType.getShape().size())
435 return op.emitOpError("invalid vector mask rank");
436 }
437 return success();
438 }
439
getTraitAttrNames()440 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
441 static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
442 getIteratorTypesAttrName()};
443 return llvm::makeArrayRef(names);
444 }
445
getResultIndex(AffineMap map,AffineExpr targetExpr)446 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
447 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
448 if (targetExpr == map.getResult(i))
449 return i;
450 return -1;
451 }
452
453 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)454 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
455 StringRef targetIteratorTypeName, MLIRContext *context) {
456 std::vector<std::pair<int64_t, int64_t>> dimMap;
457 for (auto it : llvm::enumerate(iteratorTypes)) {
458 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
459 if (iteratorTypeName != targetIteratorTypeName)
460 continue;
461 // Search lhs/rhs map results for 'targetExpr'.
462 auto targetExpr = getAffineDimExpr(it.index(), context);
463 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
464 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
465 if (lhsDim >= 0 && rhsDim >= 0)
466 dimMap.push_back({lhsDim, rhsDim});
467 }
468 return dimMap;
469 }
470
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)471 void ContractionOp::getIterationBounds(
472 SmallVectorImpl<int64_t> &iterationBounds) {
473 auto lhsShape = getLhsType().getShape();
474 auto resVectorType = getResultType().dyn_cast<VectorType>();
475 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
476 SmallVector<int64_t, 2> iterationShape;
477 for (auto it : llvm::enumerate(iterator_types())) {
478 // Search lhs/rhs map results for 'targetExpr'.
479 auto targetExpr = getAffineDimExpr(it.index(), getContext());
480 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
481 if (iteratorTypeName == getReductionIteratorTypeName()) {
482 // Get reduction dim size from lhs shape (same size in rhsShape).
483 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
484 assert(lhsDimIndex >= 0);
485 iterationBounds.push_back(lhsShape[lhsDimIndex]);
486 continue;
487 }
488 // Get parallel dimension size from result shape.
489 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
490 assert(resDimIndex >= 0);
491 assert(resVectorType != nullptr);
492 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
493 }
494 }
495
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)496 void ContractionOp::getIterationIndexMap(
497 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
498 unsigned numMaps = indexing_maps().getValue().size();
499 iterationIndexMap.resize(numMaps);
500 for (auto it : llvm::enumerate(indexing_maps())) {
501 auto index = it.index();
502 auto map = it.value().cast<AffineMapAttr>().getValue();
503 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
504 auto dim = map.getResult(i).cast<AffineDimExpr>();
505 iterationIndexMap[index][dim.getPosition()] = i;
506 }
507 }
508 }
509
getContractingDimMap()510 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
511 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
512 return getDimMap(indexingMaps, iterator_types(),
513 getReductionIteratorTypeName(), getContext());
514 }
515
getBatchDimMap()516 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
517 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
518 return getDimMap(indexingMaps, iterator_types(),
519 getParallelIteratorTypeName(), getContext());
520 }
521
getIndexingMaps()522 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
523 return llvm::to_vector<4>(
524 llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
525 return mapAttr.cast<AffineMapAttr>().getValue();
526 }));
527 }
528
getShapeForUnroll()529 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
530 SmallVector<int64_t, 4> shape;
531 getIterationBounds(shape);
532 return shape;
533 }
534
535 //===----------------------------------------------------------------------===//
536 // ExtractElementOp
537 //===----------------------------------------------------------------------===//
538
build(OpBuilder & builder,OperationState & result,Value source,Value position)539 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
540 Value source, Value position) {
541 result.addOperands({source, position});
542 result.addTypes(source.getType().cast<VectorType>().getElementType());
543 }
544
build(OpBuilder & builder,OperationState & result,Value source,int64_t position)545 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
546 Value source, int64_t position) {
547 Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
548 build(builder, result, source, pos);
549 }
550
verify(vector::ExtractElementOp op)551 static LogicalResult verify(vector::ExtractElementOp op) {
552 VectorType vectorType = op.getVectorType();
553 if (vectorType.getRank() != 1)
554 return op.emitOpError("expected 1-D vector");
555 return success();
556 }
557
558 //===----------------------------------------------------------------------===//
559 // ExtractOp
560 //===----------------------------------------------------------------------===//
561
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)562 static Type inferExtractOpResultType(VectorType vectorType,
563 ArrayAttr position) {
564 if (static_cast<int64_t>(position.size()) == vectorType.getRank())
565 return vectorType.getElementType();
566 return VectorType::get(vectorType.getShape().drop_front(position.size()),
567 vectorType.getElementType());
568 }
569
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)570 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
571 Value source, ArrayRef<int64_t> position) {
572 result.addOperands(source);
573 auto positionAttr = getVectorSubscriptAttr(builder, position);
574 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
575 positionAttr));
576 result.addAttribute(getPositionAttrName(), positionAttr);
577 }
578
579 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)580 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
581 Value source, ValueRange position) {
582 SmallVector<int64_t, 4> positionConstants =
583 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
584 return pos.getDefiningOp<ConstantIndexOp>().getValue();
585 }));
586 build(builder, result, source, positionConstants);
587 }
588
print(OpAsmPrinter & p,vector::ExtractOp op)589 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
590 p << op.getOperationName() << " " << op.vector() << op.position();
591 p.printOptionalAttrDict(op.getAttrs(), {"position"});
592 p << " : " << op.vector().getType();
593 }
594
parseExtractOp(OpAsmParser & parser,OperationState & result)595 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
596 llvm::SMLoc attributeLoc, typeLoc;
597 NamedAttrList attrs;
598 OpAsmParser::OperandType vector;
599 Type type;
600 Attribute attr;
601 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
602 parser.parseAttribute(attr, "position", attrs) ||
603 parser.parseOptionalAttrDict(attrs) ||
604 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
605 return failure();
606
607 auto vectorType = type.dyn_cast<VectorType>();
608 if (!vectorType)
609 return parser.emitError(typeLoc, "expected vector type");
610
611 auto positionAttr = attr.dyn_cast<ArrayAttr>();
612 if (!positionAttr ||
613 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
614 return parser.emitError(
615 attributeLoc,
616 "expected position attribute of rank smaller than vector rank");
617
618 Type resType = inferExtractOpResultType(vectorType, positionAttr);
619 result.attributes = attrs;
620 return failure(parser.resolveOperand(vector, type, result.operands) ||
621 parser.addTypeToList(resType, result.types));
622 }
623
verify(vector::ExtractOp op)624 static LogicalResult verify(vector::ExtractOp op) {
625 auto positionAttr = op.position().getValue();
626 if (positionAttr.empty())
627 return op.emitOpError("expected non-empty position attribute");
628 if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
629 return op.emitOpError(
630 "expected position attribute of rank smaller than vector rank");
631 for (auto en : llvm::enumerate(positionAttr)) {
632 auto attr = en.value().dyn_cast<IntegerAttr>();
633 if (!attr || attr.getInt() < 0 ||
634 attr.getInt() >= op.getVectorType().getDimSize(en.index()))
635 return op.emitOpError("expected position attribute #")
636 << (en.index() + 1)
637 << " to be a non-negative integer smaller than the corresponding "
638 "vector dimension";
639 }
640 return success();
641 }
642
643 template <typename IntType>
extractVector(ArrayAttr arrayAttr)644 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
645 return llvm::to_vector<4>(llvm::map_range(
646 arrayAttr.getAsRange<IntegerAttr>(),
647 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
648 }
649
650 /// Fold the result of chains of ExtractOp in place by simply concatenating the
651 /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)652 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
653 if (!extractOp.vector().getDefiningOp<ExtractOp>())
654 return failure();
655
656 SmallVector<int64_t, 4> globalPosition;
657 ExtractOp currentOp = extractOp;
658 auto extractedPos = extractVector<int64_t>(currentOp.position());
659 globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
660 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
661 currentOp = nextOp;
662 auto extractedPos = extractVector<int64_t>(currentOp.position());
663 globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
664 }
665 extractOp.setOperand(currentOp.vector());
666 // OpBuilder is only used as a helper to build an I64ArrayAttr.
667 OpBuilder b(extractOp.getContext());
668 std::reverse(globalPosition.begin(), globalPosition.end());
669 extractOp.setAttr(ExtractOp::getPositionAttrName(),
670 b.getI64ArrayAttr(globalPosition));
671 return success();
672 }
673
674 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
foldExtractOpFromTranspose(ExtractOp extractOp)675 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
676 auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
677 if (!transposeOp)
678 return failure();
679
680 auto permutation = extractVector<unsigned>(transposeOp.transp());
681 auto extractedPos = extractVector<int64_t>(extractOp.position());
682
683 // If transposition permutation is larger than the ExtractOp, all minor
684 // dimensions must be an identity for folding to occur. If not, individual
685 // elements within the extracted value are transposed and this is not just a
686 // simple folding.
687 unsigned minorRank = permutation.size() - extractedPos.size();
688 MLIRContext *ctx = extractOp.getContext();
689 AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
690 AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
691 if (minorMap && !minorMap.isMinorIdentity())
692 return failure();
693
694 // %1 = transpose %0[x, y, z] : vector<axbxcxf32>
695 // %2 = extract %1[u, v] : vector<..xf32>
696 // may turn into:
697 // %2 = extract %0[w, x] : vector<..xf32>
698 // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
699 // -1 denotes the inverse.
700 permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
701 // The major submap has fewer results but the same number of dims. To compose
702 // cleanly, we need to drop dims to form a "square matrix". This is possible
703 // because:
704 // (a) this is a permutation map and
705 // (b) the minor map has already been checked to be identity.
706 // Therefore, the major map cannot contain dims of position greater or equal
707 // than the number of results.
708 assert(llvm::all_of(permutationMap.getResults(),
709 [&](AffineExpr e) {
710 auto dim = e.dyn_cast<AffineDimExpr>();
711 return dim && dim.getPosition() <
712 permutationMap.getNumResults();
713 }) &&
714 "Unexpected map results depend on higher rank positions");
715 // Project on the first domain dimensions to allow composition.
716 permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
717 permutationMap.getResults(), ctx);
718
719 extractOp.setOperand(transposeOp.vector());
720 // Compose the inverse permutation map with the extractedPos.
721 auto newExtractedPos =
722 inversePermutation(permutationMap).compose(extractedPos);
723 // OpBuilder is only used as a helper to build an I64ArrayAttr.
724 OpBuilder b(extractOp.getContext());
725 extractOp.setAttr(ExtractOp::getPositionAttrName(),
726 b.getI64ArrayAttr(newExtractedPos));
727
728 return success();
729 }
730
731 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
732 /// result is always the input to some InsertOp.
foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp)733 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
734 MLIRContext *context = extractOp.getContext();
735 AffineMap permutationMap;
736 auto extractedPos = extractVector<unsigned>(extractOp.position());
737 // Walk back a chain of InsertOp/TransposeOp until we hit a match.
738 // Compose TransposeOp permutations as we walk back.
739 auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
740 auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
741 while (insertOp || transposeOp) {
742 if (transposeOp) {
743 // If it is transposed, compose the map and iterate.
744 auto permutation = extractVector<unsigned>(transposeOp.transp());
745 AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
746 if (!permutationMap)
747 permutationMap = newMap;
748 else if (newMap.getNumInputs() != permutationMap.getNumResults())
749 return Value();
750 else
751 permutationMap = newMap.compose(permutationMap);
752 // Compute insert/transpose for the next iteration.
753 Value transposed = transposeOp.vector();
754 insertOp = transposed.getDefiningOp<vector::InsertOp>();
755 transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
756 continue;
757 }
758
759 assert(insertOp);
760 Value insertionDest = insertOp.dest();
761 // If it is inserted into, either the position matches and we have a
762 // successful folding; or we iterate until we run out of
763 // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
764 // produces a new vector with 1 modified value/slice in exactly the static
765 // position we need to match.
766 auto insertedPos = extractVector<unsigned>(insertOp.position());
767 // Trivial permutations are solved with position equality checks.
768 if (!permutationMap || permutationMap.isIdentity()) {
769 if (extractedPos == insertedPos)
770 return insertOp.source();
771 // Fallthrough: if the position does not match, just skip to the next
772 // producing `vector.insert` / `vector.transpose`.
773 // Compute insert/transpose for the next iteration.
774 insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
775 transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
776 continue;
777 }
778
779 // More advanced permutations require application of the permutation.
780 // However, the rank of `insertedPos` may be different from that of the
781 // `permutationMap`. To support such case, we need to:
782 // 1. apply on the `insertedPos.size()` major dimensions
783 // 2. check the other dimensions of the permutation form a minor identity.
784 assert(permutationMap.isPermutation() && "expected a permutation");
785 if (insertedPos.size() == extractedPos.size()) {
786 bool fold = true;
787 for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
788 auto pos = permutationMap.getDimPosition(idx);
789 if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
790 fold = false;
791 break;
792 }
793 }
794 if (fold) {
795 assert(permutationMap.getNumResults() >= insertedPos.size() &&
796 "expected map of rank larger than insert indexing");
797 unsigned minorRank =
798 permutationMap.getNumResults() - insertedPos.size();
799 AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
800 if (!minorMap || minorMap.isMinorIdentity())
801 return insertOp.source();
802 }
803 }
804
805 // If we haven't found a match, just continue to the next producing
806 // `vector.insert` / `vector.transpose`.
807 // Compute insert/transpose for the next iteration.
808 insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
809 transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
810 }
811 return Value();
812 }
813
814 /// Fold extractOp with scalar result coming from BroadcastOp.
foldExtractFromBroadcast(ExtractOp extractOp)815 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
816 auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
817 if (!broadcastOp)
818 return Value();
819 if (extractOp.getType() == broadcastOp.getSourceType())
820 return broadcastOp.source();
821 auto getRank = [](Type type) {
822 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
823 };
824 unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
825 unsigned extractResultRank = getRank(extractOp.getType());
826 if (extractResultRank < broadcasrSrcRank) {
827 auto extractPos = extractVector<int64_t>(extractOp.position());
828 unsigned rankDiff = broadcasrSrcRank - extractResultRank;
829 extractPos.erase(
830 extractPos.begin(),
831 std::next(extractPos.begin(), extractPos.size() - rankDiff));
832 extractOp.setOperand(broadcastOp.source());
833 // OpBuilder is only used as a helper to build an I64ArrayAttr.
834 OpBuilder b(extractOp.getContext());
835 extractOp.setAttr(ExtractOp::getPositionAttrName(),
836 b.getI64ArrayAttr(extractPos));
837 return extractOp.getResult();
838 }
839 // TODO: In case the rank of the broadcast source is greater than the rank of
840 // the extract result this can be combined into a new broadcast op. This needs
841 // to be added a canonicalization pattern if needed.
842 return Value();
843 }
844
845 // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)846 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
847 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
848 if (!shapeCastOp)
849 return Value();
850 // Get the nth dimension size starting from lowest dimension.
851 auto getDimReverse = [](VectorType type, int64_t n) {
852 return type.getShape().take_back(n+1).front();
853 };
854 int64_t destinationRank =
855 extractOp.getType().isa<VectorType>()
856 ? extractOp.getType().cast<VectorType>().getRank()
857 : 0;
858 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
859 return Value();
860 if (destinationRank > 0) {
861 auto destinationType = extractOp.getResult().getType().cast<VectorType>();
862 for (int64_t i = 0; i < destinationRank; i++) {
863 // The lowest dimension of of the destination must match the lowest
864 // dimension of the shapecast op source.
865 // TODO: This case could be support in a canonicalization pattern.
866 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
867 getDimReverse(destinationType, i))
868 return Value();
869 }
870 }
871 // Extract the strides associated with the extract op vector source. Then use
872 // this to calculate a linearized position for the extract.
873 auto extractedPos = extractVector<int64_t>(extractOp.position());
874 std::reverse(extractedPos.begin(), extractedPos.end());
875 SmallVector<int64_t, 4> strides;
876 int64_t stride = 1;
877 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
878 strides.push_back(stride);
879 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
880 }
881
882 int64_t position = linearize(extractedPos, strides);
883 // Then extract the strides associated to the shapeCast op vector source and
884 // delinearize the position using those strides.
885 SmallVector<int64_t, 4> newStrides;
886 int64_t numDimension =
887 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
888 stride = 1;
889 for (int64_t i = 0; i < numDimension; i++) {
890 newStrides.push_back(stride);
891 stride *=
892 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
893 }
894 std::reverse(newStrides.begin(), newStrides.end());
895 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
896 // OpBuilder is only used as a helper to build an I64ArrayAttr.
897 OpBuilder b(extractOp.getContext());
898 extractOp.setAttr(ExtractOp::getPositionAttrName(),
899 b.getI64ArrayAttr(newPosition));
900 extractOp.setOperand(shapeCastOp.source());
901 return extractOp.getResult();
902 }
903
fold(ArrayRef<Attribute>)904 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
905 if (succeeded(foldExtractOpFromExtractChain(*this)))
906 return getResult();
907 if (succeeded(foldExtractOpFromTranspose(*this)))
908 return getResult();
909 if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
910 return val;
911 if (auto val = foldExtractFromBroadcast(*this))
912 return val;
913 if (auto val = foldExtractFromShapeCast(*this))
914 return val;
915 return OpFoldResult();
916 }
917
918 //===----------------------------------------------------------------------===//
919 // ExtractSlicesOp
920 //===----------------------------------------------------------------------===//
921
build(OpBuilder & builder,OperationState & result,TupleType tupleType,Value vector,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)922 void ExtractSlicesOp::build(OpBuilder &builder, OperationState &result,
923 TupleType tupleType, Value vector,
924 ArrayRef<int64_t> sizes,
925 ArrayRef<int64_t> strides) {
926 result.addOperands(vector);
927 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
928 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
929 result.addTypes(tupleType);
930 result.addAttribute(getSizesAttrName(), sizesAttr);
931 result.addAttribute(getStridesAttrName(), stridesAttr);
932 }
933
934 static LogicalResult
isValidExtractOrInsertSlicesType(Operation * op,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)935 isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
936 TupleType tupleType, ArrayRef<int64_t> sizes,
937 ArrayRef<int64_t> strides) {
938 // Check for non-unit strides.
939 // TODO: Support non-1 strides.
940 if (llvm::any_of(strides, [](int64_t s) { return s != 1; }))
941 return op->emitError("requires unit strides");
942 // Check that 'vectorType' rank matches rank of tuple element vectors.
943 unsigned rank = vectorType.getRank();
944 auto is_vector_type_of_rank = [&](Type t) {
945 return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank;
946 };
947 if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank))
948 return op->emitError("requires vector tuple elements of rank ") << rank;
949 // Check that 'sizes' and 'strides' are of size == 'rank'.
950 if (sizes.size() != rank || strides.size() != rank)
951 return op->emitError("requires sizes and strides of rank ") << rank;
952
953 // Generate each slice shape based on 'sizes', 'strides' and 'vectorType',
954 // and verify that the same matches the corresponding tuple element 'i'.
955 auto shape = vectorType.getShape();
956 auto sliceStrides = computeStrides(shape, sizes);
957 for (int64_t i = 0, e = tupleType.size(); i < e; ++i) {
958 auto vectorOffsets = delinearize(sliceStrides, i);
959 auto elementOffsets =
960 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
961 auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
962 // Create slice VectorType type.
963 auto sliceVectorType =
964 VectorType::get(sliceSizes, vectorType.getElementType());
965 // Verify that 'sliceVectorType' matches tupleType.getTypes(i)
966 if (sliceVectorType != tupleType.getType(i))
967 return op->emitError("invalid tuple element type ") << sliceVectorType;
968 }
969 return success();
970 }
971
verify(ExtractSlicesOp op)972 static LogicalResult verify(ExtractSlicesOp op) {
973 SmallVector<int64_t, 4> sizes;
974 op.getSizes(sizes);
975 SmallVector<int64_t, 4> strides;
976 op.getStrides(strides);
977 return isValidExtractOrInsertSlicesType(
978 op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(),
979 sizes, strides);
980 }
981
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)982 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
983 SmallVectorImpl<int64_t> &results) {
984 for (auto attr : arrayAttr)
985 results.push_back(attr.cast<IntegerAttr>().getInt());
986 }
987
getSizes(SmallVectorImpl<int64_t> & results)988 void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
989 populateFromInt64AttrArray(sizes(), results);
990 }
991
getStrides(SmallVectorImpl<int64_t> & results)992 void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
993 populateFromInt64AttrArray(strides(), results);
994 }
995
996 //===----------------------------------------------------------------------===//
997 // ExtractMapOp
998 //===----------------------------------------------------------------------===//
999
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)1000 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1001 Value vector, ValueRange ids,
1002 ArrayRef<int64_t> multiplicity,
1003 AffineMap permutationMap) {
1004 assert(ids.size() == multiplicity.size() &&
1005 ids.size() == permutationMap.getNumResults());
1006 assert(permutationMap.isProjectedPermutation());
1007 VectorType type = vector.getType().cast<VectorType>();
1008 SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1009 type.getShape().end());
1010 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1011 AffineExpr expr = permutationMap.getResult(i);
1012 auto dim = expr.cast<AffineDimExpr>();
1013 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1014 }
1015 VectorType resultType = VectorType::get(newShape, type.getElementType());
1016 ExtractMapOp::build(builder, result, resultType, vector, ids);
1017 }
1018
verify(ExtractMapOp op)1019 static LogicalResult verify(ExtractMapOp op) {
1020 if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1021 return op.emitOpError(
1022 "expected source and destination vectors of same rank");
1023 unsigned numId = 0;
1024 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1025 if (op.getSourceVectorType().getDimSize(i) %
1026 op.getResultType().getDimSize(i) !=
1027 0)
1028 return op.emitOpError("source vector dimensions must be a multiple of "
1029 "destination vector dimensions");
1030 if (op.getSourceVectorType().getDimSize(i) !=
1031 op.getResultType().getDimSize(i))
1032 numId++;
1033 }
1034 if (numId != op.ids().size())
1035 return op.emitOpError("expected number of ids must match the number of "
1036 "dimensions distributed");
1037 return success();
1038 }
1039
fold(ArrayRef<Attribute> operands)1040 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1041 auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1042 if (insert == nullptr || getType() != insert.vector().getType() ||
1043 ids() != insert.ids())
1044 return {};
1045 return insert.vector();
1046 }
1047
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)1048 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1049 assert(multiplicity.empty());
1050 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1051 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1052 multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1053 getResultType().getDimSize(i));
1054 }
1055 }
1056
1057 template <typename MapOp>
calculateImplicitMap(MapOp op)1058 AffineMap calculateImplicitMap(MapOp op) {
1059 SmallVector<AffineExpr, 4> perm;
1060 // Check which dimension have a multiplicity greater than 1 and associated
1061 // them to the IDs in order.
1062 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1063 if (op.getSourceVectorType().getDimSize(i) !=
1064 op.getResultType().getDimSize(i))
1065 perm.push_back(getAffineDimExpr(i, op.getContext()));
1066 }
1067 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1068 op.getContext());
1069 return map;
1070 }
1071
map()1072 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1073
1074 //===----------------------------------------------------------------------===//
1075 // BroadcastOp
1076 //===----------------------------------------------------------------------===//
1077
verify(BroadcastOp op)1078 static LogicalResult verify(BroadcastOp op) {
1079 VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1080 VectorType dstVectorType = op.getVectorType();
1081 // Scalar to vector broadcast is always valid. A vector
1082 // to vector broadcast needs some additional checking.
1083 if (srcVectorType) {
1084 int64_t srcRank = srcVectorType.getRank();
1085 int64_t dstRank = dstVectorType.getRank();
1086 if (srcRank > dstRank)
1087 return op.emitOpError("source rank higher than destination rank");
1088 // Source has an exact match or singleton value for all trailing dimensions
1089 // (all leading dimensions are simply duplicated).
1090 int64_t lead = dstRank - srcRank;
1091 for (int64_t r = 0; r < srcRank; ++r) {
1092 int64_t srcDim = srcVectorType.getDimSize(r);
1093 int64_t dstDim = dstVectorType.getDimSize(lead + r);
1094 if (srcDim != 1 && srcDim != dstDim)
1095 return op.emitOpError("dimension mismatch (")
1096 << srcDim << " vs. " << dstDim << ")";
1097 }
1098 }
1099 return success();
1100 }
1101
fold(ArrayRef<Attribute> operands)1102 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1103 if (!operands[0])
1104 return {};
1105 auto vectorType = getVectorType();
1106 if (operands[0].getType().isIntOrIndexOrFloat())
1107 return DenseElementsAttr::get(vectorType, operands[0]);
1108 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1109 return DenseElementsAttr::get(vectorType, attr.getSplatValue());
1110 return {};
1111 }
1112
1113 //===----------------------------------------------------------------------===//
1114 // ShuffleOp
1115 //===----------------------------------------------------------------------===//
1116
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)1117 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1118 Value v2, ArrayRef<int64_t> mask) {
1119 result.addOperands({v1, v2});
1120 auto maskAttr = getVectorSubscriptAttr(builder, mask);
1121 result.addTypes(v1.getType());
1122 result.addAttribute(getMaskAttrName(), maskAttr);
1123 }
1124
print(OpAsmPrinter & p,ShuffleOp op)1125 static void print(OpAsmPrinter &p, ShuffleOp op) {
1126 p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
1127 << op.mask();
1128 p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
1129 p << " : " << op.v1().getType() << ", " << op.v2().getType();
1130 }
1131
verify(ShuffleOp op)1132 static LogicalResult verify(ShuffleOp op) {
1133 VectorType resultType = op.getVectorType();
1134 VectorType v1Type = op.getV1VectorType();
1135 VectorType v2Type = op.getV2VectorType();
1136 // Verify ranks.
1137 int64_t resRank = resultType.getRank();
1138 int64_t v1Rank = v1Type.getRank();
1139 int64_t v2Rank = v2Type.getRank();
1140 if (resRank != v1Rank || v1Rank != v2Rank)
1141 return op.emitOpError("rank mismatch");
1142 // Verify all but leading dimension sizes.
1143 for (int64_t r = 1; r < v1Rank; ++r) {
1144 int64_t resDim = resultType.getDimSize(r);
1145 int64_t v1Dim = v1Type.getDimSize(r);
1146 int64_t v2Dim = v2Type.getDimSize(r);
1147 if (resDim != v1Dim || v1Dim != v2Dim)
1148 return op.emitOpError("dimension mismatch");
1149 }
1150 // Verify mask length.
1151 auto maskAttr = op.mask().getValue();
1152 int64_t maskLength = maskAttr.size();
1153 if (maskLength != resultType.getDimSize(0))
1154 return op.emitOpError("mask length mismatch");
1155 // Verify all indices.
1156 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1157 for (auto en : llvm::enumerate(maskAttr)) {
1158 auto attr = en.value().dyn_cast<IntegerAttr>();
1159 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1160 return op.emitOpError("mask index #")
1161 << (en.index() + 1) << " out of range";
1162 }
1163 return success();
1164 }
1165
parseShuffleOp(OpAsmParser & parser,OperationState & result)1166 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1167 OpAsmParser::OperandType v1, v2;
1168 Attribute attr;
1169 VectorType v1Type, v2Type;
1170 if (parser.parseOperand(v1) || parser.parseComma() ||
1171 parser.parseOperand(v2) ||
1172 parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1173 result.attributes) ||
1174 parser.parseOptionalAttrDict(result.attributes) ||
1175 parser.parseColonType(v1Type) || parser.parseComma() ||
1176 parser.parseType(v2Type) ||
1177 parser.resolveOperand(v1, v1Type, result.operands) ||
1178 parser.resolveOperand(v2, v2Type, result.operands))
1179 return failure();
1180 // Construct resulting type: leading dimension matches mask length,
1181 // all trailing dimensions match the operands.
1182 auto maskAttr = attr.dyn_cast<ArrayAttr>();
1183 if (!maskAttr)
1184 return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1185 int64_t maskLength = maskAttr.size();
1186 if (maskLength <= 0)
1187 return parser.emitError(parser.getNameLoc(), "invalid mask length");
1188 int64_t v1Rank = v1Type.getRank();
1189 SmallVector<int64_t, 4> shape;
1190 shape.reserve(v1Rank);
1191 shape.push_back(maskLength);
1192 for (int64_t r = 1; r < v1Rank; ++r)
1193 shape.push_back(v1Type.getDimSize(r));
1194 VectorType resType = VectorType::get(shape, v1Type.getElementType());
1195 parser.addTypeToList(resType, result.types);
1196 return success();
1197 }
1198
1199 //===----------------------------------------------------------------------===//
1200 // InsertElementOp
1201 //===----------------------------------------------------------------------===//
1202
build(OpBuilder & builder,OperationState & result,Value source,Value dest,Value position)1203 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1204 Value source, Value dest, Value position) {
1205 result.addOperands({source, dest, position});
1206 result.addTypes(dest.getType());
1207 }
1208
build(OpBuilder & builder,OperationState & result,Value source,Value dest,int64_t position)1209 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1210 Value source, Value dest, int64_t position) {
1211 Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
1212 build(builder, result, source, dest, pos);
1213 }
1214
verify(InsertElementOp op)1215 static LogicalResult verify(InsertElementOp op) {
1216 auto dstVectorType = op.getDestVectorType();
1217 if (dstVectorType.getRank() != 1)
1218 return op.emitOpError("expected 1-D vector");
1219 return success();
1220 }
1221
1222 //===----------------------------------------------------------------------===//
1223 // InsertOp
1224 //===----------------------------------------------------------------------===//
1225
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)1226 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1227 Value dest, ArrayRef<int64_t> position) {
1228 result.addOperands({source, dest});
1229 auto positionAttr = getVectorSubscriptAttr(builder, position);
1230 result.addTypes(dest.getType());
1231 result.addAttribute(getPositionAttrName(), positionAttr);
1232 }
1233
1234 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)1235 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1236 Value dest, ValueRange position) {
1237 SmallVector<int64_t, 4> positionConstants =
1238 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1239 return pos.getDefiningOp<ConstantIndexOp>().getValue();
1240 }));
1241 build(builder, result, source, dest, positionConstants);
1242 }
1243
verify(InsertOp op)1244 static LogicalResult verify(InsertOp op) {
1245 auto positionAttr = op.position().getValue();
1246 if (positionAttr.empty())
1247 return op.emitOpError("expected non-empty position attribute");
1248 auto destVectorType = op.getDestVectorType();
1249 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1250 return op.emitOpError(
1251 "expected position attribute of rank smaller than dest vector rank");
1252 auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1253 if (srcVectorType &&
1254 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1255 static_cast<unsigned>(destVectorType.getRank())))
1256 return op.emitOpError("expected position attribute rank + source rank to "
1257 "match dest vector rank");
1258 else if (!srcVectorType && (positionAttr.size() !=
1259 static_cast<unsigned>(destVectorType.getRank())))
1260 return op.emitOpError(
1261 "expected position attribute rank to match the dest vector rank");
1262 for (auto en : llvm::enumerate(positionAttr)) {
1263 auto attr = en.value().dyn_cast<IntegerAttr>();
1264 if (!attr || attr.getInt() < 0 ||
1265 attr.getInt() >= destVectorType.getDimSize(en.index()))
1266 return op.emitOpError("expected position attribute #")
1267 << (en.index() + 1)
1268 << " to be a non-negative integer smaller than the corresponding "
1269 "dest vector dimension";
1270 }
1271 return success();
1272 }
1273
1274 //===----------------------------------------------------------------------===//
1275 // InsertSlicesOp
1276 //===----------------------------------------------------------------------===//
1277
verify(InsertSlicesOp op)1278 static LogicalResult verify(InsertSlicesOp op) {
1279 SmallVector<int64_t, 4> sizes;
1280 op.getSizes(sizes);
1281 SmallVector<int64_t, 4> strides;
1282 op.getStrides(strides);
1283 return isValidExtractOrInsertSlicesType(
1284 op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
1285 sizes, strides);
1286 }
1287
getSizes(SmallVectorImpl<int64_t> & results)1288 void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
1289 populateFromInt64AttrArray(sizes(), results);
1290 }
1291
getStrides(SmallVectorImpl<int64_t> & results)1292 void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
1293 populateFromInt64AttrArray(strides(), results);
1294 }
1295
1296 //===----------------------------------------------------------------------===//
1297 // InsertMapOp
1298 //===----------------------------------------------------------------------===//
1299
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange ids)1300 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1301 Value vector, Value dest, ValueRange ids) {
1302 InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1303 }
1304
verify(InsertMapOp op)1305 static LogicalResult verify(InsertMapOp op) {
1306 if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1307 return op.emitOpError(
1308 "expected source and destination vectors of same rank");
1309 unsigned numId = 0;
1310 for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1311 if (op.getResultType().getDimSize(i) %
1312 op.getSourceVectorType().getDimSize(i) !=
1313 0)
1314 return op.emitOpError(
1315 "destination vector size must be a multiple of source vector size");
1316 if (op.getResultType().getDimSize(i) !=
1317 op.getSourceVectorType().getDimSize(i))
1318 numId++;
1319 }
1320 if (numId != op.ids().size())
1321 return op.emitOpError("expected number of ids must match the number of "
1322 "dimensions distributed");
1323 return success();
1324 }
1325
map()1326 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1327
1328 //===----------------------------------------------------------------------===//
1329 // InsertStridedSliceOp
1330 //===----------------------------------------------------------------------===//
1331
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)1332 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1333 Value source, Value dest,
1334 ArrayRef<int64_t> offsets,
1335 ArrayRef<int64_t> strides) {
1336 result.addOperands({source, dest});
1337 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1338 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1339 result.addTypes(dest.getType());
1340 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1341 result.addAttribute(getStridesAttrName(), stridesAttr);
1342 }
1343
1344 // TODO: Should be moved to Tablegen Confined attributes.
1345 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)1346 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1347 ArrayAttr arrayAttr,
1348 ArrayRef<int64_t> shape,
1349 StringRef attrName) {
1350 if (arrayAttr.size() > shape.size())
1351 return op.emitOpError("expected ")
1352 << attrName << " attribute of rank smaller than vector rank";
1353 return success();
1354 }
1355
1356 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1357 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1358 // Otherwise, the admissible interval is [min, max].
1359 template <typename OpType>
1360 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)1361 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1362 int64_t max, StringRef attrName,
1363 bool halfOpen = true) {
1364 for (auto attr : arrayAttr) {
1365 auto val = attr.cast<IntegerAttr>().getInt();
1366 auto upper = max;
1367 if (!halfOpen)
1368 upper += 1;
1369 if (val < min || val >= upper)
1370 return op.emitOpError("expected ") << attrName << " to be confined to ["
1371 << min << ", " << upper << ")";
1372 }
1373 return success();
1374 }
1375
1376 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1377 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1378 // Otherwise, the admissible interval is [min, max].
1379 template <typename OpType>
1380 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)1381 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1382 ArrayRef<int64_t> shape, StringRef attrName,
1383 bool halfOpen = true, int64_t min = 0) {
1384 assert(arrayAttr.size() <= shape.size());
1385 unsigned index = 0;
1386 for (auto it : llvm::zip(arrayAttr, shape)) {
1387 auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1388 auto max = std::get<1>(it);
1389 if (!halfOpen)
1390 max += 1;
1391 if (val < min || val >= max)
1392 return op.emitOpError("expected ")
1393 << attrName << " dimension " << index << " to be confined to ["
1394 << min << ", " << max << ")";
1395 ++index;
1396 }
1397 return success();
1398 }
1399
1400 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1401 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1402 // Otherwise, the admissible interval is [min, max].
1403 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)1404 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
1405 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1406 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1407 bool halfOpen = true, int64_t min = 1) {
1408 assert(arrayAttr1.size() <= shape.size());
1409 assert(arrayAttr2.size() <= shape.size());
1410 unsigned index = 0;
1411 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1412 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1413 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1414 auto max = std::get<2>(it);
1415 if (!halfOpen)
1416 max += 1;
1417 if (val1 + val2 < 0 || val1 + val2 >= max)
1418 return op.emitOpError("expected sum(")
1419 << attrName1 << ", " << attrName2 << ") dimension " << index
1420 << " to be confined to [" << min << ", " << max << ")";
1421 ++index;
1422 }
1423 return success();
1424 }
1425
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1426 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1427 MLIRContext *context) {
1428 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1429 return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
1430 });
1431 return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
1432 }
1433
verify(InsertStridedSliceOp op)1434 static LogicalResult verify(InsertStridedSliceOp op) {
1435 auto sourceVectorType = op.getSourceVectorType();
1436 auto destVectorType = op.getDestVectorType();
1437 auto offsets = op.offsets();
1438 auto strides = op.strides();
1439 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1440 return op.emitOpError(
1441 "expected offsets of same size as destination vector rank");
1442 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1443 return op.emitOpError(
1444 "expected strides of same size as source vector rank");
1445 if (sourceVectorType.getRank() > destVectorType.getRank())
1446 return op.emitOpError(
1447 "expected source rank to be smaller than destination rank");
1448
1449 auto sourceShape = sourceVectorType.getShape();
1450 auto destShape = destVectorType.getShape();
1451 SmallVector<int64_t, 4> sourceShapeAsDestShape(
1452 destShape.size() - sourceShape.size(), 0);
1453 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1454 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1455 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1456 if (failed(
1457 isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1458 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1459 /*halfOpen=*/false)) ||
1460 failed(isSumOfIntegerArrayAttrConfinedToShape(
1461 op, offsets,
1462 makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1463 offName, "source vector shape",
1464 /*halfOpen=*/false, /*min=*/1)))
1465 return failure();
1466
1467 return success();
1468 }
1469
1470 //===----------------------------------------------------------------------===//
1471 // OuterProductOp
1472 //===----------------------------------------------------------------------===//
1473
1474 /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)1475 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
1476 Value lhs, Value rhs, Value acc) {
1477 result.addOperands({lhs, rhs, acc});
1478 result.addTypes(acc.getType());
1479 }
1480
print(OpAsmPrinter & p,OuterProductOp op)1481 static void print(OpAsmPrinter &p, OuterProductOp op) {
1482 p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
1483 if (!op.acc().empty())
1484 p << ", " << op.acc();
1485 p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1486 }
1487
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1488 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1489 OperationState &result) {
1490 SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1491 Type tLHS, tRHS;
1492 if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
1493 parser.parseComma() || parser.parseType(tRHS))
1494 return failure();
1495 if (operandsInfo.size() < 2)
1496 return parser.emitError(parser.getNameLoc(),
1497 "expected at least 2 operands");
1498 VectorType vLHS = tLHS.dyn_cast<VectorType>();
1499 VectorType vRHS = tRHS.dyn_cast<VectorType>();
1500 if (!vLHS)
1501 return parser.emitError(parser.getNameLoc(),
1502 "expected vector type for operand #1");
1503 VectorType resType =
1504 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1505 vLHS.getElementType())
1506 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
1507 return failure(
1508 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1509 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1510 (operandsInfo.size() > 2 &&
1511 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1512 parser.addTypeToList(resType, result.types));
1513 }
1514
verify(OuterProductOp op)1515 static LogicalResult verify(OuterProductOp op) {
1516 Type tRHS = op.getOperandTypeRHS();
1517 VectorType vLHS = op.getOperandVectorTypeLHS(),
1518 vRHS = tRHS.dyn_cast<VectorType>(),
1519 vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1520
1521 if (vLHS.getRank() != 1)
1522 return op.emitOpError("expected 1-d vector for operand #1");
1523
1524 if (vRHS) {
1525 // Proper OUTER operation.
1526 if (vRHS.getRank() != 1)
1527 return op.emitOpError("expected 1-d vector for operand #2");
1528 if (vRES.getRank() != 2)
1529 return op.emitOpError("expected 2-d vector result");
1530 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1531 return op.emitOpError("expected #1 operand dim to match result dim #1");
1532 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1533 return op.emitOpError("expected #2 operand dim to match result dim #2");
1534 } else {
1535 // An AXPY operation.
1536 if (vRES.getRank() != 1)
1537 return op.emitOpError("expected 1-d vector result");
1538 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1539 return op.emitOpError("expected #1 operand dim to match result dim #1");
1540 }
1541
1542 if (vACC && vACC != vRES)
1543 return op.emitOpError("expected operand #3 of same type as result type");
1544 return success();
1545 }
1546
1547 //===----------------------------------------------------------------------===//
1548 // ReshapeOp
1549 //===----------------------------------------------------------------------===//
1550
verify(ReshapeOp op)1551 static LogicalResult verify(ReshapeOp op) {
1552 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1553 auto inputVectorType = op.getInputVectorType();
1554 auto outputVectorType = op.getOutputVectorType();
1555 int64_t inputShapeRank = op.getNumInputShapeSizes();
1556 int64_t outputShapeRank = op.getNumOutputShapeSizes();
1557 SmallVector<int64_t, 4> fixedVectorSizes;
1558 op.getFixedVectorSizes(fixedVectorSizes);
1559 int64_t numFixedVectorSizes = fixedVectorSizes.size();
1560
1561 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1562 return op.emitError("invalid input shape for vector type ")
1563 << inputVectorType;
1564
1565 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1566 return op.emitError("invalid output shape for vector type ")
1567 << outputVectorType;
1568
1569 // Verify that the 'fixedVectorSizes' match an input/output vector shape
1570 // suffix.
1571 unsigned inputVectorRank = inputVectorType.getRank();
1572 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1573 unsigned index = inputVectorRank - numFixedVectorSizes - i;
1574 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1575 return op.emitError("fixed vector size must match input vector for dim ")
1576 << i;
1577 }
1578
1579 unsigned outputVectorRank = outputVectorType.getRank();
1580 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1581 unsigned index = outputVectorRank - numFixedVectorSizes - i;
1582 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1583 return op.emitError("fixed vector size must match output vector for dim ")
1584 << i;
1585 }
1586
1587 // If all shape operands are produced by constant ops, verify that product
1588 // of dimensions for input/output shape match.
1589 auto isDefByConstant = [](Value operand) {
1590 return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1591 };
1592 if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1593 llvm::all_of(op.output_shape(), isDefByConstant)) {
1594 int64_t numInputElements = 1;
1595 for (auto operand : op.input_shape())
1596 numInputElements *=
1597 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1598 int64_t numOutputElements = 1;
1599 for (auto operand : op.output_shape())
1600 numOutputElements *=
1601 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1602 if (numInputElements != numOutputElements)
1603 return op.emitError("product of input and output shape sizes must match");
1604 }
1605 return success();
1606 }
1607
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1608 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1609 populateFromInt64AttrArray(fixed_vector_sizes(), results);
1610 }
1611
1612 //===----------------------------------------------------------------------===//
1613 // ExtractStridedSliceOp
1614 //===----------------------------------------------------------------------===//
1615
1616 // Inference works as follows:
1617 // 1. Add 'sizes' from prefix of dims in 'offsets'.
1618 // 2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1619 static Type inferStridedSliceOpResultType(VectorType vectorType,
1620 ArrayAttr offsets, ArrayAttr sizes,
1621 ArrayAttr strides) {
1622 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1623 SmallVector<int64_t, 4> shape;
1624 shape.reserve(vectorType.getRank());
1625 unsigned idx = 0;
1626 for (unsigned e = offsets.size(); idx < e; ++idx)
1627 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
1628 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1629 shape.push_back(vectorType.getShape()[idx]);
1630
1631 return VectorType::get(shape, vectorType.getElementType());
1632 }
1633
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1634 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1635 Value source, ArrayRef<int64_t> offsets,
1636 ArrayRef<int64_t> sizes,
1637 ArrayRef<int64_t> strides) {
1638 result.addOperands(source);
1639 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1640 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
1641 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1642 result.addTypes(
1643 inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1644 offsetsAttr, sizesAttr, stridesAttr));
1645 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1646 result.addAttribute(getSizesAttrName(), sizesAttr);
1647 result.addAttribute(getStridesAttrName(), stridesAttr);
1648 }
1649
verify(ExtractStridedSliceOp op)1650 static LogicalResult verify(ExtractStridedSliceOp op) {
1651 auto type = op.getVectorType();
1652 auto offsets = op.offsets();
1653 auto sizes = op.sizes();
1654 auto strides = op.strides();
1655 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
1656 op.emitOpError(
1657 "expected offsets, sizes and strides attributes of same size");
1658 return failure();
1659 }
1660
1661 auto shape = type.getShape();
1662 auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
1663 auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
1664 auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
1665 if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
1666 failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
1667 failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
1668 stridesName)) ||
1669 failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
1670 failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
1671 /*halfOpen=*/false,
1672 /*min=*/1)) ||
1673 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1674 /*halfOpen=*/false)) ||
1675 failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
1676 offName, sizesName,
1677 /*halfOpen=*/false)))
1678 return failure();
1679
1680 auto resultType = inferStridedSliceOpResultType(
1681 op.getVectorType(), op.offsets(), op.sizes(), op.strides());
1682 if (op.getResult().getType() != resultType) {
1683 op.emitOpError("expected result type to be ") << resultType;
1684 return failure();
1685 }
1686
1687 return success();
1688 }
1689
1690 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
1691 // to use the source of the InsertStrided ops if we can detect that the
1692 // extracted vector is a subset of one of the vector inserted.
1693 static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)1694 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
1695 // Helper to extract integer out of ArrayAttr.
1696 auto getElement = [](ArrayAttr array, int idx) {
1697 return array[idx].cast<IntegerAttr>().getInt();
1698 };
1699 ArrayAttr extractOffsets = op.offsets();
1700 ArrayAttr extractStrides = op.strides();
1701 ArrayAttr extractSizes = op.sizes();
1702 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1703 while (insertOp) {
1704 if (op.getVectorType().getRank() !=
1705 insertOp.getSourceVectorType().getRank())
1706 return failure();
1707 ArrayAttr insertOffsets = insertOp.offsets();
1708 ArrayAttr insertStrides = insertOp.strides();
1709 // If the rank of extract is greater than the rank of insert, we are likely
1710 // extracting a partial chunk of the vector inserted.
1711 if (extractOffsets.size() > insertOffsets.size())
1712 return failure();
1713 bool patialoverlap = false;
1714 bool disjoint = false;
1715 SmallVector<int64_t, 4> offsetDiffs;
1716 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1717 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
1718 return failure();
1719 int64_t start = getElement(insertOffsets, dim);
1720 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
1721 int64_t offset = getElement(extractOffsets, dim);
1722 int64_t size = getElement(extractSizes, dim);
1723 // Check if the start of the extract offset is in the interval inserted.
1724 if (start <= offset && offset < end) {
1725 // If the extract interval overlaps but is not fully included we may
1726 // have a partial overlap that will prevent any folding.
1727 if (offset + size > end)
1728 patialoverlap = true;
1729 offsetDiffs.push_back(offset - start);
1730 continue;
1731 }
1732 disjoint = true;
1733 break;
1734 }
1735 // The extract element chunk is a subset of the insert element.
1736 if (!disjoint && !patialoverlap) {
1737 op.setOperand(insertOp.source());
1738 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1739 OpBuilder b(op.getContext());
1740 op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
1741 b.getI64ArrayAttr(offsetDiffs));
1742 return success();
1743 }
1744 // If the chunk extracted is disjoint from the chunk inserted, keep looking
1745 // in the insert chain.
1746 if (disjoint)
1747 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
1748 else {
1749 // The extracted vector partially overlap the inserted vector, we cannot
1750 // fold.
1751 return failure();
1752 }
1753 }
1754 return failure();
1755 }
1756
fold(ArrayRef<Attribute> operands)1757 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
1758 if (getVectorType() == getResult().getType())
1759 return vector();
1760 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
1761 return getResult();
1762 return {};
1763 }
1764
getOffsets(SmallVectorImpl<int64_t> & results)1765 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
1766 populateFromInt64AttrArray(offsets(), results);
1767 }
1768
1769 namespace {
1770
1771 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
1772 class StridedSliceConstantMaskFolder final
1773 : public OpRewritePattern<ExtractStridedSliceOp> {
1774 public:
1775 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1776
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const1777 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
1778 PatternRewriter &rewriter) const override {
1779 // Return if 'extractStridedSliceOp' operand is not defined by a
1780 // ConstantMaskOp.
1781 auto defOp = extractStridedSliceOp.vector().getDefiningOp();
1782 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
1783 if (!constantMaskOp)
1784 return failure();
1785 // Return if 'extractStridedSliceOp' has non-unit strides.
1786 if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
1787 return attr.cast<IntegerAttr>().getInt() != 1;
1788 }))
1789 return failure();
1790 // Gather constant mask dimension sizes.
1791 SmallVector<int64_t, 4> maskDimSizes;
1792 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
1793 // Gather strided slice offsets and sizes.
1794 SmallVector<int64_t, 4> sliceOffsets;
1795 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
1796 SmallVector<int64_t, 4> sliceSizes;
1797 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
1798
1799 // Compute slice of vector mask region.
1800 SmallVector<int64_t, 4> sliceMaskDimSizes;
1801 assert(sliceOffsets.size() == maskDimSizes.size());
1802 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
1803 int64_t maskDimSize = std::get<0>(it);
1804 int64_t sliceOffset = std::get<1>(it);
1805 int64_t sliceSize = std::get<2>(it);
1806 int64_t sliceMaskDimSize = std::max(
1807 static_cast<int64_t>(0),
1808 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
1809 sliceMaskDimSizes.push_back(sliceMaskDimSize);
1810 }
1811 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
1812 // region is a conjunction of mask dim intervals).
1813 if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
1814 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
1815
1816 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
1817 // region.
1818 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
1819 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
1820 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
1821 return success();
1822 }
1823 };
1824
1825 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
1826 class StridedSliceConstantFolder final
1827 : public OpRewritePattern<ExtractStridedSliceOp> {
1828 public:
1829 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1830
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const1831 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
1832 PatternRewriter &rewriter) const override {
1833 // Return if 'extractStridedSliceOp' operand is not defined by a
1834 // ConstantOp.
1835 auto constantOp =
1836 extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
1837 if (!constantOp)
1838 return failure();
1839 auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
1840 if (!dense)
1841 return failure();
1842 auto newAttr = DenseElementsAttr::get(
1843 extractStridedSliceOp.getType().cast<VectorType>(),
1844 dense.getSplatValue());
1845 rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
1846 return success();
1847 }
1848 };
1849
1850 } // end anonymous namespace
1851
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1852 void ExtractStridedSliceOp::getCanonicalizationPatterns(
1853 OwningRewritePatternList &results, MLIRContext *context) {
1854 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
1855 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
1856 results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
1857 context);
1858 }
1859
1860 //===----------------------------------------------------------------------===//
1861 // TransferReadOp
1862 //===----------------------------------------------------------------------===//
1863
1864 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)1865 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
1866 EmitFun emitOpError) {
1867 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
1868 for (auto expr : permutationMap.getResults()) {
1869 auto dim = expr.dyn_cast<AffineDimExpr>();
1870 auto zero = expr.dyn_cast<AffineConstantExpr>();
1871 if (zero) {
1872 if (zero.getValue() != 0) {
1873 return emitOpError(
1874 "requires a projected permutation_map (at most one dim or the zero "
1875 "constant can appear in each result)");
1876 }
1877 continue;
1878 }
1879 if (!dim) {
1880 return emitOpError("requires a projected permutation_map (at most one "
1881 "dim or the zero constant can appear in each result)");
1882 }
1883 if (seen[dim.getPosition()]) {
1884 return emitOpError(
1885 "requires a permutation_map that is a permutation (found one dim "
1886 "used more than once)");
1887 }
1888 seen[dim.getPosition()] = true;
1889 }
1890 return success();
1891 }
1892
verifyTransferOp(Operation * op,MemRefType memrefType,VectorType vectorType,AffineMap permutationMap,ArrayAttr optionalMasked)1893 static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
1894 VectorType vectorType,
1895 AffineMap permutationMap,
1896 ArrayAttr optionalMasked) {
1897 auto memrefElementType = memrefType.getElementType();
1898 if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
1899 // Memref has vector element type.
1900
1901 unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
1902 memrefVectorElementType.getShape().back();
1903 unsigned resultVecSize =
1904 vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
1905 if (resultVecSize % memrefVecSize != 0)
1906 return op->emitOpError(
1907 "requires the bitwidth of the minor 1-D vector to be an integral "
1908 "multiple of the bitwidth of the minor 1-D vector of the memref");
1909
1910 unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1911 unsigned resultVecRank = vectorType.getRank();
1912 if (memrefVecEltRank > resultVecRank)
1913 return op->emitOpError(
1914 "requires memref vector element and vector result ranks to match.");
1915 unsigned rankOffset = resultVecRank - memrefVecEltRank;
1916 // Check that permutation map results match 'rankOffset' of vector type.
1917 if (permutationMap.getNumResults() != rankOffset)
1918 return op->emitOpError("requires a permutation_map with result dims of "
1919 "the same rank as the vector type");
1920 } else {
1921 // Memref has scalar element type.
1922 unsigned resultVecSize =
1923 vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
1924 if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
1925 return op->emitOpError(
1926 "requires the bitwidth of the minor 1-D vector to be an integral "
1927 "multiple of the bitwidth of the memref element type");
1928
1929 // Check that permutation map results match rank of vector type.
1930 if (permutationMap.getNumResults() != vectorType.getRank())
1931 return op->emitOpError("requires a permutation_map with result dims of "
1932 "the same rank as the vector type");
1933 }
1934
1935 if (permutationMap.getNumSymbols() != 0)
1936 return op->emitOpError("requires permutation_map without symbols");
1937 if (permutationMap.getNumInputs() != memrefType.getRank())
1938 return op->emitOpError("requires a permutation_map with input dims of the "
1939 "same rank as the memref type");
1940
1941 if (optionalMasked) {
1942 if (permutationMap.getNumResults() !=
1943 static_cast<int64_t>(optionalMasked.size()))
1944 return op->emitOpError("expects the optional masked attr of same rank as "
1945 "permutation_map results: ")
1946 << AffineMapAttr::get(permutationMap);
1947 }
1948
1949 return success();
1950 }
1951
1952 /// Builder that sets padding to zero.
build(OpBuilder & builder,OperationState & result,VectorType vector,Value memref,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> maybeMasked)1953 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
1954 VectorType vector, Value memref, ValueRange indices,
1955 AffineMap permutationMap,
1956 ArrayRef<bool> maybeMasked) {
1957 Type elemType = memref.getType().cast<MemRefType>().getElementType();
1958 Value padding = builder.create<ConstantOp>(result.location, elemType,
1959 builder.getZeroAttr(elemType));
1960 if (maybeMasked.empty())
1961 return build(builder, result, vector, memref, indices, permutationMap,
1962 padding, ArrayAttr());
1963 ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
1964 build(builder, result, vector, memref, indices, permutationMap, padding,
1965 maskedArrayAttr);
1966 }
1967
1968 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
1969 /// (resp. zero).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value memref,ValueRange indices,ArrayRef<bool> maybeMasked)1970 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
1971 VectorType vectorType, Value memref,
1972 ValueRange indices, ArrayRef<bool> maybeMasked) {
1973 auto permMap = getTransferMinorIdentityMap(
1974 memref.getType().cast<MemRefType>(), vectorType);
1975 build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
1976 }
1977
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)1978 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
1979 SmallVector<StringRef, 2> elidedAttrs;
1980 if (op.permutation_map() ==
1981 getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
1982 elidedAttrs.push_back(op.getPermutationMapAttrName());
1983 bool elideMasked = true;
1984 if (auto maybeMasked = op.masked()) {
1985 for (auto attr : *maybeMasked) {
1986 if (!attr.template cast<BoolAttr>().getValue()) {
1987 elideMasked = false;
1988 break;
1989 }
1990 }
1991 }
1992 if (elideMasked)
1993 elidedAttrs.push_back(op.getMaskedAttrName());
1994 p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
1995 }
1996
print(OpAsmPrinter & p,TransferReadOp op)1997 static void print(OpAsmPrinter &p, TransferReadOp op) {
1998 p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
1999 << "], " << op.padding();
2000 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2001 p << " : " << op.getMemRefType() << ", " << op.getVectorType();
2002 }
2003
parseTransferReadOp(OpAsmParser & parser,OperationState & result)2004 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2005 OperationState &result) {
2006 llvm::SMLoc typesLoc;
2007 OpAsmParser::OperandType memrefInfo;
2008 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2009 OpAsmParser::OperandType paddingInfo;
2010 SmallVector<Type, 2> types;
2011 // Parsing with support for paddingValue.
2012 if (parser.parseOperand(memrefInfo) ||
2013 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2014 parser.parseComma() || parser.parseOperand(paddingInfo) ||
2015 parser.parseOptionalAttrDict(result.attributes) ||
2016 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2017 return failure();
2018 if (types.size() != 2)
2019 return parser.emitError(typesLoc, "requires two types");
2020 auto indexType = parser.getBuilder().getIndexType();
2021 MemRefType memRefType = types[0].dyn_cast<MemRefType>();
2022 if (!memRefType)
2023 return parser.emitError(typesLoc, "requires memref type");
2024 VectorType vectorType = types[1].dyn_cast<VectorType>();
2025 if (!vectorType)
2026 return parser.emitError(typesLoc, "requires vector type");
2027 auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2028 auto attr = result.attributes.get(permutationAttrName);
2029 if (!attr) {
2030 auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
2031 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2032 }
2033 return failure(
2034 parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
2035 parser.resolveOperands(indexInfo, indexType, result.operands) ||
2036 parser.resolveOperand(paddingInfo, memRefType.getElementType(),
2037 result.operands) ||
2038 parser.addTypeToList(vectorType, result.types));
2039 }
2040
verify(TransferReadOp op)2041 static LogicalResult verify(TransferReadOp op) {
2042 // Consistency of elemental types in memref and vector.
2043 MemRefType memrefType = op.getMemRefType();
2044 VectorType vectorType = op.getVectorType();
2045 auto paddingType = op.padding().getType();
2046 auto permutationMap = op.permutation_map();
2047 auto memrefElementType = memrefType.getElementType();
2048
2049 if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
2050 return op.emitOpError("requires ") << memrefType.getRank() << " indices";
2051
2052 if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
2053 permutationMap,
2054 op.masked() ? *op.masked() : ArrayAttr())))
2055 return failure();
2056
2057 if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
2058 // Memref has vector element type.
2059 // Check that 'memrefVectorElementType' and 'paddingType' types match.
2060 if (memrefVectorElementType != paddingType)
2061 return op.emitOpError(
2062 "requires memref element type and padding type to match.");
2063
2064 } else {
2065 // Check that 'paddingType' is valid to store in a vector type.
2066 if (!VectorType::isValidElementType(paddingType))
2067 return op.emitOpError("requires valid padding vector elemental type");
2068
2069 // Check that padding type and vector element types match.
2070 if (paddingType != memrefElementType)
2071 return op.emitOpError(
2072 "requires formal padding and memref of the same elemental type");
2073 }
2074
2075 return verifyPermutationMap(permutationMap,
2076 [&op](Twine t) { return op.emitOpError(t); });
2077 }
2078
2079 /// This is a common class used for patterns of the form
2080 /// ```
2081 /// someop(memrefcast) -> someop
2082 /// ```
2083 /// It folds the source of the memref_cast into the root operation directly.
foldMemRefCast(Operation * op)2084 static LogicalResult foldMemRefCast(Operation *op) {
2085 bool folded = false;
2086 for (OpOperand &operand : op->getOpOperands()) {
2087 auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
2088 if (castOp && canFoldIntoConsumerOp(castOp)) {
2089 operand.set(castOp.getOperand());
2090 folded = true;
2091 }
2092 }
2093 return success(folded);
2094 }
2095
2096 template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)2097 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2098 // TODO: support more aggressive createOrFold on:
2099 // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
2100 if (op.getMemRefType().isDynamicDim(indicesIdx))
2101 return false;
2102 Value index = op.indices()[indicesIdx];
2103 auto cstOp = index.getDefiningOp<ConstantIndexOp>();
2104 if (!cstOp)
2105 return false;
2106
2107 int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
2108 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2109
2110 return cstOp.getValue() + vectorSize <= memrefSize;
2111 }
2112
2113 template <typename TransferOp>
foldTransferMaskAttribute(TransferOp op)2114 static LogicalResult foldTransferMaskAttribute(TransferOp op) {
2115 AffineMap permutationMap = op.permutation_map();
2116 if (!permutationMap.isMinorIdentity())
2117 return failure();
2118 bool changed = false;
2119 SmallVector<bool, 4> isMasked;
2120 isMasked.reserve(op.getTransferRank());
2121 op.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2122 // Already marked unmasked, nothing to see here.
2123 if (!op.isMaskedDim(resultIdx)) {
2124 isMasked.push_back(false);
2125 return;
2126 }
2127 // Currently masked, check whether we can statically determine it is
2128 // inBounds.
2129 auto inBounds = isInBounds(op, resultIdx, indicesIdx);
2130 isMasked.push_back(!inBounds);
2131 // We commit the pattern if it is "more inbounds".
2132 changed |= inBounds;
2133 });
2134 if (!changed)
2135 return failure();
2136 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2137 OpBuilder b(op.getContext());
2138 op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked));
2139 return success();
2140 }
2141
fold(ArrayRef<Attribute>)2142 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2143 /// transfer_read(memrefcast) -> transfer_read
2144 if (succeeded(foldTransferMaskAttribute(*this)))
2145 return getResult();
2146 if (succeeded(foldMemRefCast(*this)))
2147 return getResult();
2148 return OpFoldResult();
2149 }
2150
getShapeForUnroll()2151 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2152 auto s = getVectorType().getShape();
2153 return SmallVector<int64_t, 4>{s.begin(), s.end()};
2154 }
2155
2156 //===----------------------------------------------------------------------===//
2157 // TransferWriteOp
2158 //===----------------------------------------------------------------------===//
2159
2160 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value memref,ValueRange indices,ArrayRef<bool> maybeMasked)2161 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2162 Value vector, Value memref, ValueRange indices,
2163 ArrayRef<bool> maybeMasked) {
2164 auto vectorType = vector.getType().cast<VectorType>();
2165 auto permMap = getTransferMinorIdentityMap(
2166 memref.getType().cast<MemRefType>(), vectorType);
2167 if (maybeMasked.empty())
2168 return build(builder, result, vector, memref, indices, permMap,
2169 ArrayAttr());
2170 ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
2171 build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
2172 }
2173
build(OpBuilder & builder,OperationState & result,Value vector,Value memref,ValueRange indices,AffineMap permutationMap)2174 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2175 Value vector, Value memref, ValueRange indices,
2176 AffineMap permutationMap) {
2177 build(builder, result, vector, memref, indices, permutationMap,
2178 /*maybeMasked=*/ArrayAttr());
2179 }
2180
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)2181 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
2182 OperationState &result) {
2183 llvm::SMLoc typesLoc;
2184 OpAsmParser::OperandType vectorInfo, memrefInfo;
2185 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2186 SmallVector<Type, 2> types;
2187 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
2188 parser.parseOperand(memrefInfo) ||
2189 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2190 parser.parseOptionalAttrDict(result.attributes) ||
2191 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2192 return failure();
2193 if (types.size() != 2)
2194 return parser.emitError(typesLoc, "requires two types");
2195 auto indexType = parser.getBuilder().getIndexType();
2196 VectorType vectorType = types[0].dyn_cast<VectorType>();
2197 if (!vectorType)
2198 return parser.emitError(typesLoc, "requires vector type");
2199 MemRefType memRefType = types[1].dyn_cast<MemRefType>();
2200 if (!memRefType)
2201 return parser.emitError(typesLoc, "requires memref type");
2202 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
2203 auto attr = result.attributes.get(permutationAttrName);
2204 if (!attr) {
2205 auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
2206 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2207 }
2208 return failure(
2209 parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
2210 parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
2211 parser.resolveOperands(indexInfo, indexType, result.operands));
2212 }
2213
print(OpAsmPrinter & p,TransferWriteOp op)2214 static void print(OpAsmPrinter &p, TransferWriteOp op) {
2215 p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
2216 << op.indices() << "]";
2217 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2218 p << " : " << op.getVectorType() << ", " << op.getMemRefType();
2219 }
2220
verify(TransferWriteOp op)2221 static LogicalResult verify(TransferWriteOp op) {
2222 // Consistency of elemental types in memref and vector.
2223 MemRefType memrefType = op.getMemRefType();
2224 VectorType vectorType = op.getVectorType();
2225 auto permutationMap = op.permutation_map();
2226
2227 if (llvm::size(op.indices()) != memrefType.getRank())
2228 return op.emitOpError("requires ") << memrefType.getRank() << " indices";
2229
2230 if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
2231 permutationMap,
2232 op.masked() ? *op.masked() : ArrayAttr())))
2233 return failure();
2234
2235 return verifyPermutationMap(permutationMap,
2236 [&op](Twine t) { return op.emitOpError(t); });
2237 }
2238
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2239 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
2240 SmallVectorImpl<OpFoldResult> &) {
2241 if (succeeded(foldTransferMaskAttribute(*this)))
2242 return success();
2243 return foldMemRefCast(*this);
2244 }
2245
getShapeForUnroll()2246 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
2247 return llvm::to_vector<4>(getVectorType().getShape());
2248 }
2249
2250 //===----------------------------------------------------------------------===//
2251 // MaskedLoadOp
2252 //===----------------------------------------------------------------------===//
2253
verify(MaskedLoadOp op)2254 static LogicalResult verify(MaskedLoadOp op) {
2255 VectorType maskVType = op.getMaskVectorType();
2256 VectorType passVType = op.getPassThruVectorType();
2257 VectorType resVType = op.getResultVectorType();
2258
2259 if (resVType.getElementType() != op.getMemRefType().getElementType())
2260 return op.emitOpError("base and result element type should match");
2261
2262 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
2263 return op.emitOpError("expected result dim to match mask dim");
2264 if (resVType != passVType)
2265 return op.emitOpError("expected pass_thru of same type as result type");
2266 return success();
2267 }
2268
2269 namespace {
2270 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
2271 public:
2272 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const2273 LogicalResult matchAndRewrite(MaskedLoadOp load,
2274 PatternRewriter &rewriter) const override {
2275 Value newBase;
2276 switch (get1DMaskFormat(load.mask())) {
2277 case MaskFormat::AllTrue:
2278 if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
2279 load.getResultVectorType(), rewriter, newBase))
2280 return failure();
2281 rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
2282 return success();
2283 case MaskFormat::AllFalse:
2284 rewriter.replaceOp(load, load.pass_thru());
2285 return success();
2286 case MaskFormat::Unknown:
2287 return failure();
2288 }
2289 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
2290 }
2291 };
2292 } // namespace
2293
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2294 void MaskedLoadOp::getCanonicalizationPatterns(
2295 OwningRewritePatternList &results, MLIRContext *context) {
2296 results.insert<MaskedLoadFolder>(context);
2297 }
2298
2299 //===----------------------------------------------------------------------===//
2300 // MaskedStoreOp
2301 //===----------------------------------------------------------------------===//
2302
verify(MaskedStoreOp op)2303 static LogicalResult verify(MaskedStoreOp op) {
2304 VectorType maskVType = op.getMaskVectorType();
2305 VectorType valueVType = op.getValueVectorType();
2306
2307 if (valueVType.getElementType() != op.getMemRefType().getElementType())
2308 return op.emitOpError("base and value element type should match");
2309
2310 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2311 return op.emitOpError("expected value dim to match mask dim");
2312 return success();
2313 }
2314
2315 namespace {
2316 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
2317 public:
2318 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const2319 LogicalResult matchAndRewrite(MaskedStoreOp store,
2320 PatternRewriter &rewriter) const override {
2321 Value newBase;
2322 switch (get1DMaskFormat(store.mask())) {
2323 case MaskFormat::AllTrue:
2324 if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
2325 store.getValueVectorType(), rewriter, newBase))
2326 return failure();
2327 rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
2328 return success();
2329 case MaskFormat::AllFalse:
2330 rewriter.eraseOp(store);
2331 return success();
2332 case MaskFormat::Unknown:
2333 return failure();
2334 }
2335 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
2336 }
2337 };
2338 } // namespace
2339
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2340 void MaskedStoreOp::getCanonicalizationPatterns(
2341 OwningRewritePatternList &results, MLIRContext *context) {
2342 results.insert<MaskedStoreFolder>(context);
2343 }
2344
2345 //===----------------------------------------------------------------------===//
2346 // GatherOp
2347 //===----------------------------------------------------------------------===//
2348
verify(GatherOp op)2349 static LogicalResult verify(GatherOp op) {
2350 VectorType indicesVType = op.getIndicesVectorType();
2351 VectorType maskVType = op.getMaskVectorType();
2352 VectorType resVType = op.getResultVectorType();
2353
2354 if (resVType.getElementType() != op.getMemRefType().getElementType())
2355 return op.emitOpError("base and result element type should match");
2356
2357 if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
2358 return op.emitOpError("expected result dim to match indices dim");
2359 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
2360 return op.emitOpError("expected result dim to match mask dim");
2361 if (llvm::size(op.pass_thru()) != 0) {
2362 VectorType passVType = op.getPassThruVectorType();
2363 if (resVType != passVType)
2364 return op.emitOpError("expected pass_thru of same type as result type");
2365 }
2366 return success();
2367 }
2368
2369 namespace {
2370 class GatherFolder final : public OpRewritePattern<GatherOp> {
2371 public:
2372 using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const2373 LogicalResult matchAndRewrite(GatherOp gather,
2374 PatternRewriter &rewriter) const override {
2375 switch (get1DMaskFormat(gather.mask())) {
2376 case MaskFormat::AllTrue:
2377 return failure(); // no unmasked equivalent
2378 case MaskFormat::AllFalse:
2379 rewriter.replaceOp(gather, gather.pass_thru());
2380 return success();
2381 case MaskFormat::Unknown:
2382 return failure();
2383 }
2384 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
2385 }
2386 };
2387 } // namespace
2388
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2389 void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2390 MLIRContext *context) {
2391 results.insert<GatherFolder>(context);
2392 }
2393
2394 //===----------------------------------------------------------------------===//
2395 // ScatterOp
2396 //===----------------------------------------------------------------------===//
2397
verify(ScatterOp op)2398 static LogicalResult verify(ScatterOp op) {
2399 VectorType indicesVType = op.getIndicesVectorType();
2400 VectorType maskVType = op.getMaskVectorType();
2401 VectorType valueVType = op.getValueVectorType();
2402
2403 if (valueVType.getElementType() != op.getMemRefType().getElementType())
2404 return op.emitOpError("base and value element type should match");
2405
2406 if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
2407 return op.emitOpError("expected value dim to match indices dim");
2408 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2409 return op.emitOpError("expected value dim to match mask dim");
2410 return success();
2411 }
2412
2413 namespace {
2414 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
2415 public:
2416 using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const2417 LogicalResult matchAndRewrite(ScatterOp scatter,
2418 PatternRewriter &rewriter) const override {
2419 switch (get1DMaskFormat(scatter.mask())) {
2420 case MaskFormat::AllTrue:
2421 return failure(); // no unmasked equivalent
2422 case MaskFormat::AllFalse:
2423 rewriter.eraseOp(scatter);
2424 return success();
2425 case MaskFormat::Unknown:
2426 return failure();
2427 }
2428 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
2429 }
2430 };
2431 } // namespace
2432
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2433 void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2434 MLIRContext *context) {
2435 results.insert<ScatterFolder>(context);
2436 }
2437
2438 //===----------------------------------------------------------------------===//
2439 // ExpandLoadOp
2440 //===----------------------------------------------------------------------===//
2441
verify(ExpandLoadOp op)2442 static LogicalResult verify(ExpandLoadOp op) {
2443 VectorType maskVType = op.getMaskVectorType();
2444 VectorType passVType = op.getPassThruVectorType();
2445 VectorType resVType = op.getResultVectorType();
2446
2447 if (resVType.getElementType() != op.getMemRefType().getElementType())
2448 return op.emitOpError("base and result element type should match");
2449
2450 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
2451 return op.emitOpError("expected result dim to match mask dim");
2452 if (resVType != passVType)
2453 return op.emitOpError("expected pass_thru of same type as result type");
2454 return success();
2455 }
2456
2457 namespace {
2458 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
2459 public:
2460 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const2461 LogicalResult matchAndRewrite(ExpandLoadOp expand,
2462 PatternRewriter &rewriter) const override {
2463 Value newBase;
2464 switch (get1DMaskFormat(expand.mask())) {
2465 case MaskFormat::AllTrue:
2466 if (!castedToMemRef(expand.getLoc(), expand.base(),
2467 expand.getMemRefType(), expand.getResultVectorType(),
2468 rewriter, newBase))
2469 return failure();
2470 rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
2471 return success();
2472 case MaskFormat::AllFalse:
2473 rewriter.replaceOp(expand, expand.pass_thru());
2474 return success();
2475 case MaskFormat::Unknown:
2476 return failure();
2477 }
2478 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
2479 }
2480 };
2481 } // namespace
2482
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2483 void ExpandLoadOp::getCanonicalizationPatterns(
2484 OwningRewritePatternList &results, MLIRContext *context) {
2485 results.insert<ExpandLoadFolder>(context);
2486 }
2487
2488 //===----------------------------------------------------------------------===//
2489 // CompressStoreOp
2490 //===----------------------------------------------------------------------===//
2491
verify(CompressStoreOp op)2492 static LogicalResult verify(CompressStoreOp op) {
2493 VectorType maskVType = op.getMaskVectorType();
2494 VectorType valueVType = op.getValueVectorType();
2495
2496 if (valueVType.getElementType() != op.getMemRefType().getElementType())
2497 return op.emitOpError("base and value element type should match");
2498
2499 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2500 return op.emitOpError("expected value dim to match mask dim");
2501 return success();
2502 }
2503
2504 namespace {
2505 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
2506 public:
2507 using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const2508 LogicalResult matchAndRewrite(CompressStoreOp compress,
2509 PatternRewriter &rewriter) const override {
2510 Value newBase;
2511 switch (get1DMaskFormat(compress.mask())) {
2512 case MaskFormat::AllTrue:
2513 if (!castedToMemRef(compress.getLoc(), compress.base(),
2514 compress.getMemRefType(),
2515 compress.getValueVectorType(), rewriter, newBase))
2516 return failure();
2517 rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
2518 return success();
2519 case MaskFormat::AllFalse:
2520 rewriter.eraseOp(compress);
2521 return success();
2522 case MaskFormat::Unknown:
2523 return failure();
2524 }
2525 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
2526 }
2527 };
2528 } // namespace
2529
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2530 void CompressStoreOp::getCanonicalizationPatterns(
2531 OwningRewritePatternList &results, MLIRContext *context) {
2532 results.insert<CompressStoreFolder>(context);
2533 }
2534
2535 //===----------------------------------------------------------------------===//
2536 // ShapeCastOp
2537 //===----------------------------------------------------------------------===//
2538
2539 /// Returns true if each element of 'a' is equal to the product of a contiguous
2540 /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)2541 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2542 unsigned rankA = a.size();
2543 unsigned rankB = b.size();
2544 assert(rankA < rankB);
2545
2546 unsigned i = 0;
2547 unsigned j = 0;
2548 while (i < rankA && j < rankB) {
2549 int64_t dimA = a[i];
2550 int64_t dimB = 1;
2551 while (dimB < dimA && j < rankB)
2552 dimB *= b[j++];
2553 if (dimA != dimB)
2554 break;
2555 ++i;
2556
2557 // Handle the case when trailing dimensions are of size 1.
2558 // Include them into the contiguous sequence.
2559 auto isOne = [](int64_t v) { return v == 1; };
2560 if (i < rankA && llvm::all_of(a.slice(i), isOne))
2561 i = rankA;
2562 if (j < rankB && llvm::all_of(b.slice(j), isOne))
2563 j = rankB;
2564 }
2565
2566 return i == rankA && j == rankB;
2567 }
2568
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)2569 static LogicalResult verifyVectorShapeCast(Operation *op,
2570 VectorType sourceVectorType,
2571 VectorType resultVectorType) {
2572 // Check that element type is the same.
2573 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
2574 return op->emitOpError("source/result vectors must have same element type");
2575 auto sourceShape = sourceVectorType.getShape();
2576 auto resultShape = resultVectorType.getShape();
2577
2578 // Check that product of source dim sizes matches product of result dim sizes.
2579 int64_t sourceDimProduct = std::accumulate(
2580 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
2581 int64_t resultDimProduct = std::accumulate(
2582 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
2583 if (sourceDimProduct != resultDimProduct)
2584 return op->emitOpError("source/result number of elements must match");
2585
2586 // Check that expanding/contracting rank cases.
2587 unsigned sourceRank = sourceVectorType.getRank();
2588 unsigned resultRank = resultVectorType.getRank();
2589 if (sourceRank < resultRank) {
2590 if (!isValidShapeCast(sourceShape, resultShape))
2591 return op->emitOpError("invalid shape cast");
2592 } else if (sourceRank > resultRank) {
2593 if (!isValidShapeCast(resultShape, sourceShape))
2594 return op->emitOpError("invalid shape cast");
2595 }
2596 return success();
2597 }
2598
verify(ShapeCastOp op)2599 static LogicalResult verify(ShapeCastOp op) {
2600 auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
2601 auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
2602
2603 // Check if source/result are of vector type.
2604 if (sourceVectorType && resultVectorType)
2605 return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
2606
2607 // Check if source/result are "tuple of vectors" type.
2608 auto sourceTupleType = op.source().getType().dyn_cast_or_null<TupleType>();
2609 auto resultTupleType = op.result().getType().dyn_cast_or_null<TupleType>();
2610 if (!sourceTupleType || !resultTupleType)
2611 return op.emitOpError("source/result must be of same type");
2612
2613 // Check that source/result tuple sizes are the same.
2614 if (sourceTupleType.size() != resultTupleType.size())
2615 return op.emitOpError("source/result tuples must be the same size");
2616
2617 // Check each source/result tuple element pair.
2618 for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i)
2619 if (failed(verifyVectorShapeCast(
2620 op, sourceTupleType.getType(i).cast<VectorType>(),
2621 resultTupleType.getType(i).cast<VectorType>())))
2622 return failure();
2623
2624 return success();
2625 }
2626
fold(ArrayRef<Attribute> operands)2627 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
2628 // Nop shape cast.
2629 if (source().getType() == result().getType())
2630 return source();
2631
2632 // Canceling shape casts.
2633 if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
2634 if (result().getType() == otherOp.source().getType())
2635 return otherOp.source();
2636
2637 return {};
2638 }
2639
2640 namespace {
2641 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
2642 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
2643 public:
2644 using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
2645
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const2646 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
2647 PatternRewriter &rewriter) const override {
2648 auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
2649 if (!constantOp)
2650 return failure();
2651 // Only handle splat for now.
2652 auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2653 if (!dense)
2654 return failure();
2655 auto newAttr = DenseElementsAttr::get(
2656 shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
2657 rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
2658 return success();
2659 }
2660 };
2661
2662 } // namespace
2663
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2664 void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2665 MLIRContext *context) {
2666 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
2667 results.insert<ShapeCastConstantFolder>(context);
2668 }
2669
2670 //===----------------------------------------------------------------------===//
2671 // VectorBitCastOp
2672 //===----------------------------------------------------------------------===//
2673
verify(BitCastOp op)2674 static LogicalResult verify(BitCastOp op) {
2675 auto sourceVectorType = op.getSourceVectorType();
2676 auto resultVectorType = op.getResultVectorType();
2677
2678 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
2679 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
2680 return op.emitOpError("dimension size mismatch at: ") << i;
2681 }
2682
2683 if (sourceVectorType.getElementTypeBitWidth() *
2684 sourceVectorType.getShape().back() !=
2685 resultVectorType.getElementTypeBitWidth() *
2686 resultVectorType.getShape().back())
2687 return op.emitOpError(
2688 "source/result bitwidth of the minor 1-D vectors must be equal");
2689
2690 return success();
2691 }
2692
fold(ArrayRef<Attribute> operands)2693 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
2694 // Nop cast.
2695 if (source().getType() == result().getType())
2696 return source();
2697
2698 // Canceling bitcasts.
2699 if (auto otherOp = source().getDefiningOp<BitCastOp>())
2700 if (result().getType() == otherOp.source().getType())
2701 return otherOp.source();
2702
2703 return {};
2704 }
2705
2706 //===----------------------------------------------------------------------===//
2707 // TypeCastOp
2708 //===----------------------------------------------------------------------===//
2709
extractShape(MemRefType memRefType)2710 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
2711 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
2712 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
2713 memRefType.getShape().end());
2714 if (vectorType)
2715 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
2716 return res;
2717 }
2718
2719 /// Build the canonical memRefType with a single vector.
2720 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)2721 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
2722 Value source) {
2723 result.addOperands(source);
2724 MemRefType memRefType = source.getType().cast<MemRefType>();
2725 VectorType vectorType =
2726 VectorType::get(extractShape(memRefType),
2727 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
2728 result.addTypes(
2729 MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
2730 }
2731
verify(TypeCastOp op)2732 static LogicalResult verify(TypeCastOp op) {
2733 MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
2734 if (!canonicalType.getAffineMaps().empty())
2735 return op.emitOpError("expects operand to be a memref with no layout");
2736 if (!op.getResultMemRefType().getAffineMaps().empty())
2737 return op.emitOpError("expects result to be a memref with no layout");
2738 if (op.getResultMemRefType().getMemorySpace() !=
2739 op.getMemRefType().getMemorySpace())
2740 return op.emitOpError("expects result in same memory space");
2741
2742 auto sourceType = op.getMemRefType();
2743 auto resultType = op.getResultMemRefType();
2744 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
2745 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
2746 return op.emitOpError(
2747 "expects result and operand with same underlying scalar type: ")
2748 << resultType;
2749 if (extractShape(sourceType) != extractShape(resultType))
2750 return op.emitOpError(
2751 "expects concatenated result and operand shapes to be equal: ")
2752 << resultType;
2753 return success();
2754 }
2755
2756 //===----------------------------------------------------------------------===//
2757 // TupleOp
2758 //===----------------------------------------------------------------------===//
2759
parseTupleOp(OpAsmParser & parser,OperationState & result)2760 static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
2761 SmallVector<OpAsmParser::OperandType, 4> operandInfos;
2762 SmallVector<Type, 4> types;
2763 auto loc = parser.getCurrentLocation();
2764 auto *ctx = parser.getBuilder().getContext();
2765 return failure(
2766 parser.parseOperandList(operandInfos) ||
2767 parser.parseOptionalAttrDict(result.attributes) ||
2768 parser.parseColonTypeList(types) ||
2769 parser.resolveOperands(operandInfos, types, loc, result.operands) ||
2770 parser.addTypeToList(TupleType::get(types, ctx), result.types));
2771 }
2772
print(OpAsmPrinter & p,TupleOp op)2773 static void print(OpAsmPrinter &p, TupleOp op) {
2774 p << op.getOperationName() << ' ';
2775 p.printOperands(op.getOperands());
2776 p.printOptionalAttrDict(op.getAttrs());
2777 p << " : ";
2778 llvm::interleaveComma(op->getOperandTypes(), p);
2779 }
2780
verify(TupleOp op)2781 static LogicalResult verify(TupleOp op) { return success(); }
2782
2783 //===----------------------------------------------------------------------===//
2784 // TransposeOp
2785 //===----------------------------------------------------------------------===//
2786
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)2787 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
2788 Value vector, ArrayRef<int64_t> transp) {
2789 VectorType vt = vector.getType().cast<VectorType>();
2790 SmallVector<int64_t, 4> transposedShape(vt.getRank());
2791 for (unsigned i = 0; i < transp.size(); ++i)
2792 transposedShape[i] = vt.getShape()[transp[i]];
2793
2794 result.addOperands(vector);
2795 result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
2796 result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
2797 }
2798
2799 // Eliminates transpose operations, which produce values identical to their
2800 // input values. This happens when the dimensions of the input vector remain in
2801 // their original order after the transpose operation.
fold(ArrayRef<Attribute> operands)2802 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
2803 SmallVector<int64_t, 4> transp;
2804 getTransp(transp);
2805
2806 // Check if the permutation of the dimensions contains sequential values:
2807 // {0, 1, 2, ...}.
2808 for (int64_t i = 0, e = transp.size(); i < e; i++) {
2809 if (transp[i] != i)
2810 return {};
2811 }
2812
2813 return vector();
2814 }
2815
verify(vector::TransposeOp op)2816 static LogicalResult verify(vector::TransposeOp op) {
2817 VectorType vectorType = op.getVectorType();
2818 VectorType resultType = op.getResultType();
2819 int64_t rank = resultType.getRank();
2820 if (vectorType.getRank() != rank)
2821 return op.emitOpError("vector result rank mismatch: ") << rank;
2822 // Verify transposition array.
2823 auto transpAttr = op.transp().getValue();
2824 int64_t size = transpAttr.size();
2825 if (rank != size)
2826 return op.emitOpError("transposition length mismatch: ") << size;
2827 SmallVector<bool, 8> seen(rank, false);
2828 for (auto ta : llvm::enumerate(transpAttr)) {
2829 int64_t i = ta.value().cast<IntegerAttr>().getInt();
2830 if (i < 0 || i >= rank)
2831 return op.emitOpError("transposition index out of range: ") << i;
2832 if (seen[i])
2833 return op.emitOpError("duplicate position index: ") << i;
2834 seen[i] = true;
2835 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
2836 return op.emitOpError("dimension size mismatch at: ") << i;
2837 }
2838 return success();
2839 }
2840
2841 namespace {
2842
2843 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
2844 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
2845 public:
2846 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
2847
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const2848 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2849 PatternRewriter &rewriter) const override {
2850 // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
2851 auto getPermutation = [](vector::TransposeOp transpose) {
2852 SmallVector<int64_t, 4> permutation;
2853 transpose.getTransp(permutation);
2854 return permutation;
2855 };
2856
2857 // Composes two permutations: result[i] = permutation1[permutation2[i]].
2858 auto composePermutations = [](ArrayRef<int64_t> permutation1,
2859 ArrayRef<int64_t> permutation2) {
2860 SmallVector<int64_t, 4> result;
2861 for (auto index : permutation2)
2862 result.push_back(permutation1[index]);
2863 return result;
2864 };
2865
2866 // Return if the input of 'transposeOp' is not defined by another transpose.
2867 vector::TransposeOp parentTransposeOp =
2868 transposeOp.vector().getDefiningOp<vector::TransposeOp>();
2869 if (!parentTransposeOp)
2870 return failure();
2871
2872 SmallVector<int64_t, 4> permutation = composePermutations(
2873 getPermutation(parentTransposeOp), getPermutation(transposeOp));
2874 // Replace 'transposeOp' with a new transpose operation.
2875 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
2876 transposeOp, transposeOp.getResult().getType(),
2877 parentTransposeOp.vector(),
2878 vector::getVectorSubscriptAttr(rewriter, permutation));
2879 return success();
2880 }
2881 };
2882
2883 } // end anonymous namespace
2884
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2885 void vector::TransposeOp::getCanonicalizationPatterns(
2886 OwningRewritePatternList &results, MLIRContext *context) {
2887 results.insert<TransposeFolder>(context);
2888 }
2889
getTransp(SmallVectorImpl<int64_t> & results)2890 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
2891 populateFromInt64AttrArray(transp(), results);
2892 }
2893
2894 //===----------------------------------------------------------------------===//
2895 // TupleGetOp
2896 //===----------------------------------------------------------------------===//
2897
parseTupleGetOp(OpAsmParser & parser,OperationState & result)2898 static ParseResult parseTupleGetOp(OpAsmParser &parser,
2899 OperationState &result) {
2900 OpAsmParser::OperandType operandInfo;
2901 IntegerAttr indexAttr;
2902 StringRef indexAttrName = TupleGetOp::getIndexAttrName();
2903 Type indexType = parser.getBuilder().getIndexType();
2904 TupleType tupleType;
2905 if (parser.parseOperand(operandInfo) || parser.parseComma() ||
2906 parser.parseAttribute(indexAttr, indexType, indexAttrName,
2907 result.attributes) ||
2908 parser.parseOptionalAttrDict(result.attributes) ||
2909 parser.parseColonType(tupleType) ||
2910 parser.resolveOperand(operandInfo, tupleType, result.operands))
2911 return failure();
2912 if (indexAttr.getInt() < 0 ||
2913 indexAttr.getInt() >= static_cast<int64_t>(tupleType.size()))
2914 return failure();
2915 parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types);
2916 return success();
2917 }
2918
print(OpAsmPrinter & p,TupleGetOp op)2919 static void print(OpAsmPrinter &p, TupleGetOp op) {
2920 p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index();
2921 p.printOptionalAttrDict(op.getAttrs(),
2922 /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
2923 p << " : " << op.getOperand().getType();
2924 }
2925
verify(TupleGetOp op)2926 static LogicalResult verify(TupleGetOp op) {
2927 auto tupleType = op.getOperand().getType().cast<TupleType>();
2928 if (op.getIndex() < 0 ||
2929 op.getIndex() >= static_cast<int64_t>(tupleType.size()))
2930 return op.emitOpError("tuple get index out of range");
2931 return success();
2932 }
2933
fold(ArrayRef<Attribute> operands)2934 OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
2935 // Rewrite:
2936 // %t = vector.tuple .., %e_i, ..
2937 // %x = vector.tuple_get %t, i
2938 // into:
2939 // %t = vector.tuple .., %e_i, .. // one less use
2940 // %x = %e_i
2941 if (auto tupleOp = getOperand().getDefiningOp<TupleOp>())
2942 return tupleOp.getOperand(getIndex());
2943 return {};
2944 }
2945
2946 //===----------------------------------------------------------------------===//
2947 // ConstantMaskOp
2948 //===----------------------------------------------------------------------===//
2949
verify(ConstantMaskOp & op)2950 static LogicalResult verify(ConstantMaskOp &op) {
2951 // Verify that array attr size matches the rank of the vector result.
2952 auto resultType = op.getResult().getType().cast<VectorType>();
2953 if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
2954 return op.emitOpError(
2955 "must specify array attr of size equal vector result rank");
2956 // Verify that each array attr element is in bounds of corresponding vector
2957 // result dimension size.
2958 auto resultShape = resultType.getShape();
2959 SmallVector<int64_t, 4> maskDimSizes;
2960 for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
2961 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
2962 if (attrValue < 0 || attrValue > resultShape[it.index()])
2963 return op.emitOpError(
2964 "array attr of size out of bounds of vector result dimension size");
2965 maskDimSizes.push_back(attrValue);
2966 }
2967 // Verify that if one mask dim size is zero, they all should be zero (because
2968 // the mask region is a conjunction of each mask dimension interval).
2969 bool any_zeros = llvm::is_contained(maskDimSizes, 0);
2970 bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
2971 if (any_zeros && !all_zeros)
2972 return op.emitOpError("expected all mask dim sizes to be zeros, "
2973 "as a result of conjunction with zero mask dim");
2974 return success();
2975 }
2976
2977 //===----------------------------------------------------------------------===//
2978 // CreateMaskOp
2979 //===----------------------------------------------------------------------===//
2980
verify(CreateMaskOp op)2981 static LogicalResult verify(CreateMaskOp op) {
2982 // Verify that an operand was specified for each result vector each dimension.
2983 if (op.getNumOperands() !=
2984 op.getResult().getType().cast<VectorType>().getRank())
2985 return op.emitOpError(
2986 "must specify an operand for each result vector dimension");
2987 return success();
2988 }
2989
2990 namespace {
2991
2992 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
2993 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
2994 public:
2995 using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
2996
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const2997 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
2998 PatternRewriter &rewriter) const override {
2999 // Return if any of 'createMaskOp' operands are not defined by a constant.
3000 auto is_not_def_by_constant = [](Value operand) {
3001 return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
3002 };
3003 if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
3004 return failure();
3005 // Gather constant mask dimension sizes.
3006 SmallVector<int64_t, 4> maskDimSizes;
3007 for (auto operand : createMaskOp.operands()) {
3008 auto defOp = operand.getDefiningOp();
3009 maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
3010 }
3011 // Replace 'createMaskOp' with ConstantMaskOp.
3012 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3013 createMaskOp, createMaskOp.getResult().getType(),
3014 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
3015 return success();
3016 }
3017 };
3018
3019 } // end anonymous namespace
3020
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3021 void CreateMaskOp::getCanonicalizationPatterns(
3022 OwningRewritePatternList &results, MLIRContext *context) {
3023 results.insert<CreateMaskFolder>(context);
3024 }
3025
populateVectorToVectorCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)3026 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
3027 OwningRewritePatternList &patterns, MLIRContext *context) {
3028 patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
3029 GatherFolder, ScatterFolder, ExpandLoadFolder,
3030 CompressStoreFolder, StridedSliceConstantMaskFolder,
3031 TransposeFolder>(context);
3032 }
3033
3034 #define GET_OP_CLASSES
3035 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
3036