1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/OpImplementation.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29
30 // Pull in all enum type definitions and utility function declarations.
31 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
32
33 using namespace mlir;
34
35 //===----------------------------------------------------------------------===//
36 // StandardOpsDialect Interfaces
37 //===----------------------------------------------------------------------===//
38 namespace {
39 /// This class defines the interface for handling inlining with standard
40 /// operations.
41 struct StdInlinerInterface : public DialectInlinerInterface {
42 using DialectInlinerInterface::DialectInlinerInterface;
43
44 //===--------------------------------------------------------------------===//
45 // Analysis Hooks
46 //===--------------------------------------------------------------------===//
47
48 /// All call operations within standard ops can be inlined.
isLegalToInline__anon2fd8af2c0111::StdInlinerInterface49 bool isLegalToInline(Operation *call, Operation *callable,
50 bool wouldBeCloned) const final {
51 return true;
52 }
53
54 /// All operations within standard ops can be inlined.
isLegalToInline__anon2fd8af2c0111::StdInlinerInterface55 bool isLegalToInline(Operation *, Region *, bool,
56 BlockAndValueMapping &) const final {
57 return true;
58 }
59
60 //===--------------------------------------------------------------------===//
61 // Transformation Hooks
62 //===--------------------------------------------------------------------===//
63
64 /// Handle the given inlined terminator by replacing it with a new operation
65 /// as necessary.
handleTerminator__anon2fd8af2c0111::StdInlinerInterface66 void handleTerminator(Operation *op, Block *newDest) const final {
67 // Only "std.return" needs to be handled here.
68 auto returnOp = dyn_cast<ReturnOp>(op);
69 if (!returnOp)
70 return;
71
72 // Replace the return with a branch to the dest.
73 OpBuilder builder(op);
74 builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
75 op->erase();
76 }
77
78 /// Handle the given inlined terminator by replacing it with a new operation
79 /// as necessary.
handleTerminator__anon2fd8af2c0111::StdInlinerInterface80 void handleTerminator(Operation *op,
81 ArrayRef<Value> valuesToRepl) const final {
82 // Only "std.return" needs to be handled here.
83 auto returnOp = cast<ReturnOp>(op);
84
85 // Replace the values directly with the return operands.
86 assert(returnOp.getNumOperands() == valuesToRepl.size());
87 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
88 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
89 }
90 };
91 } // end anonymous namespace
92
93 //===----------------------------------------------------------------------===//
94 // StandardOpsDialect
95 //===----------------------------------------------------------------------===//
96
97 /// A custom unary operation printer that omits the "std." prefix from the
98 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)99 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
100 assert(op->getNumOperands() == 1 && "unary op should have one operand");
101 assert(op->getNumResults() == 1 && "unary op should have one result");
102
103 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
104 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
105 << op->getOperand(0);
106 p.printOptionalAttrDict(op->getAttrs());
107 p << " : " << op->getOperand(0).getType();
108 }
109
110 /// A custom binary operation printer that omits the "std." prefix from the
111 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)112 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
113 assert(op->getNumOperands() == 2 && "binary op should have two operands");
114 assert(op->getNumResults() == 1 && "binary op should have one result");
115
116 // If not all the operand and result types are the same, just use the
117 // generic assembly form to avoid omitting information in printing.
118 auto resultType = op->getResult(0).getType();
119 if (op->getOperand(0).getType() != resultType ||
120 op->getOperand(1).getType() != resultType) {
121 p.printGenericOp(op);
122 return;
123 }
124
125 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
126 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
127 << op->getOperand(0) << ", " << op->getOperand(1);
128 p.printOptionalAttrDict(op->getAttrs());
129
130 // Now we can output only one type for all operands and the result.
131 p << " : " << op->getResult(0).getType();
132 }
133
134 /// A custom cast operation printer that omits the "std." prefix from the
135 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)136 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
137 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
138 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
139 << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
140 << op->getResult(0).getType();
141 }
142
143 /// A custom cast operation verifier.
144 template <typename T>
verifyCastOp(T op)145 static LogicalResult verifyCastOp(T op) {
146 auto opType = op.getOperand().getType();
147 auto resType = op.getType();
148 if (!T::areCastCompatible(opType, resType))
149 return op.emitError("operand type ") << opType << " and result type "
150 << resType << " are cast incompatible";
151
152 return success();
153 }
154
initialize()155 void StandardOpsDialect::initialize() {
156 addOperations<DmaStartOp, DmaWaitOp,
157 #define GET_OP_LIST
158 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
159 >();
160 addInterfaces<StdInlinerInterface>();
161 }
162
163 /// Materialize a single constant operation from a given attribute value with
164 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)165 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
166 Attribute value, Type type,
167 Location loc) {
168 return builder.create<ConstantOp>(loc, type, value);
169 }
170
171 /// Matches a ConstantIndexOp.
172 /// TODO: This should probably just be a general matcher that uses m_Constant
173 /// and checks the operation for an index type.
m_ConstantIndex()174 static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
175 return detail::op_matcher<ConstantIndexOp>();
176 }
177
178 //===----------------------------------------------------------------------===//
179 // Common canonicalization pattern support logic
180 //===----------------------------------------------------------------------===//
181
182 /// This is a common class used for patterns of the form
183 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast
184 /// into the root operation directly.
foldMemRefCast(Operation * op)185 static LogicalResult foldMemRefCast(Operation *op) {
186 bool folded = false;
187 for (OpOperand &operand : op->getOpOperands()) {
188 auto cast = operand.get().getDefiningOp<MemRefCastOp>();
189 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
190 operand.set(cast.getOperand());
191 folded = true;
192 }
193 }
194 return success(folded);
195 }
196
197 //===----------------------------------------------------------------------===//
198 // Common cast compatibility check for vector types.
199 //===----------------------------------------------------------------------===//
200
201 /// This method checks for cast compatibility of vector types.
202 /// If 'a' and 'b' are vector types, and they are cast compatible,
203 /// it calls the 'areElementsCastCompatible' function to check for
204 /// element cast compatibility.
205 /// Returns 'true' if the vector types are cast compatible, and 'false'
206 /// otherwise.
areVectorCastSimpleCompatible(Type a,Type b,function_ref<bool (Type,Type)> areElementsCastCompatible)207 static bool areVectorCastSimpleCompatible(
208 Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
209 if (auto va = a.dyn_cast<VectorType>())
210 if (auto vb = b.dyn_cast<VectorType>())
211 return va.getShape().equals(vb.getShape()) &&
212 areElementsCastCompatible(va.getElementType(),
213 vb.getElementType());
214 return false;
215 }
216
217 //===----------------------------------------------------------------------===//
218 // Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
219 //===----------------------------------------------------------------------===//
220
getTensorTypeFromMemRefType(Type type)221 static Type getTensorTypeFromMemRefType(Type type) {
222 if (auto memref = type.dyn_cast<MemRefType>())
223 return RankedTensorType::get(memref.getShape(), memref.getElementType());
224 if (auto memref = type.dyn_cast<UnrankedMemRefType>())
225 return UnrankedTensorType::get(memref.getElementType());
226 return NoneType::get(type.getContext());
227 }
228
229 //===----------------------------------------------------------------------===//
230 // AddFOp
231 //===----------------------------------------------------------------------===//
232
fold(ArrayRef<Attribute> operands)233 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
234 return constFoldBinaryOp<FloatAttr>(
235 operands, [](APFloat a, APFloat b) { return a + b; });
236 }
237
238 //===----------------------------------------------------------------------===//
239 // AddIOp
240 //===----------------------------------------------------------------------===//
241
fold(ArrayRef<Attribute> operands)242 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
243 /// addi(x, 0) -> x
244 if (matchPattern(rhs(), m_Zero()))
245 return lhs();
246
247 return constFoldBinaryOp<IntegerAttr>(operands,
248 [](APInt a, APInt b) { return a + b; });
249 }
250
251 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)252 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
253 return llvm::to_vector<4>(
254 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
255 return a.cast<IntegerAttr>().getInt();
256 }));
257 }
258
259 //===----------------------------------------------------------------------===//
260 // AllocOp / AllocaOp
261 //===----------------------------------------------------------------------===//
262
263 template <typename AllocLikeOp>
verifyAllocLikeOp(AllocLikeOp op)264 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
265 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
266 "applies to only alloc or alloca");
267 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
268 if (!memRefType)
269 return op.emitOpError("result must be a memref");
270
271 if (static_cast<int64_t>(op.dynamicSizes().size()) !=
272 memRefType.getNumDynamicDims())
273 return op.emitOpError("dimension operand count does not equal memref "
274 "dynamic dimension count");
275
276 unsigned numSymbols = 0;
277 if (!memRefType.getAffineMaps().empty())
278 numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
279 if (op.symbolOperands().size() != numSymbols)
280 return op.emitOpError(
281 "symbol operand count does not equal memref symbol count");
282
283 return success();
284 }
285
verify(AllocOp op)286 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
287
verify(AllocaOp op)288 static LogicalResult verify(AllocaOp op) {
289 // An alloca op needs to have an ancestor with an allocation scope trait.
290 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
291 return op.emitOpError(
292 "requires an ancestor op with AutomaticAllocationScope trait");
293
294 return verifyAllocLikeOp(op);
295 }
296
297 namespace {
298 /// Fold constant dimensions into an alloc like operation.
299 template <typename AllocLikeOp>
300 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
301 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
302
matchAndRewrite__anon2fd8af2c0511::SimplifyAllocConst303 LogicalResult matchAndRewrite(AllocLikeOp alloc,
304 PatternRewriter &rewriter) const override {
305 // Check to see if any dimensions operands are constants. If so, we can
306 // substitute and drop them.
307 if (llvm::none_of(alloc.getOperands(), [](Value operand) {
308 return matchPattern(operand, m_ConstantIndex());
309 }))
310 return failure();
311
312 auto memrefType = alloc.getType();
313
314 // Ok, we have one or more constant operands. Collect the non-constant ones
315 // and keep track of the resultant memref type to build.
316 SmallVector<int64_t, 4> newShapeConstants;
317 newShapeConstants.reserve(memrefType.getRank());
318 SmallVector<Value, 4> newOperands;
319
320 unsigned dynamicDimPos = 0;
321 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
322 int64_t dimSize = memrefType.getDimSize(dim);
323 // If this is already static dimension, keep it.
324 if (dimSize != -1) {
325 newShapeConstants.push_back(dimSize);
326 continue;
327 }
328 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
329 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
330 // Dynamic shape dimension will be folded.
331 newShapeConstants.push_back(constantIndexOp.getValue());
332 } else {
333 // Dynamic shape dimension not folded; copy operand from old memref.
334 newShapeConstants.push_back(-1);
335 newOperands.push_back(alloc.getOperand(dynamicDimPos));
336 }
337 dynamicDimPos++;
338 }
339
340 // Create new memref type (which will have fewer dynamic dimensions).
341 MemRefType newMemRefType =
342 MemRefType::Builder(memrefType).setShape(newShapeConstants);
343 assert(static_cast<int64_t>(newOperands.size()) ==
344 newMemRefType.getNumDynamicDims());
345
346 // Create and insert the alloc op for the new memref.
347 auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
348 newOperands, IntegerAttr());
349 // Insert a cast so we have the same type as the old alloc.
350 auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
351 alloc.getType());
352
353 rewriter.replaceOp(alloc, {resultCast});
354 return success();
355 }
356 };
357
358 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
359 /// but can still be deleted if it has zero uses.
360 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
361 using OpRewritePattern<AllocOp>::OpRewritePattern;
362
matchAndRewrite__anon2fd8af2c0511::SimplifyDeadAlloc363 LogicalResult matchAndRewrite(AllocOp alloc,
364 PatternRewriter &rewriter) const override {
365 if (alloc.use_empty()) {
366 rewriter.eraseOp(alloc);
367 return success();
368 }
369 return failure();
370 }
371 };
372 } // end anonymous namespace.
373
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)374 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
375 MLIRContext *context) {
376 results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
377 }
378
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)379 void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
380 MLIRContext *context) {
381 results.insert<SimplifyAllocConst<AllocaOp>>(context);
382 }
383
384 //===----------------------------------------------------------------------===//
385 // AndOp
386 //===----------------------------------------------------------------------===//
387
fold(ArrayRef<Attribute> operands)388 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
389 /// and(x, 0) -> 0
390 if (matchPattern(rhs(), m_Zero()))
391 return rhs();
392 /// and(x, allOnes) -> x
393 APInt intValue;
394 if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
395 intValue.isAllOnesValue())
396 return lhs();
397 /// and(x,x) -> x
398 if (lhs() == rhs())
399 return rhs();
400
401 return constFoldBinaryOp<IntegerAttr>(operands,
402 [](APInt a, APInt b) { return a & b; });
403 }
404
405 //===----------------------------------------------------------------------===//
406 // AssertOp
407 //===----------------------------------------------------------------------===//
408
409 namespace {
410 struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
411 using OpRewritePattern<AssertOp>::OpRewritePattern;
412
matchAndRewrite__anon2fd8af2c0811::EraseRedundantAssertions413 LogicalResult matchAndRewrite(AssertOp op,
414 PatternRewriter &rewriter) const override {
415 // Erase assertion if argument is constant true.
416 if (matchPattern(op.arg(), m_One())) {
417 rewriter.eraseOp(op);
418 return success();
419 }
420 return failure();
421 }
422 };
423 } // namespace
424
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)425 void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
426 MLIRContext *context) {
427 patterns.insert<EraseRedundantAssertions>(context);
428 }
429
430 //===----------------------------------------------------------------------===//
431 // AssumeAlignmentOp
432 //===----------------------------------------------------------------------===//
433
verify(AssumeAlignmentOp op)434 static LogicalResult verify(AssumeAlignmentOp op) {
435 unsigned alignment = op.alignment();
436 if (!llvm::isPowerOf2_32(alignment))
437 return op.emitOpError("alignment must be power of 2");
438 return success();
439 }
440
441 //===----------------------------------------------------------------------===//
442 // AtomicRMWOp
443 //===----------------------------------------------------------------------===//
444
verify(AtomicRMWOp op)445 static LogicalResult verify(AtomicRMWOp op) {
446 if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
447 return op.emitOpError(
448 "expects the number of subscripts to be equal to memref rank");
449 switch (op.kind()) {
450 case AtomicRMWKind::addf:
451 case AtomicRMWKind::maxf:
452 case AtomicRMWKind::minf:
453 case AtomicRMWKind::mulf:
454 if (!op.value().getType().isa<FloatType>())
455 return op.emitOpError()
456 << "with kind '" << stringifyAtomicRMWKind(op.kind())
457 << "' expects a floating-point type";
458 break;
459 case AtomicRMWKind::addi:
460 case AtomicRMWKind::maxs:
461 case AtomicRMWKind::maxu:
462 case AtomicRMWKind::mins:
463 case AtomicRMWKind::minu:
464 case AtomicRMWKind::muli:
465 if (!op.value().getType().isa<IntegerType>())
466 return op.emitOpError()
467 << "with kind '" << stringifyAtomicRMWKind(op.kind())
468 << "' expects an integer type";
469 break;
470 default:
471 break;
472 }
473 return success();
474 }
475
476 //===----------------------------------------------------------------------===//
477 // GenericAtomicRMWOp
478 //===----------------------------------------------------------------------===//
479
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)480 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
481 Value memref, ValueRange ivs) {
482 result.addOperands(memref);
483 result.addOperands(ivs);
484
485 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
486 Type elementType = memrefType.getElementType();
487 result.addTypes(elementType);
488
489 Region *bodyRegion = result.addRegion();
490 bodyRegion->push_back(new Block());
491 bodyRegion->addArgument(elementType);
492 }
493 }
494
verify(GenericAtomicRMWOp op)495 static LogicalResult verify(GenericAtomicRMWOp op) {
496 auto &body = op.body();
497 if (body.getNumArguments() != 1)
498 return op.emitOpError("expected single number of entry block arguments");
499
500 if (op.getResult().getType() != body.getArgument(0).getType())
501 return op.emitOpError(
502 "expected block argument of the same type result type");
503
504 bool hasSideEffects =
505 body.walk([&](Operation *nestedOp) {
506 if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
507 return WalkResult::advance();
508 nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
509 "only operations with no side effects");
510 return WalkResult::interrupt();
511 })
512 .wasInterrupted();
513 return hasSideEffects ? failure() : success();
514 }
515
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)516 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
517 OperationState &result) {
518 OpAsmParser::OperandType memref;
519 Type memrefType;
520 SmallVector<OpAsmParser::OperandType, 4> ivs;
521
522 Type indexType = parser.getBuilder().getIndexType();
523 if (parser.parseOperand(memref) ||
524 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
525 parser.parseColonType(memrefType) ||
526 parser.resolveOperand(memref, memrefType, result.operands) ||
527 parser.resolveOperands(ivs, indexType, result.operands))
528 return failure();
529
530 Region *body = result.addRegion();
531 if (parser.parseRegion(*body, llvm::None, llvm::None) ||
532 parser.parseOptionalAttrDict(result.attributes))
533 return failure();
534 result.types.push_back(memrefType.cast<MemRefType>().getElementType());
535 return success();
536 }
537
print(OpAsmPrinter & p,GenericAtomicRMWOp op)538 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
539 p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
540 << "] : " << op.memref().getType();
541 p.printRegion(op.body());
542 p.printOptionalAttrDict(op.getAttrs());
543 }
544
545 //===----------------------------------------------------------------------===//
546 // AtomicYieldOp
547 //===----------------------------------------------------------------------===//
548
verify(AtomicYieldOp op)549 static LogicalResult verify(AtomicYieldOp op) {
550 Type parentType = op->getParentOp()->getResultTypes().front();
551 Type resultType = op.result().getType();
552 if (parentType != resultType)
553 return op.emitOpError() << "types mismatch between yield op: " << resultType
554 << " and its parent: " << parentType;
555 return success();
556 }
557
558 //===----------------------------------------------------------------------===//
559 // BranchOp
560 //===----------------------------------------------------------------------===//
561
562 /// Given a successor, try to collapse it to a new destination if it only
563 /// contains a passthrough unconditional branch. If the successor is
564 /// collapsable, `successor` and `successorOperands` are updated to reference
565 /// the new destination and values. `argStorage` is an optional storage to use
566 /// if operands to the collapsed successor need to be remapped.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)567 static LogicalResult collapseBranch(Block *&successor,
568 ValueRange &successorOperands,
569 SmallVectorImpl<Value> &argStorage) {
570 // Check that the successor only contains a unconditional branch.
571 if (std::next(successor->begin()) != successor->end())
572 return failure();
573 // Check that the terminator is an unconditional branch.
574 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
575 if (!successorBranch)
576 return failure();
577 // Check that the arguments are only used within the terminator.
578 for (BlockArgument arg : successor->getArguments()) {
579 for (Operation *user : arg.getUsers())
580 if (user != successorBranch)
581 return failure();
582 }
583 // Don't try to collapse branches to infinite loops.
584 Block *successorDest = successorBranch.getDest();
585 if (successorDest == successor)
586 return failure();
587
588 // Update the operands to the successor. If the branch parent has no
589 // arguments, we can use the branch operands directly.
590 OperandRange operands = successorBranch.getOperands();
591 if (successor->args_empty()) {
592 successor = successorDest;
593 successorOperands = operands;
594 return success();
595 }
596
597 // Otherwise, we need to remap any argument operands.
598 for (Value operand : operands) {
599 BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
600 if (argOperand && argOperand.getOwner() == successor)
601 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
602 else
603 argStorage.push_back(operand);
604 }
605 successor = successorDest;
606 successorOperands = argStorage;
607 return success();
608 }
609
610 namespace {
611 /// Simplify a branch to a block that has a single predecessor. This effectively
612 /// merges the two blocks.
613 struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
614 using OpRewritePattern<BranchOp>::OpRewritePattern;
615
matchAndRewrite__anon2fd8af2c0a11::SimplifyBrToBlockWithSinglePred616 LogicalResult matchAndRewrite(BranchOp op,
617 PatternRewriter &rewriter) const override {
618 // Check that the successor block has a single predecessor.
619 Block *succ = op.getDest();
620 Block *opParent = op->getBlock();
621 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
622 return failure();
623
624 // Merge the successor into the current block and erase the branch.
625 rewriter.mergeBlocks(succ, opParent, op.getOperands());
626 rewriter.eraseOp(op);
627 return success();
628 }
629 };
630
631 /// br ^bb1
632 /// ^bb1
633 /// br ^bbN(...)
634 ///
635 /// -> br ^bbN(...)
636 ///
637 struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
638 using OpRewritePattern<BranchOp>::OpRewritePattern;
639
matchAndRewrite__anon2fd8af2c0a11::SimplifyPassThroughBr640 LogicalResult matchAndRewrite(BranchOp op,
641 PatternRewriter &rewriter) const override {
642 Block *dest = op.getDest();
643 ValueRange destOperands = op.getOperands();
644 SmallVector<Value, 4> destOperandStorage;
645
646 // Try to collapse the successor if it points somewhere other than this
647 // block.
648 if (dest == op->getBlock() ||
649 failed(collapseBranch(dest, destOperands, destOperandStorage)))
650 return failure();
651
652 // Create a new branch with the collapsed successor.
653 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
654 return success();
655 }
656 };
657 } // end anonymous namespace.
658
getDest()659 Block *BranchOp::getDest() { return getSuccessor(); }
660
setDest(Block * block)661 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
662
eraseOperand(unsigned index)663 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
664
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)665 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
666 MLIRContext *context) {
667 results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
668 context);
669 }
670
671 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)672 BranchOp::getMutableSuccessorOperands(unsigned index) {
673 assert(index == 0 && "invalid successor index");
674 return destOperandsMutable();
675 }
676
getSuccessorForOperands(ArrayRef<Attribute>)677 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
678
679 //===----------------------------------------------------------------------===//
680 // CallOp
681 //===----------------------------------------------------------------------===//
682
verifySymbolUses(SymbolTableCollection & symbolTable)683 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
684 // Check that the callee attribute was specified.
685 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
686 if (!fnAttr)
687 return emitOpError("requires a 'callee' symbol reference attribute");
688 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
689 if (!fn)
690 return emitOpError() << "'" << fnAttr.getValue()
691 << "' does not reference a valid function";
692
693 // Verify that the operand and result types match the callee.
694 auto fnType = fn.getType();
695 if (fnType.getNumInputs() != getNumOperands())
696 return emitOpError("incorrect number of operands for callee");
697
698 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
699 if (getOperand(i).getType() != fnType.getInput(i))
700 return emitOpError("operand type mismatch: expected operand type ")
701 << fnType.getInput(i) << ", but provided "
702 << getOperand(i).getType() << " for operand number " << i;
703
704 if (fnType.getNumResults() != getNumResults())
705 return emitOpError("incorrect number of results for callee");
706
707 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
708 if (getResult(i).getType() != fnType.getResult(i))
709 return emitOpError("result type mismatch");
710
711 return success();
712 }
713
getCalleeType()714 FunctionType CallOp::getCalleeType() {
715 return FunctionType::get(getOperandTypes(), getResultTypes(), getContext());
716 }
717
718 //===----------------------------------------------------------------------===//
719 // CallIndirectOp
720 //===----------------------------------------------------------------------===//
721 namespace {
722 /// Fold indirect calls that have a constant function as the callee operand.
723 struct SimplifyIndirectCallWithKnownCallee
724 : public OpRewritePattern<CallIndirectOp> {
725 using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
726
matchAndRewrite__anon2fd8af2c0b11::SimplifyIndirectCallWithKnownCallee727 LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
728 PatternRewriter &rewriter) const override {
729 // Check that the callee is a constant callee.
730 SymbolRefAttr calledFn;
731 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
732 return failure();
733
734 // Replace with a direct call.
735 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
736 indirectCall.getResultTypes(),
737 indirectCall.getArgOperands());
738 return success();
739 }
740 };
741 } // end anonymous namespace.
742
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)743 void CallIndirectOp::getCanonicalizationPatterns(
744 OwningRewritePatternList &results, MLIRContext *context) {
745 results.insert<SimplifyIndirectCallWithKnownCallee>(context);
746 }
747
748 //===----------------------------------------------------------------------===//
749 // General helpers for comparison ops
750 //===----------------------------------------------------------------------===//
751
752 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)753 static Type getI1SameShape(Type type) {
754 auto i1Type = IntegerType::get(1, type.getContext());
755 if (auto tensorType = type.dyn_cast<RankedTensorType>())
756 return RankedTensorType::get(tensorType.getShape(), i1Type);
757 if (type.isa<UnrankedTensorType>())
758 return UnrankedTensorType::get(i1Type);
759 if (auto vectorType = type.dyn_cast<VectorType>())
760 return VectorType::get(vectorType.getShape(), i1Type);
761 return i1Type;
762 }
763
764 //===----------------------------------------------------------------------===//
765 // CmpIOp
766 //===----------------------------------------------------------------------===//
767
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)768 static void buildCmpIOp(OpBuilder &build, OperationState &result,
769 CmpIPredicate predicate, Value lhs, Value rhs) {
770 result.addOperands({lhs, rhs});
771 result.types.push_back(getI1SameShape(lhs.getType()));
772 result.addAttribute(CmpIOp::getPredicateAttrName(),
773 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
774 }
775
776 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
777 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)778 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
779 const APInt &rhs) {
780 switch (predicate) {
781 case CmpIPredicate::eq:
782 return lhs.eq(rhs);
783 case CmpIPredicate::ne:
784 return lhs.ne(rhs);
785 case CmpIPredicate::slt:
786 return lhs.slt(rhs);
787 case CmpIPredicate::sle:
788 return lhs.sle(rhs);
789 case CmpIPredicate::sgt:
790 return lhs.sgt(rhs);
791 case CmpIPredicate::sge:
792 return lhs.sge(rhs);
793 case CmpIPredicate::ult:
794 return lhs.ult(rhs);
795 case CmpIPredicate::ule:
796 return lhs.ule(rhs);
797 case CmpIPredicate::ugt:
798 return lhs.ugt(rhs);
799 case CmpIPredicate::uge:
800 return lhs.uge(rhs);
801 }
802 llvm_unreachable("unknown comparison predicate");
803 }
804
805 // Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(CmpIPredicate predicate)806 static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
807 switch (predicate) {
808 case CmpIPredicate::eq:
809 case CmpIPredicate::sle:
810 case CmpIPredicate::sge:
811 case CmpIPredicate::ule:
812 case CmpIPredicate::uge:
813 return true;
814 case CmpIPredicate::ne:
815 case CmpIPredicate::slt:
816 case CmpIPredicate::sgt:
817 case CmpIPredicate::ult:
818 case CmpIPredicate::ugt:
819 return false;
820 }
821 llvm_unreachable("unknown comparison predicate");
822 }
823
824 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)825 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
826 assert(operands.size() == 2 && "cmpi takes two arguments");
827
828 if (lhs() == rhs()) {
829 auto val = applyCmpPredicateToEqualOperands(getPredicate());
830 return BoolAttr::get(val, getContext());
831 }
832
833 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
834 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
835 if (!lhs || !rhs)
836 return {};
837
838 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
839 return BoolAttr::get(val, getContext());
840 }
841
842 //===----------------------------------------------------------------------===//
843 // CmpFOp
844 //===----------------------------------------------------------------------===//
845
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)846 static void buildCmpFOp(OpBuilder &build, OperationState &result,
847 CmpFPredicate predicate, Value lhs, Value rhs) {
848 result.addOperands({lhs, rhs});
849 result.types.push_back(getI1SameShape(lhs.getType()));
850 result.addAttribute(CmpFOp::getPredicateAttrName(),
851 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
852 }
853
854 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
855 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)856 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
857 const APFloat &rhs) {
858 auto cmpResult = lhs.compare(rhs);
859 switch (predicate) {
860 case CmpFPredicate::AlwaysFalse:
861 return false;
862 case CmpFPredicate::OEQ:
863 return cmpResult == APFloat::cmpEqual;
864 case CmpFPredicate::OGT:
865 return cmpResult == APFloat::cmpGreaterThan;
866 case CmpFPredicate::OGE:
867 return cmpResult == APFloat::cmpGreaterThan ||
868 cmpResult == APFloat::cmpEqual;
869 case CmpFPredicate::OLT:
870 return cmpResult == APFloat::cmpLessThan;
871 case CmpFPredicate::OLE:
872 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
873 case CmpFPredicate::ONE:
874 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
875 case CmpFPredicate::ORD:
876 return cmpResult != APFloat::cmpUnordered;
877 case CmpFPredicate::UEQ:
878 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
879 case CmpFPredicate::UGT:
880 return cmpResult == APFloat::cmpUnordered ||
881 cmpResult == APFloat::cmpGreaterThan;
882 case CmpFPredicate::UGE:
883 return cmpResult == APFloat::cmpUnordered ||
884 cmpResult == APFloat::cmpGreaterThan ||
885 cmpResult == APFloat::cmpEqual;
886 case CmpFPredicate::ULT:
887 return cmpResult == APFloat::cmpUnordered ||
888 cmpResult == APFloat::cmpLessThan;
889 case CmpFPredicate::ULE:
890 return cmpResult == APFloat::cmpUnordered ||
891 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
892 case CmpFPredicate::UNE:
893 return cmpResult != APFloat::cmpEqual;
894 case CmpFPredicate::UNO:
895 return cmpResult == APFloat::cmpUnordered;
896 case CmpFPredicate::AlwaysTrue:
897 return true;
898 }
899 llvm_unreachable("unknown comparison predicate");
900 }
901
902 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)903 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
904 assert(operands.size() == 2 && "cmpf takes two arguments");
905
906 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
907 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
908
909 // TODO: We could actually do some intelligent things if we know only one
910 // of the operands, but it's inf or nan.
911 if (!lhs || !rhs)
912 return {};
913
914 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
915 return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
916 }
917
918 //===----------------------------------------------------------------------===//
919 // CondBranchOp
920 //===----------------------------------------------------------------------===//
921
922 namespace {
923 /// cond_br true, ^bb1, ^bb2
924 /// -> br ^bb1
925 /// cond_br false, ^bb1, ^bb2
926 /// -> br ^bb2
927 ///
928 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
929 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
930
matchAndRewrite__anon2fd8af2c0c11::SimplifyConstCondBranchPred931 LogicalResult matchAndRewrite(CondBranchOp condbr,
932 PatternRewriter &rewriter) const override {
933 if (matchPattern(condbr.getCondition(), m_NonZero())) {
934 // True branch taken.
935 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
936 condbr.getTrueOperands());
937 return success();
938 } else if (matchPattern(condbr.getCondition(), m_Zero())) {
939 // False branch taken.
940 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
941 condbr.getFalseOperands());
942 return success();
943 }
944 return failure();
945 }
946 };
947
948 /// cond_br %cond, ^bb1, ^bb2
949 /// ^bb1
950 /// br ^bbN(...)
951 /// ^bb2
952 /// br ^bbK(...)
953 ///
954 /// -> cond_br %cond, ^bbN(...), ^bbK(...)
955 ///
956 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
957 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
958
matchAndRewrite__anon2fd8af2c0c11::SimplifyPassThroughCondBranch959 LogicalResult matchAndRewrite(CondBranchOp condbr,
960 PatternRewriter &rewriter) const override {
961 Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
962 ValueRange trueDestOperands = condbr.getTrueOperands();
963 ValueRange falseDestOperands = condbr.getFalseOperands();
964 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
965
966 // Try to collapse one of the current successors.
967 LogicalResult collapsedTrue =
968 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
969 LogicalResult collapsedFalse =
970 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
971 if (failed(collapsedTrue) && failed(collapsedFalse))
972 return failure();
973
974 // Create a new branch with the collapsed successors.
975 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
976 trueDest, trueDestOperands,
977 falseDest, falseDestOperands);
978 return success();
979 }
980 };
981
982 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
983 /// -> br ^bb1(A, ..., N)
984 ///
985 /// cond_br %cond, ^bb1(A), ^bb1(B)
986 /// -> %select = select %cond, A, B
987 /// br ^bb1(%select)
988 ///
989 struct SimplifyCondBranchIdenticalSuccessors
990 : public OpRewritePattern<CondBranchOp> {
991 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
992
matchAndRewrite__anon2fd8af2c0c11::SimplifyCondBranchIdenticalSuccessors993 LogicalResult matchAndRewrite(CondBranchOp condbr,
994 PatternRewriter &rewriter) const override {
995 // Check that the true and false destinations are the same and have the same
996 // operands.
997 Block *trueDest = condbr.trueDest();
998 if (trueDest != condbr.falseDest())
999 return failure();
1000
1001 // If all of the operands match, no selects need to be generated.
1002 OperandRange trueOperands = condbr.getTrueOperands();
1003 OperandRange falseOperands = condbr.getFalseOperands();
1004 if (trueOperands == falseOperands) {
1005 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
1006 return success();
1007 }
1008
1009 // Otherwise, if the current block is the only predecessor insert selects
1010 // for any mismatched branch operands.
1011 if (trueDest->getUniquePredecessor() != condbr->getBlock())
1012 return failure();
1013
1014 // Generate a select for any operands that differ between the two.
1015 SmallVector<Value, 8> mergedOperands;
1016 mergedOperands.reserve(trueOperands.size());
1017 Value condition = condbr.getCondition();
1018 for (auto it : llvm::zip(trueOperands, falseOperands)) {
1019 if (std::get<0>(it) == std::get<1>(it))
1020 mergedOperands.push_back(std::get<0>(it));
1021 else
1022 mergedOperands.push_back(rewriter.create<SelectOp>(
1023 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1024 }
1025
1026 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1027 return success();
1028 }
1029 };
1030
1031 /// ...
1032 /// cond_br %cond, ^bb1(...), ^bb2(...)
1033 /// ...
1034 /// ^bb1: // has single predecessor
1035 /// ...
1036 /// cond_br %cond, ^bb3(...), ^bb4(...)
1037 ///
1038 /// ->
1039 ///
1040 /// ...
1041 /// cond_br %cond, ^bb1(...), ^bb2(...)
1042 /// ...
1043 /// ^bb1: // has single predecessor
1044 /// ...
1045 /// br ^bb3(...)
1046 ///
1047 struct SimplifyCondBranchFromCondBranchOnSameCondition
1048 : public OpRewritePattern<CondBranchOp> {
1049 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1050
matchAndRewrite__anon2fd8af2c0c11::SimplifyCondBranchFromCondBranchOnSameCondition1051 LogicalResult matchAndRewrite(CondBranchOp condbr,
1052 PatternRewriter &rewriter) const override {
1053 // Check that we have a single distinct predecessor.
1054 Block *currentBlock = condbr->getBlock();
1055 Block *predecessor = currentBlock->getSinglePredecessor();
1056 if (!predecessor)
1057 return failure();
1058
1059 // Check that the predecessor terminates with a conditional branch to this
1060 // block and that it branches on the same condition.
1061 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
1062 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
1063 return failure();
1064
1065 // Fold this branch to an unconditional branch.
1066 if (currentBlock == predBranch.trueDest())
1067 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
1068 condbr.trueDestOperands());
1069 else
1070 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
1071 condbr.falseDestOperands());
1072 return success();
1073 }
1074 };
1075 } // end anonymous namespace
1076
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1077 void CondBranchOp::getCanonicalizationPatterns(
1078 OwningRewritePatternList &results, MLIRContext *context) {
1079 results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1080 SimplifyCondBranchIdenticalSuccessors,
1081 SimplifyCondBranchFromCondBranchOnSameCondition>(context);
1082 }
1083
1084 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1085 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1086 assert(index < getNumSuccessors() && "invalid successor index");
1087 return index == trueIndex ? trueDestOperandsMutable()
1088 : falseDestOperandsMutable();
1089 }
1090
getSuccessorForOperands(ArrayRef<Attribute> operands)1091 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1092 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1093 return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1094 return nullptr;
1095 }
1096
1097 //===----------------------------------------------------------------------===//
1098 // Constant*Op
1099 //===----------------------------------------------------------------------===//
1100
print(OpAsmPrinter & p,ConstantOp & op)1101 static void print(OpAsmPrinter &p, ConstantOp &op) {
1102 p << "constant ";
1103 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
1104
1105 if (op.getAttrs().size() > 1)
1106 p << ' ';
1107 p << op.getValue();
1108
1109 // If the value is a symbol reference, print a trailing type.
1110 if (op.getValue().isa<SymbolRefAttr>())
1111 p << " : " << op.getType();
1112 }
1113
parseConstantOp(OpAsmParser & parser,OperationState & result)1114 static ParseResult parseConstantOp(OpAsmParser &parser,
1115 OperationState &result) {
1116 Attribute valueAttr;
1117 if (parser.parseOptionalAttrDict(result.attributes) ||
1118 parser.parseAttribute(valueAttr, "value", result.attributes))
1119 return failure();
1120
1121 // If the attribute is a symbol reference, then we expect a trailing type.
1122 Type type;
1123 if (!valueAttr.isa<SymbolRefAttr>())
1124 type = valueAttr.getType();
1125 else if (parser.parseColonType(type))
1126 return failure();
1127
1128 // Add the attribute type to the list.
1129 return parser.addTypeToList(type, result.types);
1130 }
1131
1132 /// The constant op requires an attribute, and furthermore requires that it
1133 /// matches the return type.
verify(ConstantOp & op)1134 static LogicalResult verify(ConstantOp &op) {
1135 auto value = op.getValue();
1136 if (!value)
1137 return op.emitOpError("requires a 'value' attribute");
1138
1139 auto type = op.getType();
1140 if (!value.getType().isa<NoneType>() && type != value.getType())
1141 return op.emitOpError() << "requires attribute's type (" << value.getType()
1142 << ") to match op's return type (" << type << ")";
1143
1144 if (type.isa<IndexType>() || value.isa<BoolAttr>())
1145 return success();
1146
1147 if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1148 // If the type has a known bitwidth we verify that the value can be
1149 // represented with the given bitwidth.
1150 auto bitwidth = type.cast<IntegerType>().getWidth();
1151 auto intVal = intAttr.getValue();
1152 if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1153 return op.emitOpError("requires 'value' to be an integer within the "
1154 "range of the integer result type");
1155 return success();
1156 }
1157
1158 if (type.isa<FloatType>()) {
1159 if (!value.isa<FloatAttr>())
1160 return op.emitOpError("requires 'value' to be a floating point constant");
1161 return success();
1162 }
1163
1164 if (type.isa<ShapedType>()) {
1165 if (!value.isa<ElementsAttr>())
1166 return op.emitOpError("requires 'value' to be a shaped constant");
1167 return success();
1168 }
1169
1170 if (type.isa<FunctionType>()) {
1171 auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1172 if (!fnAttr)
1173 return op.emitOpError("requires 'value' to be a function reference");
1174
1175 // Try to find the referenced function.
1176 auto fn =
1177 op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1178 if (!fn)
1179 return op.emitOpError()
1180 << "reference to undefined function '" << fnAttr.getValue() << "'";
1181
1182 // Check that the referenced function has the correct type.
1183 if (fn.getType() != type)
1184 return op.emitOpError("reference to function with mismatched type");
1185
1186 return success();
1187 }
1188
1189 if (type.isa<NoneType>() && value.isa<UnitAttr>())
1190 return success();
1191
1192 return op.emitOpError("unsupported 'value' attribute: ") << value;
1193 }
1194
fold(ArrayRef<Attribute> operands)1195 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1196 assert(operands.empty() && "constant has no operands");
1197 return getValue();
1198 }
1199
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1200 void ConstantOp::getAsmResultNames(
1201 function_ref<void(Value, StringRef)> setNameFn) {
1202 Type type = getType();
1203 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1204 IntegerType intTy = type.dyn_cast<IntegerType>();
1205
1206 // Sugar i1 constants with 'true' and 'false'.
1207 if (intTy && intTy.getWidth() == 1)
1208 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1209
1210 // Otherwise, build a complex name with the value and type.
1211 SmallString<32> specialNameBuffer;
1212 llvm::raw_svector_ostream specialName(specialNameBuffer);
1213 specialName << 'c' << intCst.getInt();
1214 if (intTy)
1215 specialName << '_' << type;
1216 setNameFn(getResult(), specialName.str());
1217
1218 } else if (type.isa<FunctionType>()) {
1219 setNameFn(getResult(), "f");
1220 } else {
1221 setNameFn(getResult(), "cst");
1222 }
1223 }
1224
1225 /// Returns true if a constant operation can be built with the given value and
1226 /// result type.
isBuildableWith(Attribute value,Type type)1227 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1228 // SymbolRefAttr can only be used with a function type.
1229 if (value.isa<SymbolRefAttr>())
1230 return type.isa<FunctionType>();
1231 // Otherwise, the attribute must have the same type as 'type'.
1232 if (value.getType() != type)
1233 return false;
1234 // Finally, check that the attribute kind is handled.
1235 return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1236 }
1237
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1238 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1239 const APFloat &value, FloatType type) {
1240 ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1241 }
1242
classof(Operation * op)1243 bool ConstantFloatOp::classof(Operation *op) {
1244 return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1245 }
1246
1247 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1248 bool ConstantIntOp::classof(Operation *op) {
1249 return ConstantOp::classof(op) &&
1250 op->getResult(0).getType().isSignlessInteger();
1251 }
1252
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1253 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1254 int64_t value, unsigned width) {
1255 Type type = builder.getIntegerType(width);
1256 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1257 }
1258
1259 /// Build a constant int op producing an integer with the specified type,
1260 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1261 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1262 int64_t value, Type type) {
1263 assert(type.isSignlessInteger() &&
1264 "ConstantIntOp can only have signless integer type");
1265 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1266 }
1267
1268 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1269 bool ConstantIndexOp::classof(Operation *op) {
1270 return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1271 }
1272
build(OpBuilder & builder,OperationState & result,int64_t value)1273 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1274 int64_t value) {
1275 Type type = builder.getIndexType();
1276 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1277 }
1278
1279 //===----------------------------------------------------------------------===//
1280 // DeallocOp
1281 //===----------------------------------------------------------------------===//
1282 namespace {
1283 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
1284 /// by other Dealloc operations.
1285 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
1286 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1287
matchAndRewrite__anon2fd8af2c0d11::SimplifyDeadDealloc1288 LogicalResult matchAndRewrite(DeallocOp dealloc,
1289 PatternRewriter &rewriter) const override {
1290 // Check that the memref operand's defining operation is an AllocOp.
1291 Value memref = dealloc.memref();
1292 if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
1293 return failure();
1294
1295 // Check that all of the uses of the AllocOp are other DeallocOps.
1296 for (auto *user : memref.getUsers())
1297 if (!isa<DeallocOp>(user))
1298 return failure();
1299
1300 // Erase the dealloc operation.
1301 rewriter.eraseOp(dealloc);
1302 return success();
1303 }
1304 };
1305 } // end anonymous namespace.
1306
verify(DeallocOp op)1307 static LogicalResult verify(DeallocOp op) {
1308 if (!op.memref().getType().isa<MemRefType>())
1309 return op.emitOpError("operand must be a memref");
1310 return success();
1311 }
1312
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1313 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1314 MLIRContext *context) {
1315 results.insert<SimplifyDeadDealloc>(context);
1316 }
1317
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1318 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
1319 SmallVectorImpl<OpFoldResult> &results) {
1320 /// dealloc(memrefcast) -> dealloc
1321 return foldMemRefCast(*this);
1322 }
1323
1324 //===----------------------------------------------------------------------===//
1325 // DimOp
1326 //===----------------------------------------------------------------------===//
1327
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,int64_t index)1328 void DimOp::build(OpBuilder &builder, OperationState &result,
1329 Value memrefOrTensor, int64_t index) {
1330 auto loc = result.location;
1331 Value indexValue = builder.create<ConstantIndexOp>(loc, index);
1332 build(builder, result, memrefOrTensor, indexValue);
1333 }
1334
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,Value index)1335 void DimOp::build(OpBuilder &builder, OperationState &result,
1336 Value memrefOrTensor, Value index) {
1337 auto indexTy = builder.getIndexType();
1338 build(builder, result, indexTy, memrefOrTensor, index);
1339 }
1340
getConstantIndex()1341 Optional<int64_t> DimOp::getConstantIndex() {
1342 if (auto constantOp = index().getDefiningOp<ConstantOp>())
1343 return constantOp.getValue().cast<IntegerAttr>().getInt();
1344 return {};
1345 }
1346
verify(DimOp op)1347 static LogicalResult verify(DimOp op) {
1348 // Assume unknown index to be in range.
1349 Optional<int64_t> index = op.getConstantIndex();
1350 if (!index.hasValue())
1351 return success();
1352
1353 // Check that constant index is not knowingly out of range.
1354 auto type = op.memrefOrTensor().getType();
1355 if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
1356 if (index.getValue() >= tensorType.getRank())
1357 return op.emitOpError("index is out of range");
1358 } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
1359 if (index.getValue() >= memrefType.getRank())
1360 return op.emitOpError("index is out of range");
1361 } else if (type.isa<UnrankedTensorType>() || type.isa<UnrankedMemRefType>()) {
1362 // Assume index to be in range.
1363 } else {
1364 llvm_unreachable("expected operand with tensor or memref type");
1365 }
1366
1367 return success();
1368 }
1369
fold(ArrayRef<Attribute> operands)1370 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1371 auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
1372
1373 // All forms of folding require a known index.
1374 if (!index)
1375 return {};
1376
1377 auto argTy = memrefOrTensor().getType();
1378 // Fold if the shape extent along the given index is known.
1379 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
1380 // Folding for unranked types (UnrankedMemRefType, UnrankedTensorType) is
1381 // not supported.
1382 if (!shapedTy.hasRank())
1383 return {};
1384 if (!shapedTy.isDynamicDim(index.getInt())) {
1385 Builder builder(getContext());
1386 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
1387 }
1388 }
1389
1390 Operation *definingOp = memrefOrTensor().getDefiningOp();
1391 // dim(tensor_load(memref)) -> dim(memref)
1392 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
1393 setOperand(0, tensorLoadOp.memref());
1394 return getResult();
1395 }
1396
1397 // Fold dim to the operand of dynamic_tensor_from_elements.
1398 if (auto fromElements =
1399 dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) {
1400 auto resultType =
1401 fromElements.getResult().getType().cast<RankedTensorType>();
1402 // The case where the type encodes the size of the dimension is handled
1403 // above.
1404 assert(resultType.getShape()[index.getInt()] ==
1405 RankedTensorType::kDynamicSize);
1406
1407 // Find the operand of the fromElements that corresponds to this index.
1408 auto dynExtents = fromElements.dynamicExtents().begin();
1409 for (auto dim : resultType.getShape().take_front(index.getInt()))
1410 if (dim == RankedTensorType::kDynamicSize)
1411 dynExtents++;
1412
1413 return Value{*dynExtents};
1414 }
1415
1416 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1417 auto memrefType = argTy.dyn_cast<MemRefType>();
1418 if (!memrefType)
1419 return {};
1420
1421 // The size at the given index is now known to be a dynamic size of a memref.
1422 unsigned unsignedIndex = index.getValue().getZExtValue();
1423 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1424 return *(alloc.getDynamicSizes().begin() +
1425 memrefType.getDynamicDimIndex(unsignedIndex));
1426
1427 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1428 return *(view.getDynamicSizes().begin() +
1429 memrefType.getDynamicDimIndex(unsignedIndex));
1430
1431 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1432 assert(subview.isDynamicSize(unsignedIndex) &&
1433 "Expected dynamic subview size");
1434 return subview.getDynamicSize(unsignedIndex);
1435 }
1436
1437 // dim(memrefcast) -> dim
1438 if (succeeded(foldMemRefCast(*this)))
1439 return getResult();
1440
1441 return {};
1442 }
1443
1444 namespace {
1445 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1446 /// operand.
1447 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1448 using OpRewritePattern<DimOp>::OpRewritePattern;
1449
matchAndRewrite__anon2fd8af2c0e11::DimOfMemRefReshape1450 LogicalResult matchAndRewrite(DimOp dim,
1451 PatternRewriter &rewriter) const override {
1452 auto reshape = dim.memrefOrTensor().getDefiningOp<MemRefReshapeOp>();
1453
1454 if (!reshape)
1455 return failure();
1456
1457 // Place the load directly after the reshape to ensure that the shape memref
1458 // was not mutated.
1459 rewriter.setInsertionPointAfter(reshape);
1460 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
1461 llvm::makeArrayRef({dim.index()}));
1462 return success();
1463 }
1464 };
1465 } // end anonymous namespace.
1466
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1467 void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1468 MLIRContext *context) {
1469 results.insert<DimOfMemRefReshape>(context);
1470 }
1471
1472 // ---------------------------------------------------------------------------
1473 // DmaStartOp
1474 // ---------------------------------------------------------------------------
1475
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)1476 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1477 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1478 ValueRange destIndices, Value numElements,
1479 Value tagMemRef, ValueRange tagIndices, Value stride,
1480 Value elementsPerStride) {
1481 result.addOperands(srcMemRef);
1482 result.addOperands(srcIndices);
1483 result.addOperands(destMemRef);
1484 result.addOperands(destIndices);
1485 result.addOperands({numElements, tagMemRef});
1486 result.addOperands(tagIndices);
1487 if (stride)
1488 result.addOperands({stride, elementsPerStride});
1489 }
1490
print(OpAsmPrinter & p)1491 void DmaStartOp::print(OpAsmPrinter &p) {
1492 p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1493 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1494 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1495 if (isStrided())
1496 p << ", " << getStride() << ", " << getNumElementsPerStride();
1497
1498 p.printOptionalAttrDict(getAttrs());
1499 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1500 << ", " << getTagMemRef().getType();
1501 }
1502
1503 // Parse DmaStartOp.
1504 // Ex:
1505 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1506 // %tag[%index], %stride, %num_elt_per_stride :
1507 // : memref<3076 x f32, 0>,
1508 // memref<1024 x f32, 2>,
1509 // memref<1 x i32>
1510 //
parse(OpAsmParser & parser,OperationState & result)1511 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1512 OpAsmParser::OperandType srcMemRefInfo;
1513 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
1514 OpAsmParser::OperandType dstMemRefInfo;
1515 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
1516 OpAsmParser::OperandType numElementsInfo;
1517 OpAsmParser::OperandType tagMemrefInfo;
1518 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
1519 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1520
1521 SmallVector<Type, 3> types;
1522 auto indexType = parser.getBuilder().getIndexType();
1523
1524 // Parse and resolve the following list of operands:
1525 // *) source memref followed by its indices (in square brackets).
1526 // *) destination memref followed by its indices (in square brackets).
1527 // *) dma size in KiB.
1528 if (parser.parseOperand(srcMemRefInfo) ||
1529 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1530 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1531 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1532 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1533 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1534 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1535 return failure();
1536
1537 // Parse optional stride and elements per stride.
1538 if (parser.parseTrailingOperandList(strideInfo))
1539 return failure();
1540
1541 bool isStrided = strideInfo.size() == 2;
1542 if (!strideInfo.empty() && !isStrided) {
1543 return parser.emitError(parser.getNameLoc(),
1544 "expected two stride related operands");
1545 }
1546
1547 if (parser.parseColonTypeList(types))
1548 return failure();
1549 if (types.size() != 3)
1550 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1551
1552 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1553 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1554 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1555 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1556 // size should be an index.
1557 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1558 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1559 // tag indices should be index.
1560 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1561 return failure();
1562
1563 if (isStrided) {
1564 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1565 return failure();
1566 }
1567
1568 return success();
1569 }
1570
verify()1571 LogicalResult DmaStartOp::verify() {
1572 unsigned numOperands = getNumOperands();
1573
1574 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1575 // the number of elements.
1576 if (numOperands < 4)
1577 return emitOpError("expected at least 4 operands");
1578
1579 // Check types of operands. The order of these calls is important: the later
1580 // calls rely on some type properties to compute the operand position.
1581 // 1. Source memref.
1582 if (!getSrcMemRef().getType().isa<MemRefType>())
1583 return emitOpError("expected source to be of memref type");
1584 if (numOperands < getSrcMemRefRank() + 4)
1585 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1586 << " operands";
1587 if (!getSrcIndices().empty() &&
1588 !llvm::all_of(getSrcIndices().getTypes(),
1589 [](Type t) { return t.isIndex(); }))
1590 return emitOpError("expected source indices to be of index type");
1591
1592 // 2. Destination memref.
1593 if (!getDstMemRef().getType().isa<MemRefType>())
1594 return emitOpError("expected destination to be of memref type");
1595 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1596 if (numOperands < numExpectedOperands)
1597 return emitOpError() << "expected at least " << numExpectedOperands
1598 << " operands";
1599 if (!getDstIndices().empty() &&
1600 !llvm::all_of(getDstIndices().getTypes(),
1601 [](Type t) { return t.isIndex(); }))
1602 return emitOpError("expected destination indices to be of index type");
1603
1604 // 3. Number of elements.
1605 if (!getNumElements().getType().isIndex())
1606 return emitOpError("expected num elements to be of index type");
1607
1608 // 4. Tag memref.
1609 if (!getTagMemRef().getType().isa<MemRefType>())
1610 return emitOpError("expected tag to be of memref type");
1611 numExpectedOperands += getTagMemRefRank();
1612 if (numOperands < numExpectedOperands)
1613 return emitOpError() << "expected at least " << numExpectedOperands
1614 << " operands";
1615 if (!getTagIndices().empty() &&
1616 !llvm::all_of(getTagIndices().getTypes(),
1617 [](Type t) { return t.isIndex(); }))
1618 return emitOpError("expected tag indices to be of index type");
1619
1620 // DMAs from different memory spaces supported.
1621 if (getSrcMemorySpace() == getDstMemorySpace())
1622 return emitOpError("DMA should be between different memory spaces");
1623
1624 // Optional stride-related operands must be either both present or both
1625 // absent.
1626 if (numOperands != numExpectedOperands &&
1627 numOperands != numExpectedOperands + 2)
1628 return emitOpError("incorrect number of operands");
1629
1630 // 5. Strides.
1631 if (isStrided()) {
1632 if (!getStride().getType().isIndex() ||
1633 !getNumElementsPerStride().getType().isIndex())
1634 return emitOpError(
1635 "expected stride and num elements per stride to be of type index");
1636 }
1637
1638 return success();
1639 }
1640
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1641 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1642 SmallVectorImpl<OpFoldResult> &results) {
1643 /// dma_start(memrefcast) -> dma_start
1644 return foldMemRefCast(*this);
1645 }
1646
1647 // ---------------------------------------------------------------------------
1648 // DmaWaitOp
1649 // ---------------------------------------------------------------------------
1650
build(OpBuilder & builder,OperationState & result,Value tagMemRef,ValueRange tagIndices,Value numElements)1651 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1652 Value tagMemRef, ValueRange tagIndices,
1653 Value numElements) {
1654 result.addOperands(tagMemRef);
1655 result.addOperands(tagIndices);
1656 result.addOperands(numElements);
1657 }
1658
print(OpAsmPrinter & p)1659 void DmaWaitOp::print(OpAsmPrinter &p) {
1660 p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
1661 << getNumElements();
1662 p.printOptionalAttrDict(getAttrs());
1663 p << " : " << getTagMemRef().getType();
1664 }
1665
1666 // Parse DmaWaitOp.
1667 // Eg:
1668 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1669 //
parse(OpAsmParser & parser,OperationState & result)1670 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1671 OpAsmParser::OperandType tagMemrefInfo;
1672 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1673 Type type;
1674 auto indexType = parser.getBuilder().getIndexType();
1675 OpAsmParser::OperandType numElementsInfo;
1676
1677 // Parse tag memref, its indices, and dma size.
1678 if (parser.parseOperand(tagMemrefInfo) ||
1679 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1680 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1681 parser.parseColonType(type) ||
1682 parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1683 parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1684 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1685 return failure();
1686
1687 return success();
1688 }
1689
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1690 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1691 SmallVectorImpl<OpFoldResult> &results) {
1692 /// dma_wait(memrefcast) -> dma_wait
1693 return foldMemRefCast(*this);
1694 }
1695
verify()1696 LogicalResult DmaWaitOp::verify() {
1697 // Mandatory non-variadic operands are tag and the number of elements.
1698 if (getNumOperands() < 2)
1699 return emitOpError() << "expected at least 2 operands";
1700
1701 // Check types of operands. The order of these calls is important: the later
1702 // calls rely on some type properties to compute the operand position.
1703 if (!getTagMemRef().getType().isa<MemRefType>())
1704 return emitOpError() << "expected tag to be of memref type";
1705
1706 if (getNumOperands() != 2 + getTagMemRefRank())
1707 return emitOpError() << "expected " << 2 + getTagMemRefRank()
1708 << " operands";
1709
1710 if (!getTagIndices().empty() &&
1711 !llvm::all_of(getTagIndices().getTypes(),
1712 [](Type t) { return t.isIndex(); }))
1713 return emitOpError() << "expected tag indices to be of index type";
1714
1715 if (!getNumElements().getType().isIndex())
1716 return emitOpError()
1717 << "expected the number of elements to be of index type";
1718
1719 return success();
1720 }
1721
1722 //===----------------------------------------------------------------------===//
1723 // DynamicTensorFromElementsOp
1724 //===----------------------------------------------------------------------===//
1725
parseDynamicTensorFromElementsOp(OpAsmParser & parser,OperationState & result)1726 static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser,
1727 OperationState &result) {
1728 // Parse operands.
1729 SmallVector<OpAsmParser::OperandType, 4> dynamicExtents;
1730 Type indexTy = parser.getBuilder().getIndexType();
1731 if (parser.parseOperandList(dynamicExtents) ||
1732 parser.resolveOperands(dynamicExtents, indexTy, result.operands))
1733 return failure();
1734
1735 // Parse body.
1736 Region *body = result.addRegion();
1737 if (parser.parseRegion(*body, {}, {}))
1738 return failure();
1739
1740 // Parse result type.
1741 Type resultType;
1742 if (parser.parseOptionalAttrDict(result.attributes) ||
1743 parser.parseColonType(resultType))
1744 return failure();
1745 result.addTypes(resultType);
1746
1747 return success();
1748 }
1749
print(OpAsmPrinter & p,DynamicTensorFromElementsOp op)1750 static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) {
1751 p << "dynamic_tensor_from_elements " << op.dynamicExtents();
1752 p.printRegion(op.body());
1753 p.printOptionalAttrDict(op.getAttrs());
1754 p << " : " << op.getType();
1755 }
1756
verify(DynamicTensorFromElementsOp op)1757 static LogicalResult verify(DynamicTensorFromElementsOp op) {
1758 // Ensure that the tensor type has as many dynamic dimensions as are specified
1759 // by the operands.
1760 RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
1761 if (op.getNumOperands() != resultTy.getNumDynamicDims())
1762 return op.emitError("must have as many index operands as dynamic extents "
1763 "in the result type");
1764
1765 // Ensure that region arguments span the index space.
1766 if (!llvm::all_of(op.body().getArgumentTypes(),
1767 [](Type ty) { return ty.isIndex(); }))
1768 return op.emitError("all body arguments must be index");
1769 if (op.body().getNumArguments() != resultTy.getRank())
1770 return op.emitError("must have one body argument per input dimension");
1771
1772 // Ensure that the region yields an element of the right type.
1773 auto yieldOp =
1774 llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
1775 if (yieldOp.value().getType() != resultTy.getElementType())
1776 return op.emitOpError(
1777 "body must be terminated with a `yield` operation of the tensor "
1778 "element type");
1779
1780 return success();
1781 }
1782
build(OpBuilder & b,OperationState & result,Type resultTy,ValueRange dynamicExtents,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)1783 void DynamicTensorFromElementsOp::build(
1784 OpBuilder &b, OperationState &result, Type resultTy,
1785 ValueRange dynamicExtents,
1786 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1787 build(b, result, resultTy, dynamicExtents);
1788
1789 // Build and populate body.
1790 OpBuilder::InsertionGuard guard(b);
1791 Region *bodyRegion = result.regions.front().get();
1792 auto rank = resultTy.cast<RankedTensorType>().getRank();
1793 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1794 Block *bodyBlock =
1795 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
1796 bodyBuilder(b, result.location, bodyBlock->getArguments());
1797 }
1798
1799 namespace {
1800
1801 /// Canonicalizes dynamic_tensor_from_elements operations with a constant
1802 /// operand into the equivalent operation with the operand expressed in the
1803 /// result type, instead. We also insert a type cast to make sure that the
1804 /// resulting IR is still well-typed.
1805 struct StaticDynamicTensorFromElements
1806 : public OpRewritePattern<DynamicTensorFromElementsOp> {
1807 using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
1808
matchAndRewrite__anon2fd8af2c1411::StaticDynamicTensorFromElements1809 LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
1810 PatternRewriter &rewriter) const final {
1811 auto resultType =
1812 tensorFromElements.getResult().getType().cast<RankedTensorType>();
1813
1814 if (resultType.hasStaticShape())
1815 return failure();
1816
1817 SmallVector<Value, 4> newOperands;
1818 SmallVector<int64_t, 4> newShape;
1819 auto operandsIt = tensorFromElements.dynamicExtents().begin();
1820
1821 for (int64_t dim : resultType.getShape()) {
1822 if (dim != RankedTensorType::kDynamicSize) {
1823 newShape.push_back(dim);
1824 continue;
1825 }
1826 APInt index;
1827 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1828 newShape.push_back(RankedTensorType::kDynamicSize);
1829 newOperands.push_back(*operandsIt++);
1830 continue;
1831 }
1832 newShape.push_back(index.getSExtValue());
1833 operandsIt++;
1834 }
1835
1836 if (newOperands.size() == tensorFromElements.dynamicExtents().size())
1837 return failure();
1838
1839 auto loc = tensorFromElements.getLoc();
1840 auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
1841 loc, RankedTensorType::get(newShape, resultType.getElementType()),
1842 newOperands);
1843 rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
1844 newOp.body().begin());
1845 rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
1846 newOp);
1847 return success();
1848 }
1849 };
1850
1851 /// Canonicalizes the pattern of the form
1852 ///
1853 /// %tensor = dynamic_tensor_from_elements %x {
1854 /// ^bb0(%arg0: index): // no predecessors
1855 /// <computation>
1856 /// yield %1 : index
1857 /// } : tensor<?xindex>
1858 /// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
1859 ///
1860 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1861 /// dynamic_tensor_from_elements operation has no side-effects.
1862 struct ExtractElementFromDynamicTensorFromElements
1863 : public OpRewritePattern<ExtractElementOp> {
1864 using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1865
matchAndRewrite__anon2fd8af2c1411::ExtractElementFromDynamicTensorFromElements1866 LogicalResult matchAndRewrite(ExtractElementOp extract,
1867 PatternRewriter &rewriter) const final {
1868 auto tensorFromElements =
1869 extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
1870 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1871 return failure();
1872
1873 BlockAndValueMapping mapping;
1874 Block *body = tensorFromElements.getBody();
1875 mapping.map(body->getArguments(), extract.indices());
1876 for (auto &op : body->without_terminator())
1877 rewriter.clone(op, mapping);
1878
1879 auto yield = cast<YieldOp>(body->getTerminator());
1880
1881 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
1882 return success();
1883 }
1884 };
1885
1886 /// Canonicalizes the pattern of the form
1887 ///
1888 /// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
1889 /// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
1890 ///
1891 /// to
1892 ///
1893 /// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
1894 struct ExtractElementFromTensorCast
1895 : public OpRewritePattern<ExtractElementOp> {
1896 using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1897
matchAndRewrite__anon2fd8af2c1411::ExtractElementFromTensorCast1898 LogicalResult matchAndRewrite(ExtractElementOp extract,
1899 PatternRewriter &rewriter) const final {
1900 auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
1901 if (!tensorCast)
1902 return failure();
1903
1904 rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
1905 extract.getIndices());
1906 return success();
1907 }
1908 };
1909
1910 } // namespace
1911
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1912 void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
1913 OwningRewritePatternList &results, MLIRContext *context) {
1914 results.insert<ExtractElementFromDynamicTensorFromElements,
1915 ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
1916 context);
1917 }
1918
1919 //===----------------------------------------------------------------------===//
1920 // ExtractElementOp
1921 //===----------------------------------------------------------------------===//
1922
verify(ExtractElementOp op)1923 static LogicalResult verify(ExtractElementOp op) {
1924 // Verify the # indices match if we have a ranked type.
1925 auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
1926 if (aggregateType.hasRank() &&
1927 aggregateType.getRank() != op.getNumOperands() - 1)
1928 return op.emitOpError("incorrect number of indices for extract_element");
1929
1930 return success();
1931 }
1932
fold(ArrayRef<Attribute> operands)1933 OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1934 assert(!operands.empty() && "extract_element takes at least one operand");
1935
1936 // The aggregate operand must be a known constant.
1937 Attribute aggregate = operands.front();
1938 if (!aggregate)
1939 return {};
1940
1941 // If this is a splat elements attribute, simply return the value. All of the
1942 // elements of a splat attribute are the same.
1943 if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
1944 return splatAggregate.getSplatValue();
1945
1946 // Otherwise, collect the constant indices into the aggregate.
1947 SmallVector<uint64_t, 8> indices;
1948 for (Attribute indice : llvm::drop_begin(operands, 1)) {
1949 if (!indice || !indice.isa<IntegerAttr>())
1950 return {};
1951 indices.push_back(indice.cast<IntegerAttr>().getInt());
1952 }
1953
1954 // If this is an elements attribute, query the value at the given indices.
1955 auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
1956 if (elementsAttr && elementsAttr.isValidIndex(indices))
1957 return elementsAttr.getValue(indices);
1958 return {};
1959 }
1960
1961 //===----------------------------------------------------------------------===//
1962 // TensorFromElementsOp
1963 //===----------------------------------------------------------------------===//
1964
build(OpBuilder & builder,OperationState & result,Type elementType,ValueRange elements)1965 void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
1966 Type elementType, ValueRange elements) {
1967 Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
1968 elementType);
1969 result.addOperands(elements);
1970 result.addTypes(resultTy);
1971 }
1972
build(OpBuilder & builder,OperationState & result,ValueRange elements)1973 void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
1974 ValueRange elements) {
1975 assert(!elements.empty() && "expected at least one element");
1976 build(builder, result, elements.front().getType(), elements);
1977 }
1978
1979 namespace {
1980
1981 // Canonicalizes the pattern of the form
1982 //
1983 // %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
1984 // %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
1985 //
1986 // to just %element.
1987 struct ExtractElementFromTensorFromElements
1988 : public OpRewritePattern<ExtractElementOp> {
1989 using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1990
matchAndRewrite__anon2fd8af2c1511::ExtractElementFromTensorFromElements1991 LogicalResult matchAndRewrite(ExtractElementOp extract,
1992 PatternRewriter &rewriter) const final {
1993 if (extract.indices().size() != 1)
1994 return failure();
1995
1996 auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
1997 extract.aggregate().getDefiningOp());
1998 if (tensorFromElements == nullptr)
1999 return failure();
2000
2001 APInt index;
2002 if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
2003 return failure();
2004 rewriter.replaceOp(extract,
2005 tensorFromElements.getOperand(index.getZExtValue()));
2006 return success();
2007 }
2008 };
2009
2010 } // namespace
2011
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2012 void TensorFromElementsOp::getCanonicalizationPatterns(
2013 OwningRewritePatternList &results, MLIRContext *context) {
2014 results.insert<ExtractElementFromTensorFromElements>(context);
2015 }
2016
2017 //===----------------------------------------------------------------------===//
2018 // FPExtOp
2019 //===----------------------------------------------------------------------===//
2020
areCastCompatible(Type a,Type b)2021 bool FPExtOp::areCastCompatible(Type a, Type b) {
2022 if (auto fa = a.dyn_cast<FloatType>())
2023 if (auto fb = b.dyn_cast<FloatType>())
2024 return fa.getWidth() < fb.getWidth();
2025 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2026 }
2027
2028 //===----------------------------------------------------------------------===//
2029 // FPToSIOp
2030 //===----------------------------------------------------------------------===//
2031
areCastCompatible(Type a,Type b)2032 bool FPToSIOp::areCastCompatible(Type a, Type b) {
2033 if (a.isa<FloatType>() && b.isSignlessInteger())
2034 return true;
2035 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2036 }
2037
2038 //===----------------------------------------------------------------------===//
2039 // FPToUIOp
2040 //===----------------------------------------------------------------------===//
2041
areCastCompatible(Type a,Type b)2042 bool FPToUIOp::areCastCompatible(Type a, Type b) {
2043 if (a.isa<FloatType>() && b.isSignlessInteger())
2044 return true;
2045 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2046 }
2047
2048 //===----------------------------------------------------------------------===//
2049 // FPTruncOp
2050 //===----------------------------------------------------------------------===//
2051
areCastCompatible(Type a,Type b)2052 bool FPTruncOp::areCastCompatible(Type a, Type b) {
2053 if (auto fa = a.dyn_cast<FloatType>())
2054 if (auto fb = b.dyn_cast<FloatType>())
2055 return fa.getWidth() > fb.getWidth();
2056 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2057 }
2058
2059 //===----------------------------------------------------------------------===//
2060 // GlobalMemrefOp
2061 //===----------------------------------------------------------------------===//
2062
printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter & p,GlobalMemrefOp op,TypeAttr type,Attribute initialValue)2063 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
2064 GlobalMemrefOp op,
2065 TypeAttr type,
2066 Attribute initialValue) {
2067 p << type;
2068 if (!op.isExternal()) {
2069 p << " = ";
2070 if (op.isUninitialized())
2071 p << "uninitialized";
2072 else
2073 p.printAttributeWithoutType(initialValue);
2074 }
2075 }
2076
2077 static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & initialValue)2078 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
2079 Attribute &initialValue) {
2080 Type type;
2081 if (parser.parseType(type))
2082 return failure();
2083
2084 auto memrefType = type.dyn_cast<MemRefType>();
2085 if (!memrefType || !memrefType.hasStaticShape())
2086 return parser.emitError(parser.getNameLoc())
2087 << "type should be static shaped memref, but got " << type;
2088 typeAttr = TypeAttr::get(type);
2089
2090 if (parser.parseOptionalEqual())
2091 return success();
2092
2093 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
2094 initialValue = UnitAttr::get(parser.getBuilder().getContext());
2095 return success();
2096 }
2097
2098 Type tensorType = getTensorTypeFromMemRefType(memrefType);
2099 if (parser.parseAttribute(initialValue, tensorType))
2100 return failure();
2101 if (!initialValue.isa<ElementsAttr>())
2102 return parser.emitError(parser.getNameLoc())
2103 << "initial value should be a unit or elements attribute";
2104 return success();
2105 }
2106
verify(GlobalMemrefOp op)2107 static LogicalResult verify(GlobalMemrefOp op) {
2108 auto memrefType = op.type().dyn_cast<MemRefType>();
2109 if (!memrefType || !memrefType.hasStaticShape())
2110 return op.emitOpError("type should be static shaped memref, but got ")
2111 << op.type();
2112
2113 // Verify that the initial value, if present, is either a unit attribute or
2114 // an elements attribute.
2115 if (op.initial_value().hasValue()) {
2116 Attribute initValue = op.initial_value().getValue();
2117 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
2118 return op.emitOpError("initial value should be a unit or elements "
2119 "attribute, but got ")
2120 << initValue;
2121
2122 // Check that the type of the initial value is compatible with the type of
2123 // the global variable.
2124 if (initValue.isa<ElementsAttr>()) {
2125 Type initType = initValue.getType();
2126 Type tensorType = getTensorTypeFromMemRefType(memrefType);
2127 if (initType != tensorType)
2128 return op.emitOpError("initial value expected to be of type ")
2129 << tensorType << ", but was of type " << initType;
2130 }
2131 }
2132
2133 // TODO: verify visibility for declarations.
2134 return success();
2135 }
2136
2137 //===----------------------------------------------------------------------===//
2138 // GetGlobalMemrefOp
2139 //===----------------------------------------------------------------------===//
2140
2141 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)2142 GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2143 // Verify that the result type is same as the type of the referenced
2144 // global_memref op.
2145 auto global =
2146 symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
2147 if (!global)
2148 return emitOpError("'")
2149 << name() << "' does not reference a valid global memref";
2150
2151 Type resultType = result().getType();
2152 if (global.type() != resultType)
2153 return emitOpError("result type ")
2154 << resultType << " does not match type " << global.type()
2155 << " of the global memref @" << name();
2156 return success();
2157 }
2158
2159 //===----------------------------------------------------------------------===//
2160 // IndexCastOp
2161 //===----------------------------------------------------------------------===//
2162
2163 // Index cast is applicable from index to integer and backwards.
areCastCompatible(Type a,Type b)2164 bool IndexCastOp::areCastCompatible(Type a, Type b) {
2165 if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
2166 auto aShaped = a.cast<ShapedType>();
2167 auto bShaped = b.cast<ShapedType>();
2168
2169 return (aShaped.getShape() == bShaped.getShape()) &&
2170 areCastCompatible(aShaped.getElementType(),
2171 bShaped.getElementType());
2172 }
2173
2174 return (a.isIndex() && b.isSignlessInteger()) ||
2175 (a.isSignlessInteger() && b.isIndex());
2176 }
2177
fold(ArrayRef<Attribute> cstOperands)2178 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
2179 // Fold IndexCast(IndexCast(x)) -> x
2180 auto cast = getOperand().getDefiningOp<IndexCastOp>();
2181 if (cast && cast.getOperand().getType() == getType())
2182 return cast.getOperand();
2183
2184 // Fold IndexCast(constant) -> constant
2185 // A little hack because we go through int. Otherwise, the size
2186 // of the constant might need to change.
2187 if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
2188 return IntegerAttr::get(getType(), value.getInt());
2189
2190 return {};
2191 }
2192
2193 //===----------------------------------------------------------------------===//
2194 // LoadOp
2195 //===----------------------------------------------------------------------===//
2196
verify(LoadOp op)2197 static LogicalResult verify(LoadOp op) {
2198 if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
2199 return op.emitOpError("incorrect number of indices for load");
2200 return success();
2201 }
2202
fold(ArrayRef<Attribute> cstOperands)2203 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
2204 /// load(memrefcast) -> load
2205 if (succeeded(foldMemRefCast(*this)))
2206 return getResult();
2207 return OpFoldResult();
2208 }
2209
2210 namespace {
2211 /// Fold a load on a tensor_to_memref operation into an extract_element on the
2212 /// corresponding tensor.
2213 struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
2214 using OpRewritePattern<LoadOp>::OpRewritePattern;
2215
matchAndRewrite__anon2fd8af2c1611::LoadOfTensorToMemref2216 LogicalResult matchAndRewrite(LoadOp load,
2217 PatternRewriter &rewriter) const override {
2218 auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
2219 if (!tensorToMemref)
2220 return failure();
2221
2222 rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
2223 load.indices());
2224 return success();
2225 }
2226 };
2227 } // end anonymous namespace.
2228
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2229 void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2230 MLIRContext *context) {
2231 results.insert<LoadOfTensorToMemref>(context);
2232 }
2233
2234 //===----------------------------------------------------------------------===//
2235 // MemRefCastOp
2236 //===----------------------------------------------------------------------===//
2237
getViewSource()2238 Value MemRefCastOp::getViewSource() { return source(); }
2239
areCastCompatible(Type a,Type b)2240 bool MemRefCastOp::areCastCompatible(Type a, Type b) {
2241 auto aT = a.dyn_cast<MemRefType>();
2242 auto bT = b.dyn_cast<MemRefType>();
2243
2244 auto uaT = a.dyn_cast<UnrankedMemRefType>();
2245 auto ubT = b.dyn_cast<UnrankedMemRefType>();
2246
2247 if (aT && bT) {
2248 if (aT.getElementType() != bT.getElementType())
2249 return false;
2250 if (aT.getAffineMaps() != bT.getAffineMaps()) {
2251 int64_t aOffset, bOffset;
2252 SmallVector<int64_t, 4> aStrides, bStrides;
2253 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
2254 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
2255 aStrides.size() != bStrides.size())
2256 return false;
2257
2258 // Strides along a dimension/offset are compatible if the value in the
2259 // source memref is static and the value in the target memref is the
2260 // same. They are also compatible if either one is dynamic (see
2261 // description of MemRefCastOp for details).
2262 auto checkCompatible = [](int64_t a, int64_t b) {
2263 return (a == MemRefType::getDynamicStrideOrOffset() ||
2264 b == MemRefType::getDynamicStrideOrOffset() || a == b);
2265 };
2266 if (!checkCompatible(aOffset, bOffset))
2267 return false;
2268 for (auto aStride : enumerate(aStrides))
2269 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
2270 return false;
2271 }
2272 if (aT.getMemorySpace() != bT.getMemorySpace())
2273 return false;
2274
2275 // They must have the same rank, and any specified dimensions must match.
2276 if (aT.getRank() != bT.getRank())
2277 return false;
2278
2279 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
2280 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
2281 if (aDim != -1 && bDim != -1 && aDim != bDim)
2282 return false;
2283 }
2284 return true;
2285 } else {
2286 if (!aT && !uaT)
2287 return false;
2288 if (!bT && !ubT)
2289 return false;
2290 // Unranked to unranked casting is unsupported
2291 if (uaT && ubT)
2292 return false;
2293
2294 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
2295 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
2296 if (aEltType != bEltType)
2297 return false;
2298
2299 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
2300 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
2301 if (aMemSpace != bMemSpace)
2302 return false;
2303
2304 return true;
2305 }
2306
2307 return false;
2308 }
2309
fold(ArrayRef<Attribute> operands)2310 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
2311 if (Value folded = impl::foldCastOp(*this))
2312 return folded;
2313 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
2314 }
2315
2316 //===----------------------------------------------------------------------===//
2317 // MemRefReinterpretCastOp
2318 //===----------------------------------------------------------------------===//
2319
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,int64_t staticOffset,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2320 void mlir::MemRefReinterpretCastOp::build(
2321 OpBuilder &b, OperationState &result, MemRefType resultType, Value source,
2322 int64_t staticOffset, ArrayRef<int64_t> staticSizes,
2323 ArrayRef<int64_t> staticStrides, ValueRange offset, ValueRange sizes,
2324 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2325 build(b, result, resultType, source, offset, sizes, strides,
2326 b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes),
2327 b.getI64ArrayAttr(staticStrides));
2328 result.addAttributes(attrs);
2329 }
2330
2331 /// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`,
2332 /// `staticSizes` and `staticStrides` are automatically filled with
2333 /// source-memref-rank sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,Value offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2334 void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2335 MemRefType resultType, Value source,
2336 Value offset, ValueRange sizes,
2337 ValueRange strides,
2338 ArrayRef<NamedAttribute> attrs) {
2339 unsigned rank = resultType.getRank();
2340 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
2341 SmallVector<int64_t, 4> staticStridesVector(
2342 rank, ShapedType::kDynamicStrideOrOffset);
2343 build(b, result, resultType, source,
2344 /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector,
2345 staticStridesVector, offset, sizes, strides, attrs);
2346 }
2347
2348 /// Print a memref_reinterpret_cast op of the form:
2349 /// ```
2350 /// `memref_reinterpret_cast` ssa-name to
2351 /// offset: `[` offset `]`
2352 /// sizes: `[` size-list `]`
2353 /// strides:`[` stride-list `]`
2354 /// `:` any-memref-type to strided-memref-type
2355 /// ```
print(OpAsmPrinter & p,MemRefReinterpretCastOp op)2356 static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
2357 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
2358 p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
2359 p << op.source() << " ";
2360 printOffsetsSizesAndStrides(
2361 p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
2362 /*stridePrefix=*/", strides: ");
2363 p << ": " << op.source().getType() << " to " << op.getType();
2364 }
2365
2366 /// Parse a memref_reinterpret_cast op of the form:
2367 /// ```
2368 /// `memref_reinterpret_cast` ssa-name to
2369 /// offset: `[` offset `]`
2370 /// sizes: `[` size-list `]`
2371 /// strides:`[` stride-list `]`
2372 /// `:` any-memref-type to strided-memref-type
2373 /// ```
parseMemRefReinterpretCastOp(OpAsmParser & parser,OperationState & result)2374 static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
2375 OperationState &result) {
2376 // Parse `operand`
2377 OpAsmParser::OperandType srcInfo;
2378 if (parser.parseOperand(srcInfo))
2379 return failure();
2380
2381 auto parseOffsetPrefix = [](OpAsmParser &parser) {
2382 return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") ||
2383 parser.parseColon());
2384 };
2385 auto parseSizePrefix = [](OpAsmParser &parser) {
2386 return failure(parser.parseComma() || parser.parseKeyword("sizes") ||
2387 parser.parseColon());
2388 };
2389 auto parseStridePrefix = [](OpAsmParser &parser) {
2390 return failure(parser.parseComma() || parser.parseKeyword("strides") ||
2391 parser.parseColon());
2392 };
2393
2394 Type srcType, dstType;
2395 auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
2396 return failure(parser.parseOptionalAttrDict(result.attributes) ||
2397 parser.parseColonType(srcType) ||
2398 parser.parseKeywordType("to", dstType) ||
2399 parser.resolveOperand(srcInfo, srcType, result.operands));
2400 };
2401 if (failed(parseOffsetsSizesAndStrides(parser, result,
2402 /*segmentSizes=*/{1}, // source memref
2403 preResolutionFn, parseOffsetPrefix,
2404 parseSizePrefix, parseStridePrefix)))
2405 return failure();
2406 return parser.addTypeToList(dstType, result.types);
2407 }
2408
verify(MemRefReinterpretCastOp op)2409 static LogicalResult verify(MemRefReinterpretCastOp op) {
2410 // The source and result memrefs should be in the same memory space.
2411 auto srcType = op.source().getType().cast<BaseMemRefType>();
2412 auto resultType = op.getType().cast<MemRefType>();
2413 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2414 return op.emitError("different memory spaces specified for source type ")
2415 << srcType << " and result memref type " << resultType;
2416 if (srcType.getElementType() != resultType.getElementType())
2417 return op.emitError("different element types specified for source type ")
2418 << srcType << " and result memref type " << resultType;
2419
2420 // Match sizes in result memref type and in static_sizes attribute.
2421 for (auto &en :
2422 llvm::enumerate(llvm::zip(resultType.getShape(),
2423 extractFromI64ArrayAttr(op.static_sizes())))) {
2424 int64_t resultSize = std::get<0>(en.value());
2425 int64_t expectedSize = std::get<1>(en.value());
2426 if (resultSize != expectedSize)
2427 return op.emitError("expected result type with size = ")
2428 << expectedSize << " instead of " << resultSize
2429 << " in dim = " << en.index();
2430 }
2431
2432 // Match offset and strides in static_offset and static_strides attributes if
2433 // result memref type has an affine map specified.
2434 if (!resultType.getAffineMaps().empty()) {
2435 int64_t resultOffset;
2436 SmallVector<int64_t, 4> resultStrides;
2437 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
2438 return failure();
2439
2440 // Match offset in result memref type and in static_offsets attribute.
2441 int64_t expectedOffset =
2442 extractFromI64ArrayAttr(op.static_offsets()).front();
2443 if (resultOffset != expectedOffset)
2444 return op.emitError("expected result type with offset = ")
2445 << resultOffset << " instead of " << expectedOffset;
2446
2447 // Match strides in result memref type and in static_strides attribute.
2448 for (auto &en : llvm::enumerate(llvm::zip(
2449 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
2450 int64_t resultStride = std::get<0>(en.value());
2451 int64_t expectedStride = std::get<1>(en.value());
2452 if (resultStride != expectedStride)
2453 return op.emitError("expected result type with stride = ")
2454 << expectedStride << " instead of " << resultStride
2455 << " in dim = " << en.index();
2456 }
2457 }
2458 return success();
2459 }
2460
2461 //===----------------------------------------------------------------------===//
2462 // MemRefReshapeOp
2463 //===----------------------------------------------------------------------===//
2464
verify(MemRefReshapeOp op)2465 static LogicalResult verify(MemRefReshapeOp op) {
2466 Type operandType = op.source().getType();
2467 Type resultType = op.result().getType();
2468
2469 Type operandElementType = operandType.cast<ShapedType>().getElementType();
2470 Type resultElementType = resultType.cast<ShapedType>().getElementType();
2471 if (operandElementType != resultElementType)
2472 return op.emitOpError("element types of source and destination memref "
2473 "types should be the same");
2474
2475 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2476 if (!operandMemRefType.getAffineMaps().empty())
2477 return op.emitOpError(
2478 "source memref type should have identity affine map");
2479
2480 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
2481 auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2482 if (resultMemRefType) {
2483 if (!resultMemRefType.getAffineMaps().empty())
2484 return op.emitOpError(
2485 "result memref type should have identity affine map");
2486 if (shapeSize == ShapedType::kDynamicSize)
2487 return op.emitOpError("cannot use shape operand with dynamic length to "
2488 "reshape to statically-ranked memref type");
2489 if (shapeSize != resultMemRefType.getRank())
2490 return op.emitOpError(
2491 "length of shape operand differs from the result's memref rank");
2492 }
2493 return success();
2494 }
2495
2496 //===----------------------------------------------------------------------===//
2497 // MulFOp
2498 //===----------------------------------------------------------------------===//
2499
fold(ArrayRef<Attribute> operands)2500 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
2501 return constFoldBinaryOp<FloatAttr>(
2502 operands, [](APFloat a, APFloat b) { return a * b; });
2503 }
2504
2505 //===----------------------------------------------------------------------===//
2506 // MulIOp
2507 //===----------------------------------------------------------------------===//
2508
fold(ArrayRef<Attribute> operands)2509 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
2510 /// muli(x, 0) -> 0
2511 if (matchPattern(rhs(), m_Zero()))
2512 return rhs();
2513 /// muli(x, 1) -> x
2514 if (matchPattern(rhs(), m_One()))
2515 return getOperand(0);
2516
2517 // TODO: Handle the overflow case.
2518 return constFoldBinaryOp<IntegerAttr>(operands,
2519 [](APInt a, APInt b) { return a * b; });
2520 }
2521
2522 //===----------------------------------------------------------------------===//
2523 // OrOp
2524 //===----------------------------------------------------------------------===//
2525
fold(ArrayRef<Attribute> operands)2526 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
2527 /// or(x, 0) -> x
2528 if (matchPattern(rhs(), m_Zero()))
2529 return lhs();
2530 /// or(x,x) -> x
2531 if (lhs() == rhs())
2532 return rhs();
2533
2534 return constFoldBinaryOp<IntegerAttr>(operands,
2535 [](APInt a, APInt b) { return a | b; });
2536 }
2537
2538 //===----------------------------------------------------------------------===//
2539 // PrefetchOp
2540 //===----------------------------------------------------------------------===//
2541
print(OpAsmPrinter & p,PrefetchOp op)2542 static void print(OpAsmPrinter &p, PrefetchOp op) {
2543 p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
2544 p.printOperands(op.indices());
2545 p << ']' << ", " << (op.isWrite() ? "write" : "read");
2546 p << ", locality<" << op.localityHint();
2547 p << ">, " << (op.isDataCache() ? "data" : "instr");
2548 p.printOptionalAttrDict(
2549 op.getAttrs(),
2550 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
2551 p << " : " << op.getMemRefType();
2552 }
2553
parsePrefetchOp(OpAsmParser & parser,OperationState & result)2554 static ParseResult parsePrefetchOp(OpAsmParser &parser,
2555 OperationState &result) {
2556 OpAsmParser::OperandType memrefInfo;
2557 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
2558 IntegerAttr localityHint;
2559 MemRefType type;
2560 StringRef readOrWrite, cacheType;
2561
2562 auto indexTy = parser.getBuilder().getIndexType();
2563 auto i32Type = parser.getBuilder().getIntegerType(32);
2564 if (parser.parseOperand(memrefInfo) ||
2565 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2566 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2567 parser.parseComma() || parser.parseKeyword("locality") ||
2568 parser.parseLess() ||
2569 parser.parseAttribute(localityHint, i32Type, "localityHint",
2570 result.attributes) ||
2571 parser.parseGreater() || parser.parseComma() ||
2572 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
2573 parser.resolveOperand(memrefInfo, type, result.operands) ||
2574 parser.resolveOperands(indexInfo, indexTy, result.operands))
2575 return failure();
2576
2577 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2578 return parser.emitError(parser.getNameLoc(),
2579 "rw specifier has to be 'read' or 'write'");
2580 result.addAttribute(
2581 PrefetchOp::getIsWriteAttrName(),
2582 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2583
2584 if (!cacheType.equals("data") && !cacheType.equals("instr"))
2585 return parser.emitError(parser.getNameLoc(),
2586 "cache type has to be 'data' or 'instr'");
2587
2588 result.addAttribute(
2589 PrefetchOp::getIsDataCacheAttrName(),
2590 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2591
2592 return success();
2593 }
2594
verify(PrefetchOp op)2595 static LogicalResult verify(PrefetchOp op) {
2596 if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
2597 return op.emitOpError("too few indices");
2598
2599 return success();
2600 }
2601
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2602 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2603 SmallVectorImpl<OpFoldResult> &results) {
2604 // prefetch(memrefcast) -> prefetch
2605 return foldMemRefCast(*this);
2606 }
2607
2608 //===----------------------------------------------------------------------===//
2609 // RankOp
2610 //===----------------------------------------------------------------------===//
2611
fold(ArrayRef<Attribute> operands)2612 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2613 // Constant fold rank when the rank of the operand is known.
2614 auto type = getOperand().getType();
2615 if (auto shapedType = type.dyn_cast<ShapedType>())
2616 if (shapedType.hasRank())
2617 return IntegerAttr::get(IndexType::get(getContext()),
2618 shapedType.getRank());
2619 return IntegerAttr();
2620 }
2621
2622 //===----------------------------------------------------------------------===//
2623 // ReturnOp
2624 //===----------------------------------------------------------------------===//
2625
verify(ReturnOp op)2626 static LogicalResult verify(ReturnOp op) {
2627 auto function = cast<FuncOp>(op->getParentOp());
2628
2629 // The operand number and types must match the function signature.
2630 const auto &results = function.getType().getResults();
2631 if (op.getNumOperands() != results.size())
2632 return op.emitOpError("has ")
2633 << op.getNumOperands() << " operands, but enclosing function (@"
2634 << function.getName() << ") returns " << results.size();
2635
2636 for (unsigned i = 0, e = results.size(); i != e; ++i)
2637 if (op.getOperand(i).getType() != results[i])
2638 return op.emitError()
2639 << "type of return operand " << i << " ("
2640 << op.getOperand(i).getType()
2641 << ") doesn't match function result type (" << results[i] << ")"
2642 << " in function @" << function.getName();
2643
2644 return success();
2645 }
2646
2647 //===----------------------------------------------------------------------===//
2648 // SelectOp
2649 //===----------------------------------------------------------------------===//
2650
fold(ArrayRef<Attribute> operands)2651 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2652 auto condition = getCondition();
2653
2654 // select true, %0, %1 => %0
2655 if (matchPattern(condition, m_One()))
2656 return getTrueValue();
2657
2658 // select false, %0, %1 => %1
2659 if (matchPattern(condition, m_Zero()))
2660 return getFalseValue();
2661 return nullptr;
2662 }
2663
print(OpAsmPrinter & p,SelectOp op)2664 static void print(OpAsmPrinter &p, SelectOp op) {
2665 p << "select " << op.getOperands();
2666 p.printOptionalAttrDict(op.getAttrs());
2667 p << " : ";
2668 if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
2669 p << condType << ", ";
2670 p << op.getType();
2671 }
2672
parseSelectOp(OpAsmParser & parser,OperationState & result)2673 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
2674 Type conditionType, resultType;
2675 SmallVector<OpAsmParser::OperandType, 3> operands;
2676 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2677 parser.parseOptionalAttrDict(result.attributes) ||
2678 parser.parseColonType(resultType))
2679 return failure();
2680
2681 // Check for the explicit condition type if this is a masked tensor or vector.
2682 if (succeeded(parser.parseOptionalComma())) {
2683 conditionType = resultType;
2684 if (parser.parseType(resultType))
2685 return failure();
2686 } else {
2687 conditionType = parser.getBuilder().getI1Type();
2688 }
2689
2690 result.addTypes(resultType);
2691 return parser.resolveOperands(operands,
2692 {conditionType, resultType, resultType},
2693 parser.getNameLoc(), result.operands);
2694 }
2695
verify(SelectOp op)2696 static LogicalResult verify(SelectOp op) {
2697 Type conditionType = op.getCondition().getType();
2698 if (conditionType.isSignlessInteger(1))
2699 return success();
2700
2701 // If the result type is a vector or tensor, the type can be a mask with the
2702 // same elements.
2703 Type resultType = op.getType();
2704 if (!resultType.isa<TensorType, VectorType>())
2705 return op.emitOpError()
2706 << "expected condition to be a signless i1, but got "
2707 << conditionType;
2708 Type shapedConditionType = getI1SameShape(resultType);
2709 if (conditionType != shapedConditionType)
2710 return op.emitOpError()
2711 << "expected condition type to have the same shape "
2712 "as the result type, expected "
2713 << shapedConditionType << ", but got " << conditionType;
2714 return success();
2715 }
2716
2717 //===----------------------------------------------------------------------===//
2718 // SignExtendIOp
2719 //===----------------------------------------------------------------------===//
2720
verify(SignExtendIOp op)2721 static LogicalResult verify(SignExtendIOp op) {
2722 // Get the scalar type (which is either directly the type of the operand
2723 // or the vector's/tensor's element type.
2724 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2725 auto dstType = getElementTypeOrSelf(op.getType());
2726
2727 // For now, index is forbidden for the source and the destination type.
2728 if (srcType.isa<IndexType>())
2729 return op.emitError() << srcType << " is not a valid operand type";
2730 if (dstType.isa<IndexType>())
2731 return op.emitError() << dstType << " is not a valid result type";
2732
2733 if (srcType.cast<IntegerType>().getWidth() >=
2734 dstType.cast<IntegerType>().getWidth())
2735 return op.emitError("result type ")
2736 << dstType << " must be wider than operand type " << srcType;
2737
2738 return success();
2739 }
2740
2741 //===----------------------------------------------------------------------===//
2742 // SignedDivIOp
2743 //===----------------------------------------------------------------------===//
2744
fold(ArrayRef<Attribute> operands)2745 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
2746 assert(operands.size() == 2 && "binary operation takes two operands");
2747
2748 // Don't fold if it would overflow or if it requires a division by zero.
2749 bool overflowOrDiv0 = false;
2750 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2751 if (overflowOrDiv0 || !b) {
2752 overflowOrDiv0 = true;
2753 return a;
2754 }
2755 return a.sdiv_ov(b, overflowOrDiv0);
2756 });
2757
2758 // Fold out division by one. Assumes all tensors of all ones are splats.
2759 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2760 if (rhs.getValue() == 1)
2761 return lhs();
2762 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2763 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2764 return lhs();
2765 }
2766
2767 return overflowOrDiv0 ? Attribute() : result;
2768 }
2769
2770 //===----------------------------------------------------------------------===//
2771 // SignedFloorDivIOp
2772 //===----------------------------------------------------------------------===//
2773
signedCeilNonnegInputs(APInt a,APInt b,bool & overflow)2774 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
2775 // Returns (a-1)/b + 1
2776 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
2777 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
2778 return val.sadd_ov(one, overflow);
2779 }
2780
fold(ArrayRef<Attribute> operands)2781 OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
2782 assert(operands.size() == 2 && "binary operation takes two operands");
2783
2784 // Don't fold if it would overflow or if it requires a division by zero.
2785 bool overflowOrDiv0 = false;
2786 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2787 if (overflowOrDiv0 || !b) {
2788 overflowOrDiv0 = true;
2789 return a;
2790 }
2791 unsigned bits = a.getBitWidth();
2792 APInt zero = APInt::getNullValue(bits);
2793 if (a.sge(zero) && b.sgt(zero)) {
2794 // Both positive (or a is zero), return a / b.
2795 return a.sdiv_ov(b, overflowOrDiv0);
2796 } else if (a.sle(zero) && b.slt(zero)) {
2797 // Both negative (or a is zero), return -a / -b.
2798 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2799 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2800 return posA.sdiv_ov(posB, overflowOrDiv0);
2801 } else if (a.slt(zero) && b.sgt(zero)) {
2802 // A is negative, b is positive, return - ceil(-a, b).
2803 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2804 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
2805 return zero.ssub_ov(ceil, overflowOrDiv0);
2806 } else {
2807 // A is positive, b is negative, return - ceil(a, -b).
2808 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2809 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
2810 return zero.ssub_ov(ceil, overflowOrDiv0);
2811 }
2812 });
2813
2814 // Fold out floor division by one. Assumes all tensors of all ones are
2815 // splats.
2816 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2817 if (rhs.getValue() == 1)
2818 return lhs();
2819 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2820 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2821 return lhs();
2822 }
2823
2824 return overflowOrDiv0 ? Attribute() : result;
2825 }
2826
2827 //===----------------------------------------------------------------------===//
2828 // SignedCeilDivIOp
2829 //===----------------------------------------------------------------------===//
2830
fold(ArrayRef<Attribute> operands)2831 OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
2832 assert(operands.size() == 2 && "binary operation takes two operands");
2833
2834 // Don't fold if it would overflow or if it requires a division by zero.
2835 bool overflowOrDiv0 = false;
2836 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2837 if (overflowOrDiv0 || !b) {
2838 overflowOrDiv0 = true;
2839 return a;
2840 }
2841 unsigned bits = a.getBitWidth();
2842 APInt zero = APInt::getNullValue(bits);
2843 if (a.sgt(zero) && b.sgt(zero)) {
2844 // Both positive, return ceil(a, b).
2845 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
2846 } else if (a.slt(zero) && b.slt(zero)) {
2847 // Both negative, return ceil(-a, -b).
2848 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2849 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2850 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
2851 } else if (a.slt(zero) && b.sgt(zero)) {
2852 // A is negative, b is positive, return - ( -a / b).
2853 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2854 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
2855 return zero.ssub_ov(div, overflowOrDiv0);
2856 } else {
2857 // A is positive (or zero), b is negative, return - (a / -b).
2858 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2859 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
2860 return zero.ssub_ov(div, overflowOrDiv0);
2861 }
2862 });
2863
2864 // Fold out floor division by one. Assumes all tensors of all ones are
2865 // splats.
2866 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2867 if (rhs.getValue() == 1)
2868 return lhs();
2869 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2870 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2871 return lhs();
2872 }
2873
2874 return overflowOrDiv0 ? Attribute() : result;
2875 }
2876
2877 //===----------------------------------------------------------------------===//
2878 // SignedRemIOp
2879 //===----------------------------------------------------------------------===//
2880
fold(ArrayRef<Attribute> operands)2881 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
2882 assert(operands.size() == 2 && "remi_signed takes two operands");
2883
2884 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2885 if (!rhs)
2886 return {};
2887 auto rhsValue = rhs.getValue();
2888
2889 // x % 1 = 0
2890 if (rhsValue.isOneValue())
2891 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2892
2893 // Don't fold if it requires division by zero.
2894 if (rhsValue.isNullValue())
2895 return {};
2896
2897 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2898 if (!lhs)
2899 return {};
2900 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
2901 }
2902
2903 //===----------------------------------------------------------------------===//
2904 // SIToFPOp
2905 //===----------------------------------------------------------------------===//
2906
2907 // sitofp is applicable from integer types to float types.
areCastCompatible(Type a,Type b)2908 bool SIToFPOp::areCastCompatible(Type a, Type b) {
2909 if (a.isSignlessInteger() && b.isa<FloatType>())
2910 return true;
2911 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2912 }
2913
2914 //===----------------------------------------------------------------------===//
2915 // SplatOp
2916 //===----------------------------------------------------------------------===//
2917
verify(SplatOp op)2918 static LogicalResult verify(SplatOp op) {
2919 // TODO: we could replace this by a trait.
2920 if (op.getOperand().getType() !=
2921 op.getType().cast<ShapedType>().getElementType())
2922 return op.emitError("operand should be of elemental type of result type");
2923
2924 return success();
2925 }
2926
2927 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)2928 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2929 assert(operands.size() == 1 && "splat takes one operand");
2930
2931 auto constOperand = operands.front();
2932 if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
2933 return {};
2934
2935 auto shapedType = getType().cast<ShapedType>();
2936 assert(shapedType.getElementType() == constOperand.getType() &&
2937 "incorrect input attribute type for folding");
2938
2939 // SplatElementsAttr::get treats single value for second arg as being a splat.
2940 return SplatElementsAttr::get(shapedType, {constOperand});
2941 }
2942
2943 //===----------------------------------------------------------------------===//
2944 // StoreOp
2945 //===----------------------------------------------------------------------===//
2946
verify(StoreOp op)2947 static LogicalResult verify(StoreOp op) {
2948 if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
2949 return op.emitOpError("store index operand count not equal to memref rank");
2950
2951 return success();
2952 }
2953
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2954 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2955 SmallVectorImpl<OpFoldResult> &results) {
2956 /// store(memrefcast) -> store
2957 return foldMemRefCast(*this);
2958 }
2959
2960 //===----------------------------------------------------------------------===//
2961 // SubFOp
2962 //===----------------------------------------------------------------------===//
2963
fold(ArrayRef<Attribute> operands)2964 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
2965 return constFoldBinaryOp<FloatAttr>(
2966 operands, [](APFloat a, APFloat b) { return a - b; });
2967 }
2968
2969 //===----------------------------------------------------------------------===//
2970 // SubIOp
2971 //===----------------------------------------------------------------------===//
2972
fold(ArrayRef<Attribute> operands)2973 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
2974 // subi(x,x) -> 0
2975 if (getOperand(0) == getOperand(1))
2976 return Builder(getContext()).getZeroAttr(getType());
2977 // subi(x,0) -> x
2978 if (matchPattern(rhs(), m_Zero()))
2979 return lhs();
2980
2981 return constFoldBinaryOp<IntegerAttr>(operands,
2982 [](APInt a, APInt b) { return a - b; });
2983 }
2984
2985 //===----------------------------------------------------------------------===//
2986 // UIToFPOp
2987 //===----------------------------------------------------------------------===//
2988
2989 // uitofp is applicable from integer types to float types.
areCastCompatible(Type a,Type b)2990 bool UIToFPOp::areCastCompatible(Type a, Type b) {
2991 if (a.isSignlessInteger() && b.isa<FloatType>())
2992 return true;
2993 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2994 }
2995
2996 //===----------------------------------------------------------------------===//
2997 // SubViewOp
2998 //===----------------------------------------------------------------------===//
2999
3000 namespace {
3001 /// Helpers to write more idiomatic operations.
3002 namespace saturated_arith {
3003 struct Wrapper {
Wrapper__anon2fd8af2c2411::saturated_arith::Wrapper3004 explicit Wrapper(int64_t v) : v(v) {}
operator int64_t__anon2fd8af2c2411::saturated_arith::Wrapper3005 operator int64_t() { return v; }
3006 int64_t v;
3007 };
operator +(Wrapper a,int64_t b)3008 Wrapper operator+(Wrapper a, int64_t b) {
3009 if (ShapedType::isDynamicStrideOrOffset(a) ||
3010 ShapedType::isDynamicStrideOrOffset(b))
3011 return Wrapper(ShapedType::kDynamicStrideOrOffset);
3012 return Wrapper(a.v + b);
3013 }
operator *(Wrapper a,int64_t b)3014 Wrapper operator*(Wrapper a, int64_t b) {
3015 if (ShapedType::isDynamicStrideOrOffset(a) ||
3016 ShapedType::isDynamicStrideOrOffset(b))
3017 return Wrapper(ShapedType::kDynamicStrideOrOffset);
3018 return Wrapper(a.v * b);
3019 }
3020 } // end namespace saturated_arith
3021 } // end namespace
3022
3023 /// A subview result type can be fully inferred from the source type and the
3024 /// static representation of offsets, sizes and strides. Special sentinels
3025 /// encode the dynamic case.
inferResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)3026 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
3027 ArrayRef<int64_t> staticOffsets,
3028 ArrayRef<int64_t> staticSizes,
3029 ArrayRef<int64_t> staticStrides) {
3030 unsigned rank = sourceMemRefType.getRank();
3031 (void)rank;
3032 assert(staticOffsets.size() == rank &&
3033 "unexpected staticOffsets size mismatch");
3034 assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
3035 assert(staticStrides.size() == rank &&
3036 "unexpected staticStrides size mismatch");
3037
3038 // Extract source offset and strides.
3039 int64_t sourceOffset;
3040 SmallVector<int64_t, 4> sourceStrides;
3041 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
3042 assert(succeeded(res) && "SubViewOp expected strided memref type");
3043 (void)res;
3044
3045 // Compute target offset whose value is:
3046 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
3047 int64_t targetOffset = sourceOffset;
3048 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
3049 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
3050 using namespace saturated_arith;
3051 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
3052 }
3053
3054 // Compute target stride whose value is:
3055 // `sourceStrides_i * staticStrides_i`.
3056 SmallVector<int64_t, 4> targetStrides;
3057 targetStrides.reserve(staticOffsets.size());
3058 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
3059 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3060 using namespace saturated_arith;
3061 targetStrides.push_back(Wrapper(sourceStride) * staticStride);
3062 }
3063
3064 // The type is now known.
3065 return MemRefType::get(
3066 staticSizes, sourceMemRefType.getElementType(),
3067 makeStridedLinearLayoutMap(targetStrides, targetOffset,
3068 sourceMemRefType.getContext()),
3069 sourceMemRefType.getMemorySpace());
3070 }
3071
3072 /// Print a subview op of the form:
3073 /// ```
3074 /// `subview` ssa-name
3075 /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3076 /// `:` strided-memref-type `to` strided-memref-type
3077 /// ```
print(OpAsmPrinter & p,SubViewOp op)3078 static void print(OpAsmPrinter &p, SubViewOp op) {
3079 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
3080 p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
3081 p << op.source();
3082 printOffsetsSizesAndStrides(p, op);
3083 p << " : " << op.getSourceType() << " to " << op.getType();
3084 }
3085
3086 /// Parse a subview op of the form:
3087 /// ```
3088 /// `subview` ssa-name
3089 /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3090 /// `:` strided-memref-type `to` strided-memref-type
3091 /// ```
parseSubViewOp(OpAsmParser & parser,OperationState & result)3092 static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
3093 OpAsmParser::OperandType srcInfo;
3094 if (parser.parseOperand(srcInfo))
3095 return failure();
3096 Type srcType, dstType;
3097 auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
3098 return failure(parser.parseOptionalAttrDict(result.attributes) ||
3099 parser.parseColonType(srcType) ||
3100 parser.parseKeywordType("to", dstType) ||
3101 parser.resolveOperand(srcInfo, srcType, result.operands));
3102 };
3103
3104 if (failed(parseOffsetsSizesAndStrides(parser, result,
3105 /*segmentSizes=*/{1}, // source memref
3106 preResolutionFn)))
3107 return failure();
3108 return parser.addTypeToList(dstType, result.types);
3109 }
3110
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3111 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3112 ArrayRef<int64_t> staticOffsets,
3113 ArrayRef<int64_t> staticSizes,
3114 ArrayRef<int64_t> staticStrides, ValueRange offsets,
3115 ValueRange sizes, ValueRange strides,
3116 ArrayRef<NamedAttribute> attrs) {
3117 auto sourceMemRefType = source.getType().cast<MemRefType>();
3118 auto resultType = inferResultType(sourceMemRefType, staticOffsets,
3119 staticSizes, staticStrides);
3120 build(b, result, resultType, source, offsets, sizes, strides,
3121 b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3122 b.getI64ArrayAttr(staticStrides));
3123 result.addAttributes(attrs);
3124 }
3125
3126 /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
3127 /// and `staticStrides` are automatically filled with source-memref-rank
3128 /// sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3129 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3130 ValueRange offsets, ValueRange sizes,
3131 ValueRange strides,
3132 ArrayRef<NamedAttribute> attrs) {
3133 auto sourceMemRefType = source.getType().cast<MemRefType>();
3134 unsigned rank = sourceMemRefType.getRank();
3135 SmallVector<int64_t, 4> staticOffsetsVector;
3136 staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3137 SmallVector<int64_t, 4> staticSizesVector;
3138 staticSizesVector.assign(rank, ShapedType::kDynamicSize);
3139 SmallVector<int64_t, 4> staticStridesVector;
3140 staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3141 build(b, result, source, staticOffsetsVector, staticSizesVector,
3142 staticStridesVector, offsets, sizes, strides, attrs);
3143 }
3144
3145 /// Build a SubViewOp as above but with custom result type.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3146 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
3147 MemRefType resultType, Value source,
3148 ArrayRef<int64_t> staticOffsets,
3149 ArrayRef<int64_t> staticSizes,
3150 ArrayRef<int64_t> staticStrides, ValueRange offsets,
3151 ValueRange sizes, ValueRange strides,
3152 ArrayRef<NamedAttribute> attrs) {
3153 build(b, result, resultType, source, offsets, sizes, strides,
3154 b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3155 b.getI64ArrayAttr(staticStrides));
3156 result.addAttributes(attrs);
3157 }
3158
3159 /// Build a SubViewOp as above but with custom result type.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3160 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
3161 MemRefType resultType, Value source,
3162 ValueRange offsets, ValueRange sizes,
3163 ValueRange strides,
3164 ArrayRef<NamedAttribute> attrs) {
3165 auto sourceMemRefType = source.getType().cast<MemRefType>();
3166 unsigned rank = sourceMemRefType.getRank();
3167 SmallVector<int64_t, 4> staticOffsetsVector;
3168 staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3169 SmallVector<int64_t, 4> staticSizesVector;
3170 staticSizesVector.assign(rank, ShapedType::kDynamicSize);
3171 SmallVector<int64_t, 4> staticStridesVector;
3172 staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3173 build(b, result, resultType, source, staticOffsetsVector, staticSizesVector,
3174 staticStridesVector, offsets, sizes, strides, attrs);
3175 }
3176
3177 /// For ViewLikeOpInterface.
getViewSource()3178 Value SubViewOp::getViewSource() { return source(); }
3179
3180 llvm::Optional<SmallVector<bool, 4>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> reducedShape)3181 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
3182 ArrayRef<int64_t> reducedShape) {
3183 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
3184 SmallVector<bool, 4> mask(originalRank);
3185 unsigned reducedIdx = 0;
3186 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
3187 // Skip matching dims greedily.
3188 mask[originalIdx] =
3189 (reducedIdx < reducedRank) &&
3190 (originalShape[originalIdx] == reducedShape[reducedIdx]);
3191 if (mask[originalIdx])
3192 reducedIdx++;
3193 // 1 is the only non-matching allowed.
3194 else if (originalShape[originalIdx] != 1)
3195 return {};
3196 }
3197
3198 if (reducedIdx != reducedRank)
3199 return {};
3200
3201 return mask;
3202 }
3203
3204 enum SubViewVerificationResult {
3205 Success,
3206 RankTooLarge,
3207 SizeMismatch,
3208 StrideMismatch,
3209 ElemTypeMismatch,
3210 MemSpaceMismatch,
3211 AffineMapMismatch
3212 };
3213
3214 /// Checks if `original` Type type can be rank reduced to `reduced` type.
3215 /// This function is slight variant of `is subsequence` algorithm where
3216 /// not matching dimension must be 1.
isRankReducedType(Type originalType,Type reducedType)3217 static SubViewVerificationResult isRankReducedType(Type originalType,
3218 Type reducedType) {
3219 if (originalType == reducedType)
3220 return SubViewVerificationResult::Success;
3221 if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
3222 return SubViewVerificationResult::Success;
3223 if (originalType.isa<RankedTensorType>() &&
3224 !reducedType.isa<RankedTensorType>())
3225 return SubViewVerificationResult::Success;
3226 if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
3227 return SubViewVerificationResult::Success;
3228
3229 ShapedType originalShapedType = originalType.cast<ShapedType>();
3230 ShapedType reducedShapedType = reducedType.cast<ShapedType>();
3231
3232 // Rank and size logic is valid for all ShapedTypes.
3233 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
3234 ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
3235 unsigned originalRank = originalShape.size(),
3236 reducedRank = reducedShape.size();
3237 if (reducedRank > originalRank)
3238 return SubViewVerificationResult::RankTooLarge;
3239
3240 auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
3241
3242 // Sizes cannot be matched in case empty vector is returned.
3243 if (!optionalMask.hasValue())
3244 return SubViewVerificationResult::SizeMismatch;
3245
3246 // We are done for the tensor case.
3247 if (originalType.isa<RankedTensorType>())
3248 return SubViewVerificationResult::Success;
3249
3250 // Strided layout logic is relevant for MemRefType only.
3251 MemRefType original = originalType.cast<MemRefType>();
3252 MemRefType reduced = reducedType.cast<MemRefType>();
3253 MLIRContext *c = original.getContext();
3254 int64_t originalOffset, reducedOffset;
3255 SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
3256 SmallVector<bool, 4> keepMask = optionalMask.getValue();
3257 getStridesAndOffset(original, originalStrides, originalOffset);
3258 getStridesAndOffset(reduced, reducedStrides, reducedOffset);
3259
3260 // Filter strides based on the mask and check that they are the same
3261 // as reduced ones.
3262 unsigned reducedIdx = 0;
3263 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
3264 if (keepMask[originalIdx]) {
3265 if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
3266 return SubViewVerificationResult::StrideMismatch;
3267 keepStrides.push_back(originalStrides[originalIdx]);
3268 }
3269 }
3270
3271 if (original.getElementType() != reduced.getElementType())
3272 return SubViewVerificationResult::ElemTypeMismatch;
3273
3274 if (original.getMemorySpace() != reduced.getMemorySpace())
3275 return SubViewVerificationResult::MemSpaceMismatch;
3276
3277 auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
3278 if (!reduced.getAffineMaps().empty() &&
3279 reducedMap != reduced.getAffineMaps().front())
3280 return SubViewVerificationResult::AffineMapMismatch;
3281
3282 return SubViewVerificationResult::Success;
3283 }
3284
3285 template <typename OpTy>
produceSubViewErrorMsg(SubViewVerificationResult result,OpTy op,Type expectedType)3286 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
3287 OpTy op, Type expectedType) {
3288 auto memrefType = expectedType.cast<ShapedType>();
3289 switch (result) {
3290 case SubViewVerificationResult::Success:
3291 return success();
3292 case SubViewVerificationResult::RankTooLarge:
3293 return op.emitError("expected result rank to be smaller or equal to ")
3294 << "the source rank.";
3295 case SubViewVerificationResult::SizeMismatch:
3296 return op.emitError("expected result type to be ")
3297 << expectedType
3298 << " or a rank-reduced version. (mismatch of result sizes)";
3299 case SubViewVerificationResult::StrideMismatch:
3300 return op.emitError("expected result type to be ")
3301 << expectedType
3302 << " or a rank-reduced version. (mismatch of result strides)";
3303 case SubViewVerificationResult::ElemTypeMismatch:
3304 return op.emitError("expected result element type to be ")
3305 << memrefType.getElementType();
3306 case SubViewVerificationResult::MemSpaceMismatch:
3307 return op.emitError("expected result and source memory spaces to match.");
3308 case SubViewVerificationResult::AffineMapMismatch:
3309 return op.emitError("expected result type to be ")
3310 << expectedType
3311 << " or a rank-reduced version. (mismatch of result affine map)";
3312 }
3313 llvm_unreachable("unexpected subview verification result");
3314 }
3315
3316 /// Verifier for SubViewOp.
verify(SubViewOp op)3317 static LogicalResult verify(SubViewOp op) {
3318 MemRefType baseType = op.getSourceType();
3319 MemRefType subViewType = op.getType();
3320
3321 // The base memref and the view memref should be in the same memory space.
3322 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3323 return op.emitError("different memory spaces specified for base memref "
3324 "type ")
3325 << baseType << " and subview memref type " << subViewType;
3326
3327 // Verify that the base memref type has a strided layout map.
3328 if (!isStrided(baseType))
3329 return op.emitError("base type ") << baseType << " is not strided";
3330
3331 // Verify result type against inferred type.
3332 auto expectedType = SubViewOp::inferResultType(
3333 baseType, extractFromI64ArrayAttr(op.static_offsets()),
3334 extractFromI64ArrayAttr(op.static_sizes()),
3335 extractFromI64ArrayAttr(op.static_strides()));
3336
3337 auto result = isRankReducedType(expectedType, subViewType);
3338 return produceSubViewErrorMsg(result, op, expectedType);
3339 }
3340
operator <<(raw_ostream & os,Range & range)3341 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
3342 return os << "range " << range.offset << ":" << range.size << ":"
3343 << range.stride;
3344 }
3345
3346 /// Return the list of Range (i.e. offset, size, stride). Each Range
3347 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3348 /// with `b` at location `loc`.
getOrCreateRanges(OffsetSizeAndStrideOpInterface op,OpBuilder & b,Location loc)3349 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3350 OpBuilder &b, Location loc) {
3351 std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
3352 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3353 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3354 SmallVector<Range, 8> res;
3355 unsigned rank = ranks[0];
3356 res.reserve(rank);
3357 for (unsigned idx = 0; idx < rank; ++idx) {
3358 Value offset =
3359 op.isDynamicOffset(idx)
3360 ? op.getDynamicOffset(idx)
3361 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
3362 Value size = op.isDynamicSize(idx)
3363 ? op.getDynamicSize(idx)
3364 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
3365 Value stride =
3366 op.isDynamicStride(idx)
3367 ? op.getDynamicStride(idx)
3368 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
3369 res.emplace_back(Range{offset, size, stride});
3370 }
3371 return res;
3372 }
3373
3374 namespace {
3375
3376 /// Take a list of `values` with potential new constant to extract and a list
3377 /// of `constantValues` with`values.size()` sentinel that evaluate to true by
3378 /// applying `isDynamic`.
3379 /// Detects the `values` produced by a ConstantIndexOp and places the new
3380 /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<Value> & values,SmallVectorImpl<int64_t> & constantValues,llvm::function_ref<bool (int64_t)> isDynamic)3381 void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
3382 SmallVectorImpl<int64_t> &constantValues,
3383 llvm::function_ref<bool(int64_t)> isDynamic) {
3384 bool hasNewStaticValue = llvm::any_of(
3385 values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
3386 if (hasNewStaticValue) {
3387 for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
3388 cstIdx != e; ++cstIdx) {
3389 // Was already static, skip.
3390 if (!isDynamic(constantValues[cstIdx]))
3391 continue;
3392 // Newly static, move from Value to constant.
3393 if (matchPattern(values[valIdx], m_ConstantIndex())) {
3394 constantValues[cstIdx] =
3395 cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
3396 // Erase for impl. simplicity. Reverse iterator if we really must.
3397 values.erase(std::next(values.begin(), valIdx));
3398 continue;
3399 }
3400 // Remains dynamic move to next value.
3401 ++valIdx;
3402 }
3403 }
3404 }
3405
replaceWithNewOp(PatternRewriter & rewriter,SubViewOp op,SubViewOp newOp)3406 static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
3407 SubViewOp newOp) {
3408 rewriter.replaceOpWithNewOp<MemRefCastOp>(op, newOp, op.getType());
3409 }
3410
replaceWithNewOp(PatternRewriter & rewriter,SubTensorOp op,SubTensorOp newOp)3411 static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
3412 SubTensorOp newOp) {
3413 rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
3414 }
3415
3416 /// Pattern to rewrite a subview op with constant arguments.
3417 template <typename OpType>
3418 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
3419 : public OpRewritePattern<OpType> {
3420 public:
3421 using OpRewritePattern<OpType>::OpRewritePattern;
3422
matchAndRewrite(OpType op,PatternRewriter & rewriter) const3423 LogicalResult matchAndRewrite(OpType op,
3424 PatternRewriter &rewriter) const override {
3425 // No constant operand, just return;
3426 if (llvm::none_of(op.getOperands(), [](Value operand) {
3427 return matchPattern(operand, m_ConstantIndex());
3428 }))
3429 return failure();
3430
3431 // At least one of offsets/sizes/strides is a new constant.
3432 // Form the new list of operands and constant attributes from the existing.
3433 SmallVector<Value, 8> newOffsets(op.offsets());
3434 SmallVector<int64_t, 8> newStaticOffsets =
3435 extractFromI64ArrayAttr(op.static_offsets());
3436 std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
3437 (void)ranks;
3438 assert(newStaticOffsets.size() == ranks[0]);
3439 canonicalizeSubViewPart(newOffsets, newStaticOffsets,
3440 ShapedType::isDynamicStrideOrOffset);
3441
3442 SmallVector<Value, 8> newSizes(op.sizes());
3443 SmallVector<int64_t, 8> newStaticSizes =
3444 extractFromI64ArrayAttr(op.static_sizes());
3445 assert(newStaticSizes.size() == ranks[1]);
3446 canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
3447
3448 SmallVector<Value, 8> newStrides(op.strides());
3449 SmallVector<int64_t, 8> newStaticStrides =
3450 extractFromI64ArrayAttr(op.static_strides());
3451 assert(newStaticStrides.size() == ranks[2]);
3452 canonicalizeSubViewPart(newStrides, newStaticStrides,
3453 ShapedType::isDynamicStrideOrOffset);
3454
3455 // Create the new op in canonical form.
3456 auto newOp = rewriter.create<OpType>(
3457 op.getLoc(), op.source(), newStaticOffsets, newStaticSizes,
3458 newStaticStrides, newOffsets, newSizes, newStrides);
3459
3460 replaceWithNewOp(rewriter, op, newOp);
3461
3462 return success();
3463 }
3464 };
3465
3466 } // end anonymous namespace
3467
3468 /// Determines whether MemRefCastOp casts to a more dynamic version of the
3469 /// source memref. This is useful to to fold a memref_cast into a consuming op
3470 /// and implement canonicalization patterns for ops in different dialects that
3471 /// may consume the results of memref_cast operations. Such foldable memref_cast
3472 /// operations are typically inserted as `view` and `subview` ops are
3473 /// canonicalized, to preserve the type compatibility of their uses.
3474 ///
3475 /// Returns true when all conditions are met:
3476 /// 1. source and result are ranked memrefs with strided semantics and same
3477 /// element type and rank.
3478 /// 2. each of the source's size, offset or stride has more static information
3479 /// than the corresponding result's size, offset or stride.
3480 ///
3481 /// Example 1:
3482 /// ```mlir
3483 /// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
3484 /// %2 = consumer %1 ... : memref<?x?xf32> ...
3485 /// ```
3486 ///
3487 /// may fold into:
3488 ///
3489 /// ```mlir
3490 /// %2 = consumer %0 ... : memref<8x16xf32> ...
3491 /// ```
3492 ///
3493 /// Example 2:
3494 /// ```
3495 /// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
3496 /// to memref<?x?xf32>
3497 /// consumer %1 : memref<?x?xf32> ...
3498 /// ```
3499 ///
3500 /// may fold into:
3501 ///
3502 /// ```
3503 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
3504 /// ```
canFoldIntoConsumerOp(MemRefCastOp castOp)3505 bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
3506 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
3507 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
3508
3509 // Requires ranked MemRefType.
3510 if (!sourceType || !resultType)
3511 return false;
3512
3513 // Requires same elemental type.
3514 if (sourceType.getElementType() != resultType.getElementType())
3515 return false;
3516
3517 // Requires same rank.
3518 if (sourceType.getRank() != resultType.getRank())
3519 return false;
3520
3521 // Only fold casts between strided memref forms.
3522 int64_t sourceOffset, resultOffset;
3523 SmallVector<int64_t, 4> sourceStrides, resultStrides;
3524 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
3525 failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
3526 return false;
3527
3528 // If cast is towards more static sizes along any dimension, don't fold.
3529 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
3530 auto ss = std::get<0>(it), st = std::get<1>(it);
3531 if (ss != st)
3532 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
3533 return false;
3534 }
3535
3536 // If cast is towards more static offset along any dimension, don't fold.
3537 if (sourceOffset != resultOffset)
3538 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
3539 !MemRefType::isDynamicStrideOrOffset(resultOffset))
3540 return false;
3541
3542 // If cast is towards more static strides along any dimension, don't fold.
3543 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
3544 auto ss = std::get<0>(it), st = std::get<1>(it);
3545 if (ss != st)
3546 if (MemRefType::isDynamicStrideOrOffset(ss) &&
3547 !MemRefType::isDynamicStrideOrOffset(st))
3548 return false;
3549 }
3550
3551 return true;
3552 }
3553
3554 /// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
3555 /// Determines whether TensorCastOp casts to a more dynamic version of the
3556 /// source tensor. This is useful to fold a tensor_cast into a consuming op and
3557 /// implement canonicalization patterns for ops in different dialects that may
3558 /// consume the results of tensor_cast operations. Such foldable tensor_cast
3559 /// operations are typically inserted as `subtensor` ops and are canonicalized,
3560 /// to preserve the type compatibility of their uses.
3561 ///
3562 /// Returns true when all conditions are met:
3563 /// 1. source and result are ranked tensors with same element type and rank.
3564 /// 2. the tensor type has more static information than the result
3565 ///
3566 /// Example:
3567 /// ```mlir
3568 /// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3569 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
3570 /// ```
3571 ///
3572 /// folds into:
3573 ///
3574 /// ```mlir
3575 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
3576 /// ```
canFoldIntoConsumerOp(TensorCastOp castOp)3577 bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
3578 if (!castOp)
3579 return false;
3580
3581 RankedTensorType sourceType =
3582 castOp.source().getType().dyn_cast<RankedTensorType>();
3583 RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
3584
3585 // Requires RankedTensorType.
3586 if (!sourceType || !resultType)
3587 return false;
3588
3589 // Requires same elemental type.
3590 if (sourceType.getElementType() != resultType.getElementType())
3591 return false;
3592
3593 // Requires same rank.
3594 if (sourceType.getRank() != resultType.getRank())
3595 return false;
3596
3597 // If cast is towards more static sizes along any dimension, don't fold.
3598 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
3599 auto ss = std::get<0>(it), st = std::get<1>(it);
3600 if (ss != st)
3601 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
3602 return false;
3603 }
3604
3605 return true;
3606 }
3607
3608 namespace {
3609 /// Pattern to rewrite a subview op with MemRefCast arguments.
3610 /// This essentially pushes memref_cast past its consuming subview when
3611 /// `canFoldIntoConsumerOp` is true.
3612 ///
3613 /// Example:
3614 /// ```
3615 /// %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
3616 /// %1 = subview %0[0, 0][3, 4][1, 1] :
3617 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
3618 /// ```
3619 /// is rewritten into:
3620 /// ```
3621 /// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3622 /// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
3623 /// memref<3x4xf32, offset:?, strides:[?, 1]>
3624 /// ```
3625 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3626 public:
3627 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3628
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const3629 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3630 PatternRewriter &rewriter) const override {
3631 // Any constant operand, just return to let SubViewOpConstantFolder kick in.
3632 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3633 return matchPattern(operand, m_ConstantIndex());
3634 }))
3635 return failure();
3636
3637 auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
3638 if (!castOp)
3639 return failure();
3640
3641 if (!canFoldIntoConsumerOp(castOp))
3642 return failure();
3643
3644 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
3645 /// the cast source operand type and the SubViewOp static information. This
3646 /// is the resulting type if the MemRefCastOp were folded.
3647 Type resultType = SubViewOp::inferResultType(
3648 castOp.source().getType().cast<MemRefType>(),
3649 extractFromI64ArrayAttr(subViewOp.static_offsets()),
3650 extractFromI64ArrayAttr(subViewOp.static_sizes()),
3651 extractFromI64ArrayAttr(subViewOp.static_strides()));
3652 Value newSubView = rewriter.create<SubViewOp>(
3653 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
3654 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
3655 subViewOp.static_sizes(), subViewOp.static_strides());
3656 rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
3657 newSubView);
3658 return success();
3659 }
3660 };
3661 } // namespace
3662
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3663 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3664 MLIRContext *context) {
3665 results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubViewOp>,
3666 SubViewOpMemRefCastFolder>(context);
3667 }
3668
fold(ArrayRef<Attribute> operands)3669 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
3670 if (getResult().getType().cast<ShapedType>().getRank() == 0 &&
3671 source().getType().cast<ShapedType>().getRank() == 0)
3672 return getViewSource();
3673
3674 return {};
3675 }
3676
3677 //===----------------------------------------------------------------------===//
3678 // SubTensorOp
3679 //===----------------------------------------------------------------------===//
3680
3681 /// Print a subtensor op of the form:
3682 /// ```
3683 /// `subtensor` ssa-name
3684 /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3685 /// `:` ranked-tensor-type `to` ranked-tensor-type
3686 /// ```
print(OpAsmPrinter & p,SubTensorOp op)3687 static void print(OpAsmPrinter &p, SubTensorOp op) {
3688 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
3689 p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
3690 p << op.source();
3691 printOffsetsSizesAndStrides(p, op);
3692 p << " : " << op.getSourceType() << " to " << op.getType();
3693 }
3694
3695 /// Parse a subtensor op of the form:
3696 /// ```
3697 /// `subtensor` ssa-name
3698 /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3699 /// `:` ranked-tensor-type `to` ranked-tensor-type
3700 /// ```
parseSubTensorOp(OpAsmParser & parser,OperationState & result)3701 static ParseResult parseSubTensorOp(OpAsmParser &parser,
3702 OperationState &result) {
3703 OpAsmParser::OperandType srcInfo;
3704 if (parser.parseOperand(srcInfo))
3705 return failure();
3706 Type srcType, dstType;
3707 auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
3708 return failure(parser.parseOptionalAttrDict(result.attributes) ||
3709 parser.parseColonType(srcType) ||
3710 parser.parseKeywordType("to", dstType) ||
3711 parser.resolveOperand(srcInfo, srcType, result.operands));
3712 };
3713
3714 if (failed(parseOffsetsSizesAndStrides(parser, result,
3715 /*segmentSizes=*/{1}, // source tensor
3716 preResolutionFn)))
3717 return failure();
3718 return parser.addTypeToList(dstType, result.types);
3719 }
3720
3721 /// A subtensor result type can be fully inferred from the source type and the
3722 /// static representation of offsets, sizes and strides. Special sentinels
3723 /// encode the dynamic case.
inferResultType(RankedTensorType sourceRankedTensorType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)3724 Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
3725 ArrayRef<int64_t> staticOffsets,
3726 ArrayRef<int64_t> staticSizes,
3727 ArrayRef<int64_t> staticStrides) {
3728 unsigned rank = sourceRankedTensorType.getRank();
3729 (void)rank;
3730 assert(staticOffsets.size() == rank &&
3731 "unexpected staticOffsets size mismatch");
3732 assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
3733 assert(staticStrides.size() == rank &&
3734 "unexpected staticStrides size mismatch");
3735 return RankedTensorType::get(staticSizes,
3736 sourceRankedTensorType.getElementType());
3737 }
3738
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3739 void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
3740 Value source, ArrayRef<int64_t> staticOffsets,
3741 ArrayRef<int64_t> staticSizes,
3742 ArrayRef<int64_t> staticStrides,
3743 ValueRange offsets, ValueRange sizes,
3744 ValueRange strides,
3745 ArrayRef<NamedAttribute> attrs) {
3746 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3747 auto resultType = inferResultType(sourceRankedTensorType, staticOffsets,
3748 staticSizes, staticStrides);
3749 build(b, result, resultType, source, offsets, sizes, strides,
3750 b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3751 b.getI64ArrayAttr(staticStrides));
3752 result.addAttributes(attrs);
3753 }
3754
3755 /// Build a SubTensorOp with all dynamic entries: `staticOffsets`, `staticSizes`
3756 /// and `staticStrides` are automatically filled with sentinel values that
3757 /// encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3758 void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
3759 Value source, ValueRange offsets,
3760 ValueRange sizes, ValueRange strides,
3761 ArrayRef<NamedAttribute> attrs) {
3762 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3763 unsigned rank = sourceRankedTensorType.getRank();
3764 SmallVector<int64_t, 4> staticOffsetsVector(
3765 rank, ShapedType::kDynamicStrideOrOffset);
3766 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
3767 SmallVector<int64_t, 4> staticStridesVector(
3768 rank, ShapedType::kDynamicStrideOrOffset);
3769 build(b, result, source, staticOffsetsVector, staticSizesVector,
3770 staticStridesVector, offsets, sizes, strides, attrs);
3771 }
3772
3773 /// Verifier for SubTensorOp.
verify(SubTensorOp op)3774 static LogicalResult verify(SubTensorOp op) {
3775 // Verify result type against inferred type.
3776 auto expectedType = SubTensorOp::inferResultType(
3777 op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
3778 extractFromI64ArrayAttr(op.static_sizes()),
3779 extractFromI64ArrayAttr(op.static_strides()));
3780 auto result = isRankReducedType(expectedType, op.getType());
3781 return produceSubViewErrorMsg(result, op, expectedType);
3782 }
3783
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3784 void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3785 MLIRContext *context) {
3786 results
3787 .insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>>(
3788 context);
3789 }
3790
3791 //===----------------------------------------------------------------------===//
3792 // SubTensorInsertOp
3793 //===----------------------------------------------------------------------===//
3794
3795 /// Print a subtensor_insert op of the form:
3796 /// ```
3797 /// `subtensor_insert` ssa-name `into` ssa-name
3798 /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3799 /// `:` ranked-tensor-type `into` ranked-tensor-type
3800 /// ```
print(OpAsmPrinter & p,SubTensorInsertOp op)3801 static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
3802 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
3803 p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
3804 p << op.source() << " into " << op.dest();
3805 printOffsetsSizesAndStrides(p, op);
3806 p << " : " << op.getSourceType() << " into " << op.getType();
3807 }
3808
3809 /// Parse a subtensor_insert op of the form:
3810 /// ```
3811 /// `subtensor_insert` ssa-name `into` ssa-name
3812 /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3813 /// `:` ranked-tensor-type `into` ranked-tensor-type
3814 /// ```
parseSubTensorInsertOp(OpAsmParser & parser,OperationState & result)3815 static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
3816 OperationState &result) {
3817 OpAsmParser::OperandType srcInfo, dstInfo;
3818 if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") ||
3819 parser.parseOperand(dstInfo))
3820 return failure();
3821 Type srcType, dstType;
3822 auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
3823 return failure(parser.parseOptionalAttrDict(result.attributes) ||
3824 parser.parseColonType(srcType) ||
3825 parser.parseKeywordType("into", dstType) ||
3826 parser.resolveOperand(srcInfo, srcType, result.operands) ||
3827 parser.resolveOperand(dstInfo, dstType, result.operands));
3828 };
3829
3830 if (failed(parseOffsetsSizesAndStrides(
3831 parser, result,
3832 /*segmentSizes=*/{1, 1}, // source tensor, destination tensor
3833 preResolutionFn)))
3834 return failure();
3835 return parser.addTypeToList(dstType, result.types);
3836 }
3837
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3838 void mlir::SubTensorInsertOp::build(
3839 OpBuilder &b, OperationState &result, Value source, Value dest,
3840 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
3841 ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes,
3842 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
3843 build(b, result, dest.getType(), source, dest, offsets, sizes, strides,
3844 b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3845 b.getI64ArrayAttr(staticStrides));
3846 result.addAttributes(attrs);
3847 }
3848
3849 /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
3850 /// and `staticStrides` are automatically filled with source-memref-rank
3851 /// sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3852 void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
3853 Value source, Value dest,
3854 ValueRange offsets, ValueRange sizes,
3855 ValueRange strides,
3856 ArrayRef<NamedAttribute> attrs) {
3857 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3858 unsigned rank = sourceRankedTensorType.getRank();
3859 SmallVector<int64_t, 4> staticOffsetsVector(
3860 rank, ShapedType::kDynamicStrideOrOffset);
3861 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
3862 SmallVector<int64_t, 4> staticStridesVector(
3863 rank, ShapedType::kDynamicStrideOrOffset);
3864 build(b, result, source, dest, staticOffsetsVector, staticSizesVector,
3865 staticStridesVector, offsets, sizes, strides, attrs);
3866 }
3867
3868 /// Verifier for SubViewOp.
verify(SubTensorInsertOp op)3869 static LogicalResult verify(SubTensorInsertOp op) {
3870 if (op.getType() != op.dest().getType())
3871 return op.emitError("expected result type to be ") << op.dest().getType();
3872 return success();
3873 }
3874
3875 //===----------------------------------------------------------------------===//
3876 // TensorCastOp
3877 //===----------------------------------------------------------------------===//
3878
areCastCompatible(Type a,Type b)3879 bool TensorCastOp::areCastCompatible(Type a, Type b) {
3880 auto aT = a.dyn_cast<TensorType>();
3881 auto bT = b.dyn_cast<TensorType>();
3882 if (!aT || !bT)
3883 return false;
3884
3885 if (aT.getElementType() != bT.getElementType())
3886 return false;
3887
3888 return succeeded(verifyCompatibleShape(aT, bT));
3889 }
3890
fold(ArrayRef<Attribute> operands)3891 OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
3892 return impl::foldCastOp(*this);
3893 }
3894
3895 /// Compute a TensorType that has the joined shape knowledge of the two
3896 /// given TensorTypes. The element types need to match.
joinShapes(TensorType one,TensorType two)3897 static TensorType joinShapes(TensorType one, TensorType two) {
3898 assert(one.getElementType() == two.getElementType());
3899
3900 if (!one.hasRank())
3901 return two;
3902 if (!two.hasRank())
3903 return one;
3904
3905 int64_t rank = one.getRank();
3906 if (rank != two.getRank())
3907 return {};
3908
3909 SmallVector<int64_t, 4> join;
3910 join.reserve(rank);
3911 for (int64_t i = 0; i < rank; ++i) {
3912 if (one.isDynamicDim(i)) {
3913 join.push_back(two.getDimSize(i));
3914 continue;
3915 }
3916 if (two.isDynamicDim(i)) {
3917 join.push_back(one.getDimSize(i));
3918 continue;
3919 }
3920 if (one.getDimSize(i) != two.getDimSize(i))
3921 return {};
3922 join.push_back(one.getDimSize(i));
3923 }
3924 return RankedTensorType::get(join, one.getElementType());
3925 }
3926
3927 namespace {
3928
3929 /// Replaces chains of two tensor_cast operations by a single tensor_cast
3930 /// operation if doing so does not remove runtime constraints.
3931 struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
3932 using OpRewritePattern<TensorCastOp>::OpRewritePattern;
3933
matchAndRewrite__anon2fd8af2c2d11::ChainedTensorCast3934 LogicalResult matchAndRewrite(TensorCastOp tensorCast,
3935 PatternRewriter &rewriter) const final {
3936 auto tensorCastOperand =
3937 tensorCast.getOperand().getDefiningOp<TensorCastOp>();
3938
3939 if (!tensorCastOperand)
3940 return failure();
3941
3942 auto sourceType =
3943 tensorCastOperand.getOperand().getType().cast<TensorType>();
3944 auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
3945 auto resultType = tensorCast.getType().cast<TensorType>();
3946
3947 // We can remove the intermediate cast if joining all three produces the
3948 // same result as just joining the source and result shapes.
3949 auto firstJoin =
3950 joinShapes(joinShapes(sourceType, intermediateType), resultType);
3951
3952 // The join might not exist if the cast sequence would fail at runtime.
3953 if (!firstJoin)
3954 return failure();
3955
3956 // The newJoin always exists if the above join exists, it might just contain
3957 // less information. If so, we cannot drop the intermediate cast, as doing
3958 // so would remove runtime checks.
3959 auto newJoin = joinShapes(sourceType, resultType);
3960 if (firstJoin != newJoin)
3961 return failure();
3962
3963 rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
3964 tensorCastOperand.getOperand());
3965 return success();
3966 }
3967 };
3968
3969 } // namespace
3970
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3971 void TensorCastOp::getCanonicalizationPatterns(
3972 OwningRewritePatternList &results, MLIRContext *context) {
3973 results.insert<ChainedTensorCast>(context);
3974 }
3975
3976 //===----------------------------------------------------------------------===//
3977 // TensorLoadOp
3978 //===----------------------------------------------------------------------===//
3979
fold(ArrayRef<Attribute>)3980 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
3981 if (auto tensorToMemref = memref().getDefiningOp<TensorToMemrefOp>())
3982 return tensorToMemref.tensor();
3983 return {};
3984 }
3985
3986 //===----------------------------------------------------------------------===//
3987 // TensorToMemrefOp
3988 //===----------------------------------------------------------------------===//
3989
fold(ArrayRef<Attribute>)3990 OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
3991 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
3992 if (tensorLoad.memref().getType() == getType())
3993 return tensorLoad.memref();
3994 return {};
3995 }
3996
3997 //===----------------------------------------------------------------------===//
3998 // TransposeOp
3999 //===----------------------------------------------------------------------===//
4000
4001 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
inferTransposeResultType(MemRefType memRefType,AffineMap permutationMap)4002 static MemRefType inferTransposeResultType(MemRefType memRefType,
4003 AffineMap permutationMap) {
4004 auto rank = memRefType.getRank();
4005 auto originalSizes = memRefType.getShape();
4006 // Compute permuted sizes.
4007 SmallVector<int64_t, 4> sizes(rank, 0);
4008 for (auto en : llvm::enumerate(permutationMap.getResults()))
4009 sizes[en.index()] =
4010 originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
4011
4012 // Compute permuted strides.
4013 int64_t offset;
4014 SmallVector<int64_t, 4> strides;
4015 auto res = getStridesAndOffset(memRefType, strides, offset);
4016 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
4017 (void)res;
4018 auto map =
4019 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
4020 map = permutationMap ? map.compose(permutationMap) : map;
4021 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
4022 }
4023
build(OpBuilder & b,OperationState & result,Value in,AffineMapAttr permutation,ArrayRef<NamedAttribute> attrs)4024 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
4025 AffineMapAttr permutation,
4026 ArrayRef<NamedAttribute> attrs) {
4027 auto permutationMap = permutation.getValue();
4028 assert(permutationMap);
4029
4030 auto memRefType = in.getType().cast<MemRefType>();
4031 // Compute result type.
4032 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
4033
4034 build(b, result, resultType, in, attrs);
4035 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
4036 }
4037
4038 // transpose $in $permutation attr-dict : type($in) `to` type(results)
print(OpAsmPrinter & p,TransposeOp op)4039 static void print(OpAsmPrinter &p, TransposeOp op) {
4040 p << "transpose " << op.in() << " " << op.permutation();
4041 p.printOptionalAttrDict(op.getAttrs(),
4042 {TransposeOp::getPermutationAttrName()});
4043 p << " : " << op.in().getType() << " to " << op.getType();
4044 }
4045
parseTransposeOp(OpAsmParser & parser,OperationState & result)4046 static ParseResult parseTransposeOp(OpAsmParser &parser,
4047 OperationState &result) {
4048 OpAsmParser::OperandType in;
4049 AffineMap permutation;
4050 MemRefType srcType, dstType;
4051 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
4052 parser.parseOptionalAttrDict(result.attributes) ||
4053 parser.parseColonType(srcType) ||
4054 parser.resolveOperand(in, srcType, result.operands) ||
4055 parser.parseKeywordType("to", dstType) ||
4056 parser.addTypeToList(dstType, result.types))
4057 return failure();
4058
4059 result.addAttribute(TransposeOp::getPermutationAttrName(),
4060 AffineMapAttr::get(permutation));
4061 return success();
4062 }
4063
verify(TransposeOp op)4064 static LogicalResult verify(TransposeOp op) {
4065 if (!op.permutation().isPermutation())
4066 return op.emitOpError("expected a permutation map");
4067 if (op.permutation().getNumDims() != op.getShapedType().getRank())
4068 return op.emitOpError(
4069 "expected a permutation map of same rank as the input");
4070
4071 auto srcType = op.in().getType().cast<MemRefType>();
4072 auto dstType = op.getType().cast<MemRefType>();
4073 auto transposedType = inferTransposeResultType(srcType, op.permutation());
4074 if (dstType != transposedType)
4075 return op.emitOpError("output type ")
4076 << dstType << " does not match transposed input type " << srcType
4077 << ", " << transposedType;
4078 return success();
4079 }
4080
fold(ArrayRef<Attribute>)4081 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
4082 if (succeeded(foldMemRefCast(*this)))
4083 return getResult();
4084 return {};
4085 }
4086
4087 //===----------------------------------------------------------------------===//
4088 // TruncateIOp
4089 //===----------------------------------------------------------------------===//
4090
verify(TruncateIOp op)4091 static LogicalResult verify(TruncateIOp op) {
4092 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
4093 auto dstType = getElementTypeOrSelf(op.getType());
4094
4095 if (srcType.isa<IndexType>())
4096 return op.emitError() << srcType << " is not a valid operand type";
4097 if (dstType.isa<IndexType>())
4098 return op.emitError() << dstType << " is not a valid result type";
4099
4100 if (srcType.cast<IntegerType>().getWidth() <=
4101 dstType.cast<IntegerType>().getWidth())
4102 return op.emitError("operand type ")
4103 << srcType << " must be wider than result type " << dstType;
4104
4105 return success();
4106 }
4107
4108 //===----------------------------------------------------------------------===//
4109 // UnsignedDivIOp
4110 //===----------------------------------------------------------------------===//
4111
fold(ArrayRef<Attribute> operands)4112 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
4113 assert(operands.size() == 2 && "binary operation takes two operands");
4114
4115 // Don't fold if it would require a division by zero.
4116 bool div0 = false;
4117 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
4118 if (div0 || !b) {
4119 div0 = true;
4120 return a;
4121 }
4122 return a.udiv(b);
4123 });
4124
4125 // Fold out division by one. Assumes all tensors of all ones are splats.
4126 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
4127 if (rhs.getValue() == 1)
4128 return lhs();
4129 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
4130 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
4131 return lhs();
4132 }
4133
4134 return div0 ? Attribute() : result;
4135 }
4136
4137 //===----------------------------------------------------------------------===//
4138 // UnsignedRemIOp
4139 //===----------------------------------------------------------------------===//
4140
fold(ArrayRef<Attribute> operands)4141 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
4142 assert(operands.size() == 2 && "remi_unsigned takes two operands");
4143
4144 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
4145 if (!rhs)
4146 return {};
4147 auto rhsValue = rhs.getValue();
4148
4149 // x % 1 = 0
4150 if (rhsValue.isOneValue())
4151 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
4152
4153 // Don't fold if it requires division by zero.
4154 if (rhsValue.isNullValue())
4155 return {};
4156
4157 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
4158 if (!lhs)
4159 return {};
4160 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
4161 }
4162
4163 //===----------------------------------------------------------------------===//
4164 // ViewOp
4165 //===----------------------------------------------------------------------===//
4166
parseViewOp(OpAsmParser & parser,OperationState & result)4167 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
4168 OpAsmParser::OperandType srcInfo;
4169 SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
4170 SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
4171 auto indexType = parser.getBuilder().getIndexType();
4172 Type srcType, dstType;
4173 llvm::SMLoc offsetLoc;
4174 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
4175 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
4176 return failure();
4177
4178 if (offsetInfo.size() != 1)
4179 return parser.emitError(offsetLoc) << "expects 1 offset operand";
4180
4181 return failure(
4182 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
4183 parser.parseOptionalAttrDict(result.attributes) ||
4184 parser.parseColonType(srcType) ||
4185 parser.resolveOperand(srcInfo, srcType, result.operands) ||
4186 parser.resolveOperands(offsetInfo, indexType, result.operands) ||
4187 parser.resolveOperands(sizesInfo, indexType, result.operands) ||
4188 parser.parseKeywordType("to", dstType) ||
4189 parser.addTypeToList(dstType, result.types));
4190 }
4191
print(OpAsmPrinter & p,ViewOp op)4192 static void print(OpAsmPrinter &p, ViewOp op) {
4193 p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
4194 p.printOperand(op.byte_shift());
4195 p << "][" << op.sizes() << ']';
4196 p.printOptionalAttrDict(op.getAttrs());
4197 p << " : " << op.getOperand(0).getType() << " to " << op.getType();
4198 }
4199
verify(ViewOp op)4200 static LogicalResult verify(ViewOp op) {
4201 auto baseType = op.getOperand(0).getType().cast<MemRefType>();
4202 auto viewType = op.getType();
4203
4204 // The base memref should have identity layout map (or none).
4205 if (baseType.getAffineMaps().size() > 1 ||
4206 (baseType.getAffineMaps().size() == 1 &&
4207 !baseType.getAffineMaps()[0].isIdentity()))
4208 return op.emitError("unsupported map for base memref type ") << baseType;
4209
4210 // The result memref should have identity layout map (or none).
4211 if (viewType.getAffineMaps().size() > 1 ||
4212 (viewType.getAffineMaps().size() == 1 &&
4213 !viewType.getAffineMaps()[0].isIdentity()))
4214 return op.emitError("unsupported map for result memref type ") << viewType;
4215
4216 // The base memref and the view memref should be in the same memory space.
4217 if (baseType.getMemorySpace() != viewType.getMemorySpace())
4218 return op.emitError("different memory spaces specified for base memref "
4219 "type ")
4220 << baseType << " and view memref type " << viewType;
4221
4222 // Verify that we have the correct number of sizes for the result type.
4223 unsigned numDynamicDims = viewType.getNumDynamicDims();
4224 if (op.sizes().size() != numDynamicDims)
4225 return op.emitError("incorrect number of size operands for type ")
4226 << viewType;
4227
4228 return success();
4229 }
4230
getViewSource()4231 Value ViewOp::getViewSource() { return source(); }
4232
4233 namespace {
4234
4235 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
4236 using OpRewritePattern<ViewOp>::OpRewritePattern;
4237
matchAndRewrite__anon2fd8af2c2f11::ViewOpShapeFolder4238 LogicalResult matchAndRewrite(ViewOp viewOp,
4239 PatternRewriter &rewriter) const override {
4240 // Return if none of the operands are constants.
4241 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
4242 return matchPattern(operand, m_ConstantIndex());
4243 }))
4244 return failure();
4245
4246 // Get result memref type.
4247 auto memrefType = viewOp.getType();
4248
4249 // Get offset from old memref view type 'memRefType'.
4250 int64_t oldOffset;
4251 SmallVector<int64_t, 4> oldStrides;
4252 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
4253 return failure();
4254 assert(oldOffset == 0 && "Expected 0 offset");
4255
4256 SmallVector<Value, 4> newOperands;
4257
4258 // Offset cannot be folded into result type.
4259
4260 // Fold any dynamic dim operands which are produced by a constant.
4261 SmallVector<int64_t, 4> newShapeConstants;
4262 newShapeConstants.reserve(memrefType.getRank());
4263
4264 unsigned dynamicDimPos = 0;
4265 unsigned rank = memrefType.getRank();
4266 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
4267 int64_t dimSize = memrefType.getDimSize(dim);
4268 // If this is already static dimension, keep it.
4269 if (!ShapedType::isDynamic(dimSize)) {
4270 newShapeConstants.push_back(dimSize);
4271 continue;
4272 }
4273 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
4274 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
4275 // Dynamic shape dimension will be folded.
4276 newShapeConstants.push_back(constantIndexOp.getValue());
4277 } else {
4278 // Dynamic shape dimension not folded; copy operand from old memref.
4279 newShapeConstants.push_back(dimSize);
4280 newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
4281 }
4282 dynamicDimPos++;
4283 }
4284
4285 // Create new memref type with constant folded dims.
4286 MemRefType newMemRefType =
4287 MemRefType::Builder(memrefType).setShape(newShapeConstants);
4288 // Nothing new, don't fold.
4289 if (newMemRefType == memrefType)
4290 return failure();
4291
4292 // Create new ViewOp.
4293 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
4294 viewOp.getOperand(0),
4295 viewOp.byte_shift(), newOperands);
4296 // Insert a cast so we have the same type as the old memref type.
4297 rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
4298 viewOp.getType());
4299 return success();
4300 }
4301 };
4302
4303 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
4304 using OpRewritePattern<ViewOp>::OpRewritePattern;
4305
matchAndRewrite__anon2fd8af2c2f11::ViewOpMemrefCastFolder4306 LogicalResult matchAndRewrite(ViewOp viewOp,
4307 PatternRewriter &rewriter) const override {
4308 Value memrefOperand = viewOp.getOperand(0);
4309 MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
4310 if (!memrefCastOp)
4311 return failure();
4312 Value allocOperand = memrefCastOp.getOperand();
4313 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
4314 if (!allocOp)
4315 return failure();
4316 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
4317 viewOp.byte_shift(), viewOp.sizes());
4318 return success();
4319 }
4320 };
4321
4322 } // end anonymous namespace
4323
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)4324 void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
4325 MLIRContext *context) {
4326 results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
4327 }
4328
4329 //===----------------------------------------------------------------------===//
4330 // XOrOp
4331 //===----------------------------------------------------------------------===//
4332
fold(ArrayRef<Attribute> operands)4333 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
4334 /// xor(x, 0) -> x
4335 if (matchPattern(rhs(), m_Zero()))
4336 return lhs();
4337 /// xor(x,x) -> 0
4338 if (lhs() == rhs())
4339 return Builder(getContext()).getZeroAttr(getType());
4340
4341 return constFoldBinaryOp<IntegerAttr>(operands,
4342 [](APInt a, APInt b) { return a ^ b; });
4343 }
4344
4345 //===----------------------------------------------------------------------===//
4346 // ZeroExtendIOp
4347 //===----------------------------------------------------------------------===//
4348
verify(ZeroExtendIOp op)4349 static LogicalResult verify(ZeroExtendIOp op) {
4350 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
4351 auto dstType = getElementTypeOrSelf(op.getType());
4352
4353 if (srcType.isa<IndexType>())
4354 return op.emitError() << srcType << " is not a valid operand type";
4355 if (dstType.isa<IndexType>())
4356 return op.emitError() << dstType << " is not a valid result type";
4357
4358 if (srcType.cast<IntegerType>().getWidth() >=
4359 dstType.cast<IntegerType>().getWidth())
4360 return op.emitError("result type ")
4361 << dstType << " must be wider than operand type " << srcType;
4362
4363 return success();
4364 }
4365
4366 //===----------------------------------------------------------------------===//
4367 // TableGen'd op method definitions
4368 //===----------------------------------------------------------------------===//
4369
4370 #define GET_OP_CLASSES
4371 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
4372