1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10 
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/OpImplementation.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 // Pull in all enum type definitions and utility function declarations.
31 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
32 
33 using namespace mlir;
34 
35 //===----------------------------------------------------------------------===//
36 // StandardOpsDialect Interfaces
37 //===----------------------------------------------------------------------===//
38 namespace {
39 /// This class defines the interface for handling inlining with standard
40 /// operations.
41 struct StdInlinerInterface : public DialectInlinerInterface {
42   using DialectInlinerInterface::DialectInlinerInterface;
43 
44   //===--------------------------------------------------------------------===//
45   // Analysis Hooks
46   //===--------------------------------------------------------------------===//
47 
48   /// All call operations within standard ops can be inlined.
isLegalToInline__anon2fd8af2c0111::StdInlinerInterface49   bool isLegalToInline(Operation *call, Operation *callable,
50                        bool wouldBeCloned) const final {
51     return true;
52   }
53 
54   /// All operations within standard ops can be inlined.
isLegalToInline__anon2fd8af2c0111::StdInlinerInterface55   bool isLegalToInline(Operation *, Region *, bool,
56                        BlockAndValueMapping &) const final {
57     return true;
58   }
59 
60   //===--------------------------------------------------------------------===//
61   // Transformation Hooks
62   //===--------------------------------------------------------------------===//
63 
64   /// Handle the given inlined terminator by replacing it with a new operation
65   /// as necessary.
handleTerminator__anon2fd8af2c0111::StdInlinerInterface66   void handleTerminator(Operation *op, Block *newDest) const final {
67     // Only "std.return" needs to be handled here.
68     auto returnOp = dyn_cast<ReturnOp>(op);
69     if (!returnOp)
70       return;
71 
72     // Replace the return with a branch to the dest.
73     OpBuilder builder(op);
74     builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
75     op->erase();
76   }
77 
78   /// Handle the given inlined terminator by replacing it with a new operation
79   /// as necessary.
handleTerminator__anon2fd8af2c0111::StdInlinerInterface80   void handleTerminator(Operation *op,
81                         ArrayRef<Value> valuesToRepl) const final {
82     // Only "std.return" needs to be handled here.
83     auto returnOp = cast<ReturnOp>(op);
84 
85     // Replace the values directly with the return operands.
86     assert(returnOp.getNumOperands() == valuesToRepl.size());
87     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
88       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
89   }
90 };
91 } // end anonymous namespace
92 
93 //===----------------------------------------------------------------------===//
94 // StandardOpsDialect
95 //===----------------------------------------------------------------------===//
96 
97 /// A custom unary operation printer that omits the "std." prefix from the
98 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)99 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
100   assert(op->getNumOperands() == 1 && "unary op should have one operand");
101   assert(op->getNumResults() == 1 && "unary op should have one result");
102 
103   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
104   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
105     << op->getOperand(0);
106   p.printOptionalAttrDict(op->getAttrs());
107   p << " : " << op->getOperand(0).getType();
108 }
109 
110 /// A custom binary operation printer that omits the "std." prefix from the
111 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)112 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
113   assert(op->getNumOperands() == 2 && "binary op should have two operands");
114   assert(op->getNumResults() == 1 && "binary op should have one result");
115 
116   // If not all the operand and result types are the same, just use the
117   // generic assembly form to avoid omitting information in printing.
118   auto resultType = op->getResult(0).getType();
119   if (op->getOperand(0).getType() != resultType ||
120       op->getOperand(1).getType() != resultType) {
121     p.printGenericOp(op);
122     return;
123   }
124 
125   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
126   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
127     << op->getOperand(0) << ", " << op->getOperand(1);
128   p.printOptionalAttrDict(op->getAttrs());
129 
130   // Now we can output only one type for all operands and the result.
131   p << " : " << op->getResult(0).getType();
132 }
133 
134 /// A custom cast operation printer that omits the "std." prefix from the
135 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)136 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
137   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
138   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
139     << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
140     << op->getResult(0).getType();
141 }
142 
143 /// A custom cast operation verifier.
144 template <typename T>
verifyCastOp(T op)145 static LogicalResult verifyCastOp(T op) {
146   auto opType = op.getOperand().getType();
147   auto resType = op.getType();
148   if (!T::areCastCompatible(opType, resType))
149     return op.emitError("operand type ") << opType << " and result type "
150                                          << resType << " are cast incompatible";
151 
152   return success();
153 }
154 
initialize()155 void StandardOpsDialect::initialize() {
156   addOperations<DmaStartOp, DmaWaitOp,
157 #define GET_OP_LIST
158 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
159                 >();
160   addInterfaces<StdInlinerInterface>();
161 }
162 
163 /// Materialize a single constant operation from a given attribute value with
164 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)165 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
166                                                    Attribute value, Type type,
167                                                    Location loc) {
168   return builder.create<ConstantOp>(loc, type, value);
169 }
170 
171 /// Matches a ConstantIndexOp.
172 /// TODO: This should probably just be a general matcher that uses m_Constant
173 /// and checks the operation for an index type.
m_ConstantIndex()174 static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
175   return detail::op_matcher<ConstantIndexOp>();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // Common canonicalization pattern support logic
180 //===----------------------------------------------------------------------===//
181 
182 /// This is a common class used for patterns of the form
183 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
184 /// into the root operation directly.
foldMemRefCast(Operation * op)185 static LogicalResult foldMemRefCast(Operation *op) {
186   bool folded = false;
187   for (OpOperand &operand : op->getOpOperands()) {
188     auto cast = operand.get().getDefiningOp<MemRefCastOp>();
189     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
190       operand.set(cast.getOperand());
191       folded = true;
192     }
193   }
194   return success(folded);
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // Common cast compatibility check for vector types.
199 //===----------------------------------------------------------------------===//
200 
201 /// This method checks for cast compatibility of vector types.
202 /// If 'a' and 'b' are vector types, and they are cast compatible,
203 /// it calls the 'areElementsCastCompatible' function to check for
204 /// element cast compatibility.
205 /// Returns 'true' if the vector types are cast compatible,  and 'false'
206 /// otherwise.
areVectorCastSimpleCompatible(Type a,Type b,function_ref<bool (Type,Type)> areElementsCastCompatible)207 static bool areVectorCastSimpleCompatible(
208     Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
209   if (auto va = a.dyn_cast<VectorType>())
210     if (auto vb = b.dyn_cast<VectorType>())
211       return va.getShape().equals(vb.getShape()) &&
212              areElementsCastCompatible(va.getElementType(),
213                                        vb.getElementType());
214   return false;
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
219 //===----------------------------------------------------------------------===//
220 
getTensorTypeFromMemRefType(Type type)221 static Type getTensorTypeFromMemRefType(Type type) {
222   if (auto memref = type.dyn_cast<MemRefType>())
223     return RankedTensorType::get(memref.getShape(), memref.getElementType());
224   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
225     return UnrankedTensorType::get(memref.getElementType());
226   return NoneType::get(type.getContext());
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // AddFOp
231 //===----------------------------------------------------------------------===//
232 
fold(ArrayRef<Attribute> operands)233 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
234   return constFoldBinaryOp<FloatAttr>(
235       operands, [](APFloat a, APFloat b) { return a + b; });
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // AddIOp
240 //===----------------------------------------------------------------------===//
241 
fold(ArrayRef<Attribute> operands)242 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
243   /// addi(x, 0) -> x
244   if (matchPattern(rhs(), m_Zero()))
245     return lhs();
246 
247   return constFoldBinaryOp<IntegerAttr>(operands,
248                                         [](APInt a, APInt b) { return a + b; });
249 }
250 
251 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)252 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
253   return llvm::to_vector<4>(
254       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
255         return a.cast<IntegerAttr>().getInt();
256       }));
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // AllocOp / AllocaOp
261 //===----------------------------------------------------------------------===//
262 
263 template <typename AllocLikeOp>
verifyAllocLikeOp(AllocLikeOp op)264 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
265   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
266                 "applies to only alloc or alloca");
267   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
268   if (!memRefType)
269     return op.emitOpError("result must be a memref");
270 
271   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
272       memRefType.getNumDynamicDims())
273     return op.emitOpError("dimension operand count does not equal memref "
274                           "dynamic dimension count");
275 
276   unsigned numSymbols = 0;
277   if (!memRefType.getAffineMaps().empty())
278     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
279   if (op.symbolOperands().size() != numSymbols)
280     return op.emitOpError(
281         "symbol operand count does not equal memref symbol count");
282 
283   return success();
284 }
285 
verify(AllocOp op)286 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
287 
verify(AllocaOp op)288 static LogicalResult verify(AllocaOp op) {
289   // An alloca op needs to have an ancestor with an allocation scope trait.
290   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
291     return op.emitOpError(
292         "requires an ancestor op with AutomaticAllocationScope trait");
293 
294   return verifyAllocLikeOp(op);
295 }
296 
297 namespace {
298 /// Fold constant dimensions into an alloc like operation.
299 template <typename AllocLikeOp>
300 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
301   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
302 
matchAndRewrite__anon2fd8af2c0511::SimplifyAllocConst303   LogicalResult matchAndRewrite(AllocLikeOp alloc,
304                                 PatternRewriter &rewriter) const override {
305     // Check to see if any dimensions operands are constants.  If so, we can
306     // substitute and drop them.
307     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
308           return matchPattern(operand, m_ConstantIndex());
309         }))
310       return failure();
311 
312     auto memrefType = alloc.getType();
313 
314     // Ok, we have one or more constant operands.  Collect the non-constant ones
315     // and keep track of the resultant memref type to build.
316     SmallVector<int64_t, 4> newShapeConstants;
317     newShapeConstants.reserve(memrefType.getRank());
318     SmallVector<Value, 4> newOperands;
319 
320     unsigned dynamicDimPos = 0;
321     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
322       int64_t dimSize = memrefType.getDimSize(dim);
323       // If this is already static dimension, keep it.
324       if (dimSize != -1) {
325         newShapeConstants.push_back(dimSize);
326         continue;
327       }
328       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
329       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
330         // Dynamic shape dimension will be folded.
331         newShapeConstants.push_back(constantIndexOp.getValue());
332       } else {
333         // Dynamic shape dimension not folded; copy operand from old memref.
334         newShapeConstants.push_back(-1);
335         newOperands.push_back(alloc.getOperand(dynamicDimPos));
336       }
337       dynamicDimPos++;
338     }
339 
340     // Create new memref type (which will have fewer dynamic dimensions).
341     MemRefType newMemRefType =
342         MemRefType::Builder(memrefType).setShape(newShapeConstants);
343     assert(static_cast<int64_t>(newOperands.size()) ==
344            newMemRefType.getNumDynamicDims());
345 
346     // Create and insert the alloc op for the new memref.
347     auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
348                                                  newOperands, IntegerAttr());
349     // Insert a cast so we have the same type as the old alloc.
350     auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
351                                                     alloc.getType());
352 
353     rewriter.replaceOp(alloc, {resultCast});
354     return success();
355   }
356 };
357 
358 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
359 /// but can still be deleted if it has zero uses.
360 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
361   using OpRewritePattern<AllocOp>::OpRewritePattern;
362 
matchAndRewrite__anon2fd8af2c0511::SimplifyDeadAlloc363   LogicalResult matchAndRewrite(AllocOp alloc,
364                                 PatternRewriter &rewriter) const override {
365     if (alloc.use_empty()) {
366       rewriter.eraseOp(alloc);
367       return success();
368     }
369     return failure();
370   }
371 };
372 } // end anonymous namespace.
373 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)374 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
375                                           MLIRContext *context) {
376   results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
377 }
378 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)379 void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
380                                            MLIRContext *context) {
381   results.insert<SimplifyAllocConst<AllocaOp>>(context);
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // AndOp
386 //===----------------------------------------------------------------------===//
387 
fold(ArrayRef<Attribute> operands)388 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
389   /// and(x, 0) -> 0
390   if (matchPattern(rhs(), m_Zero()))
391     return rhs();
392   /// and(x, allOnes) -> x
393   APInt intValue;
394   if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
395       intValue.isAllOnesValue())
396     return lhs();
397   /// and(x,x) -> x
398   if (lhs() == rhs())
399     return rhs();
400 
401   return constFoldBinaryOp<IntegerAttr>(operands,
402                                         [](APInt a, APInt b) { return a & b; });
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // AssertOp
407 //===----------------------------------------------------------------------===//
408 
409 namespace {
410 struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
411   using OpRewritePattern<AssertOp>::OpRewritePattern;
412 
matchAndRewrite__anon2fd8af2c0811::EraseRedundantAssertions413   LogicalResult matchAndRewrite(AssertOp op,
414                                 PatternRewriter &rewriter) const override {
415     // Erase assertion if argument is constant true.
416     if (matchPattern(op.arg(), m_One())) {
417       rewriter.eraseOp(op);
418       return success();
419     }
420     return failure();
421   }
422 };
423 } // namespace
424 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)425 void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
426                                            MLIRContext *context) {
427   patterns.insert<EraseRedundantAssertions>(context);
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // AssumeAlignmentOp
432 //===----------------------------------------------------------------------===//
433 
verify(AssumeAlignmentOp op)434 static LogicalResult verify(AssumeAlignmentOp op) {
435   unsigned alignment = op.alignment();
436   if (!llvm::isPowerOf2_32(alignment))
437     return op.emitOpError("alignment must be power of 2");
438   return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // AtomicRMWOp
443 //===----------------------------------------------------------------------===//
444 
verify(AtomicRMWOp op)445 static LogicalResult verify(AtomicRMWOp op) {
446   if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
447     return op.emitOpError(
448         "expects the number of subscripts to be equal to memref rank");
449   switch (op.kind()) {
450   case AtomicRMWKind::addf:
451   case AtomicRMWKind::maxf:
452   case AtomicRMWKind::minf:
453   case AtomicRMWKind::mulf:
454     if (!op.value().getType().isa<FloatType>())
455       return op.emitOpError()
456              << "with kind '" << stringifyAtomicRMWKind(op.kind())
457              << "' expects a floating-point type";
458     break;
459   case AtomicRMWKind::addi:
460   case AtomicRMWKind::maxs:
461   case AtomicRMWKind::maxu:
462   case AtomicRMWKind::mins:
463   case AtomicRMWKind::minu:
464   case AtomicRMWKind::muli:
465     if (!op.value().getType().isa<IntegerType>())
466       return op.emitOpError()
467              << "with kind '" << stringifyAtomicRMWKind(op.kind())
468              << "' expects an integer type";
469     break;
470   default:
471     break;
472   }
473   return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // GenericAtomicRMWOp
478 //===----------------------------------------------------------------------===//
479 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)480 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
481                                Value memref, ValueRange ivs) {
482   result.addOperands(memref);
483   result.addOperands(ivs);
484 
485   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
486     Type elementType = memrefType.getElementType();
487     result.addTypes(elementType);
488 
489     Region *bodyRegion = result.addRegion();
490     bodyRegion->push_back(new Block());
491     bodyRegion->addArgument(elementType);
492   }
493 }
494 
verify(GenericAtomicRMWOp op)495 static LogicalResult verify(GenericAtomicRMWOp op) {
496   auto &body = op.body();
497   if (body.getNumArguments() != 1)
498     return op.emitOpError("expected single number of entry block arguments");
499 
500   if (op.getResult().getType() != body.getArgument(0).getType())
501     return op.emitOpError(
502         "expected block argument of the same type result type");
503 
504   bool hasSideEffects =
505       body.walk([&](Operation *nestedOp) {
506             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
507               return WalkResult::advance();
508             nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
509                                 "only operations with no side effects");
510             return WalkResult::interrupt();
511           })
512           .wasInterrupted();
513   return hasSideEffects ? failure() : success();
514 }
515 
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)516 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
517                                            OperationState &result) {
518   OpAsmParser::OperandType memref;
519   Type memrefType;
520   SmallVector<OpAsmParser::OperandType, 4> ivs;
521 
522   Type indexType = parser.getBuilder().getIndexType();
523   if (parser.parseOperand(memref) ||
524       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
525       parser.parseColonType(memrefType) ||
526       parser.resolveOperand(memref, memrefType, result.operands) ||
527       parser.resolveOperands(ivs, indexType, result.operands))
528     return failure();
529 
530   Region *body = result.addRegion();
531   if (parser.parseRegion(*body, llvm::None, llvm::None) ||
532       parser.parseOptionalAttrDict(result.attributes))
533     return failure();
534   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
535   return success();
536 }
537 
print(OpAsmPrinter & p,GenericAtomicRMWOp op)538 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
539   p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
540     << "] : " << op.memref().getType();
541   p.printRegion(op.body());
542   p.printOptionalAttrDict(op.getAttrs());
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // AtomicYieldOp
547 //===----------------------------------------------------------------------===//
548 
verify(AtomicYieldOp op)549 static LogicalResult verify(AtomicYieldOp op) {
550   Type parentType = op->getParentOp()->getResultTypes().front();
551   Type resultType = op.result().getType();
552   if (parentType != resultType)
553     return op.emitOpError() << "types mismatch between yield op: " << resultType
554                             << " and its parent: " << parentType;
555   return success();
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // BranchOp
560 //===----------------------------------------------------------------------===//
561 
562 /// Given a successor, try to collapse it to a new destination if it only
563 /// contains a passthrough unconditional branch. If the successor is
564 /// collapsable, `successor` and `successorOperands` are updated to reference
565 /// the new destination and values. `argStorage` is an optional storage to use
566 /// if operands to the collapsed successor need to be remapped.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)567 static LogicalResult collapseBranch(Block *&successor,
568                                     ValueRange &successorOperands,
569                                     SmallVectorImpl<Value> &argStorage) {
570   // Check that the successor only contains a unconditional branch.
571   if (std::next(successor->begin()) != successor->end())
572     return failure();
573   // Check that the terminator is an unconditional branch.
574   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
575   if (!successorBranch)
576     return failure();
577   // Check that the arguments are only used within the terminator.
578   for (BlockArgument arg : successor->getArguments()) {
579     for (Operation *user : arg.getUsers())
580       if (user != successorBranch)
581         return failure();
582   }
583   // Don't try to collapse branches to infinite loops.
584   Block *successorDest = successorBranch.getDest();
585   if (successorDest == successor)
586     return failure();
587 
588   // Update the operands to the successor. If the branch parent has no
589   // arguments, we can use the branch operands directly.
590   OperandRange operands = successorBranch.getOperands();
591   if (successor->args_empty()) {
592     successor = successorDest;
593     successorOperands = operands;
594     return success();
595   }
596 
597   // Otherwise, we need to remap any argument operands.
598   for (Value operand : operands) {
599     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
600     if (argOperand && argOperand.getOwner() == successor)
601       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
602     else
603       argStorage.push_back(operand);
604   }
605   successor = successorDest;
606   successorOperands = argStorage;
607   return success();
608 }
609 
610 namespace {
611 /// Simplify a branch to a block that has a single predecessor. This effectively
612 /// merges the two blocks.
613 struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
614   using OpRewritePattern<BranchOp>::OpRewritePattern;
615 
matchAndRewrite__anon2fd8af2c0a11::SimplifyBrToBlockWithSinglePred616   LogicalResult matchAndRewrite(BranchOp op,
617                                 PatternRewriter &rewriter) const override {
618     // Check that the successor block has a single predecessor.
619     Block *succ = op.getDest();
620     Block *opParent = op->getBlock();
621     if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
622       return failure();
623 
624     // Merge the successor into the current block and erase the branch.
625     rewriter.mergeBlocks(succ, opParent, op.getOperands());
626     rewriter.eraseOp(op);
627     return success();
628   }
629 };
630 
631 ///   br ^bb1
632 /// ^bb1
633 ///   br ^bbN(...)
634 ///
635 ///  -> br ^bbN(...)
636 ///
637 struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
638   using OpRewritePattern<BranchOp>::OpRewritePattern;
639 
matchAndRewrite__anon2fd8af2c0a11::SimplifyPassThroughBr640   LogicalResult matchAndRewrite(BranchOp op,
641                                 PatternRewriter &rewriter) const override {
642     Block *dest = op.getDest();
643     ValueRange destOperands = op.getOperands();
644     SmallVector<Value, 4> destOperandStorage;
645 
646     // Try to collapse the successor if it points somewhere other than this
647     // block.
648     if (dest == op->getBlock() ||
649         failed(collapseBranch(dest, destOperands, destOperandStorage)))
650       return failure();
651 
652     // Create a new branch with the collapsed successor.
653     rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
654     return success();
655   }
656 };
657 } // end anonymous namespace.
658 
getDest()659 Block *BranchOp::getDest() { return getSuccessor(); }
660 
setDest(Block * block)661 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
662 
eraseOperand(unsigned index)663 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
664 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)665 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
666                                            MLIRContext *context) {
667   results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
668       context);
669 }
670 
671 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)672 BranchOp::getMutableSuccessorOperands(unsigned index) {
673   assert(index == 0 && "invalid successor index");
674   return destOperandsMutable();
675 }
676 
getSuccessorForOperands(ArrayRef<Attribute>)677 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
678 
679 //===----------------------------------------------------------------------===//
680 // CallOp
681 //===----------------------------------------------------------------------===//
682 
verifySymbolUses(SymbolTableCollection & symbolTable)683 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
684   // Check that the callee attribute was specified.
685   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
686   if (!fnAttr)
687     return emitOpError("requires a 'callee' symbol reference attribute");
688   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
689   if (!fn)
690     return emitOpError() << "'" << fnAttr.getValue()
691                          << "' does not reference a valid function";
692 
693   // Verify that the operand and result types match the callee.
694   auto fnType = fn.getType();
695   if (fnType.getNumInputs() != getNumOperands())
696     return emitOpError("incorrect number of operands for callee");
697 
698   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
699     if (getOperand(i).getType() != fnType.getInput(i))
700       return emitOpError("operand type mismatch: expected operand type ")
701              << fnType.getInput(i) << ", but provided "
702              << getOperand(i).getType() << " for operand number " << i;
703 
704   if (fnType.getNumResults() != getNumResults())
705     return emitOpError("incorrect number of results for callee");
706 
707   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
708     if (getResult(i).getType() != fnType.getResult(i))
709       return emitOpError("result type mismatch");
710 
711   return success();
712 }
713 
getCalleeType()714 FunctionType CallOp::getCalleeType() {
715   return FunctionType::get(getOperandTypes(), getResultTypes(), getContext());
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // CallIndirectOp
720 //===----------------------------------------------------------------------===//
721 namespace {
722 /// Fold indirect calls that have a constant function as the callee operand.
723 struct SimplifyIndirectCallWithKnownCallee
724     : public OpRewritePattern<CallIndirectOp> {
725   using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
726 
matchAndRewrite__anon2fd8af2c0b11::SimplifyIndirectCallWithKnownCallee727   LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
728                                 PatternRewriter &rewriter) const override {
729     // Check that the callee is a constant callee.
730     SymbolRefAttr calledFn;
731     if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
732       return failure();
733 
734     // Replace with a direct call.
735     rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
736                                         indirectCall.getResultTypes(),
737                                         indirectCall.getArgOperands());
738     return success();
739   }
740 };
741 } // end anonymous namespace.
742 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)743 void CallIndirectOp::getCanonicalizationPatterns(
744     OwningRewritePatternList &results, MLIRContext *context) {
745   results.insert<SimplifyIndirectCallWithKnownCallee>(context);
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // General helpers for comparison ops
750 //===----------------------------------------------------------------------===//
751 
752 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)753 static Type getI1SameShape(Type type) {
754   auto i1Type = IntegerType::get(1, type.getContext());
755   if (auto tensorType = type.dyn_cast<RankedTensorType>())
756     return RankedTensorType::get(tensorType.getShape(), i1Type);
757   if (type.isa<UnrankedTensorType>())
758     return UnrankedTensorType::get(i1Type);
759   if (auto vectorType = type.dyn_cast<VectorType>())
760     return VectorType::get(vectorType.getShape(), i1Type);
761   return i1Type;
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // CmpIOp
766 //===----------------------------------------------------------------------===//
767 
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)768 static void buildCmpIOp(OpBuilder &build, OperationState &result,
769                         CmpIPredicate predicate, Value lhs, Value rhs) {
770   result.addOperands({lhs, rhs});
771   result.types.push_back(getI1SameShape(lhs.getType()));
772   result.addAttribute(CmpIOp::getPredicateAttrName(),
773                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
774 }
775 
776 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
777 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)778 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
779                              const APInt &rhs) {
780   switch (predicate) {
781   case CmpIPredicate::eq:
782     return lhs.eq(rhs);
783   case CmpIPredicate::ne:
784     return lhs.ne(rhs);
785   case CmpIPredicate::slt:
786     return lhs.slt(rhs);
787   case CmpIPredicate::sle:
788     return lhs.sle(rhs);
789   case CmpIPredicate::sgt:
790     return lhs.sgt(rhs);
791   case CmpIPredicate::sge:
792     return lhs.sge(rhs);
793   case CmpIPredicate::ult:
794     return lhs.ult(rhs);
795   case CmpIPredicate::ule:
796     return lhs.ule(rhs);
797   case CmpIPredicate::ugt:
798     return lhs.ugt(rhs);
799   case CmpIPredicate::uge:
800     return lhs.uge(rhs);
801   }
802   llvm_unreachable("unknown comparison predicate");
803 }
804 
805 // Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(CmpIPredicate predicate)806 static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
807   switch (predicate) {
808   case CmpIPredicate::eq:
809   case CmpIPredicate::sle:
810   case CmpIPredicate::sge:
811   case CmpIPredicate::ule:
812   case CmpIPredicate::uge:
813     return true;
814   case CmpIPredicate::ne:
815   case CmpIPredicate::slt:
816   case CmpIPredicate::sgt:
817   case CmpIPredicate::ult:
818   case CmpIPredicate::ugt:
819     return false;
820   }
821   llvm_unreachable("unknown comparison predicate");
822 }
823 
824 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)825 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
826   assert(operands.size() == 2 && "cmpi takes two arguments");
827 
828   if (lhs() == rhs()) {
829     auto val = applyCmpPredicateToEqualOperands(getPredicate());
830     return BoolAttr::get(val, getContext());
831   }
832 
833   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
834   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
835   if (!lhs || !rhs)
836     return {};
837 
838   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
839   return BoolAttr::get(val, getContext());
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // CmpFOp
844 //===----------------------------------------------------------------------===//
845 
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)846 static void buildCmpFOp(OpBuilder &build, OperationState &result,
847                         CmpFPredicate predicate, Value lhs, Value rhs) {
848   result.addOperands({lhs, rhs});
849   result.types.push_back(getI1SameShape(lhs.getType()));
850   result.addAttribute(CmpFOp::getPredicateAttrName(),
851                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
852 }
853 
854 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
855 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)856 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
857                              const APFloat &rhs) {
858   auto cmpResult = lhs.compare(rhs);
859   switch (predicate) {
860   case CmpFPredicate::AlwaysFalse:
861     return false;
862   case CmpFPredicate::OEQ:
863     return cmpResult == APFloat::cmpEqual;
864   case CmpFPredicate::OGT:
865     return cmpResult == APFloat::cmpGreaterThan;
866   case CmpFPredicate::OGE:
867     return cmpResult == APFloat::cmpGreaterThan ||
868            cmpResult == APFloat::cmpEqual;
869   case CmpFPredicate::OLT:
870     return cmpResult == APFloat::cmpLessThan;
871   case CmpFPredicate::OLE:
872     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
873   case CmpFPredicate::ONE:
874     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
875   case CmpFPredicate::ORD:
876     return cmpResult != APFloat::cmpUnordered;
877   case CmpFPredicate::UEQ:
878     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
879   case CmpFPredicate::UGT:
880     return cmpResult == APFloat::cmpUnordered ||
881            cmpResult == APFloat::cmpGreaterThan;
882   case CmpFPredicate::UGE:
883     return cmpResult == APFloat::cmpUnordered ||
884            cmpResult == APFloat::cmpGreaterThan ||
885            cmpResult == APFloat::cmpEqual;
886   case CmpFPredicate::ULT:
887     return cmpResult == APFloat::cmpUnordered ||
888            cmpResult == APFloat::cmpLessThan;
889   case CmpFPredicate::ULE:
890     return cmpResult == APFloat::cmpUnordered ||
891            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
892   case CmpFPredicate::UNE:
893     return cmpResult != APFloat::cmpEqual;
894   case CmpFPredicate::UNO:
895     return cmpResult == APFloat::cmpUnordered;
896   case CmpFPredicate::AlwaysTrue:
897     return true;
898   }
899   llvm_unreachable("unknown comparison predicate");
900 }
901 
902 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)903 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
904   assert(operands.size() == 2 && "cmpf takes two arguments");
905 
906   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
907   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
908 
909   // TODO: We could actually do some intelligent things if we know only one
910   // of the operands, but it's inf or nan.
911   if (!lhs || !rhs)
912     return {};
913 
914   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
915   return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // CondBranchOp
920 //===----------------------------------------------------------------------===//
921 
922 namespace {
923 /// cond_br true, ^bb1, ^bb2
924 ///  -> br ^bb1
925 /// cond_br false, ^bb1, ^bb2
926 ///  -> br ^bb2
927 ///
928 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
929   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
930 
matchAndRewrite__anon2fd8af2c0c11::SimplifyConstCondBranchPred931   LogicalResult matchAndRewrite(CondBranchOp condbr,
932                                 PatternRewriter &rewriter) const override {
933     if (matchPattern(condbr.getCondition(), m_NonZero())) {
934       // True branch taken.
935       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
936                                             condbr.getTrueOperands());
937       return success();
938     } else if (matchPattern(condbr.getCondition(), m_Zero())) {
939       // False branch taken.
940       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
941                                             condbr.getFalseOperands());
942       return success();
943     }
944     return failure();
945   }
946 };
947 
948 ///   cond_br %cond, ^bb1, ^bb2
949 /// ^bb1
950 ///   br ^bbN(...)
951 /// ^bb2
952 ///   br ^bbK(...)
953 ///
954 ///  -> cond_br %cond, ^bbN(...), ^bbK(...)
955 ///
956 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
957   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
958 
matchAndRewrite__anon2fd8af2c0c11::SimplifyPassThroughCondBranch959   LogicalResult matchAndRewrite(CondBranchOp condbr,
960                                 PatternRewriter &rewriter) const override {
961     Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
962     ValueRange trueDestOperands = condbr.getTrueOperands();
963     ValueRange falseDestOperands = condbr.getFalseOperands();
964     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
965 
966     // Try to collapse one of the current successors.
967     LogicalResult collapsedTrue =
968         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
969     LogicalResult collapsedFalse =
970         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
971     if (failed(collapsedTrue) && failed(collapsedFalse))
972       return failure();
973 
974     // Create a new branch with the collapsed successors.
975     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
976                                               trueDest, trueDestOperands,
977                                               falseDest, falseDestOperands);
978     return success();
979   }
980 };
981 
982 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
983 ///  -> br ^bb1(A, ..., N)
984 ///
985 /// cond_br %cond, ^bb1(A), ^bb1(B)
986 ///  -> %select = select %cond, A, B
987 ///     br ^bb1(%select)
988 ///
989 struct SimplifyCondBranchIdenticalSuccessors
990     : public OpRewritePattern<CondBranchOp> {
991   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
992 
matchAndRewrite__anon2fd8af2c0c11::SimplifyCondBranchIdenticalSuccessors993   LogicalResult matchAndRewrite(CondBranchOp condbr,
994                                 PatternRewriter &rewriter) const override {
995     // Check that the true and false destinations are the same and have the same
996     // operands.
997     Block *trueDest = condbr.trueDest();
998     if (trueDest != condbr.falseDest())
999       return failure();
1000 
1001     // If all of the operands match, no selects need to be generated.
1002     OperandRange trueOperands = condbr.getTrueOperands();
1003     OperandRange falseOperands = condbr.getFalseOperands();
1004     if (trueOperands == falseOperands) {
1005       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
1006       return success();
1007     }
1008 
1009     // Otherwise, if the current block is the only predecessor insert selects
1010     // for any mismatched branch operands.
1011     if (trueDest->getUniquePredecessor() != condbr->getBlock())
1012       return failure();
1013 
1014     // Generate a select for any operands that differ between the two.
1015     SmallVector<Value, 8> mergedOperands;
1016     mergedOperands.reserve(trueOperands.size());
1017     Value condition = condbr.getCondition();
1018     for (auto it : llvm::zip(trueOperands, falseOperands)) {
1019       if (std::get<0>(it) == std::get<1>(it))
1020         mergedOperands.push_back(std::get<0>(it));
1021       else
1022         mergedOperands.push_back(rewriter.create<SelectOp>(
1023             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1024     }
1025 
1026     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1027     return success();
1028   }
1029 };
1030 
1031 ///   ...
1032 ///   cond_br %cond, ^bb1(...), ^bb2(...)
1033 /// ...
1034 /// ^bb1: // has single predecessor
1035 ///   ...
1036 ///   cond_br %cond, ^bb3(...), ^bb4(...)
1037 ///
1038 /// ->
1039 ///
1040 ///   ...
1041 ///   cond_br %cond, ^bb1(...), ^bb2(...)
1042 /// ...
1043 /// ^bb1: // has single predecessor
1044 ///   ...
1045 ///   br ^bb3(...)
1046 ///
1047 struct SimplifyCondBranchFromCondBranchOnSameCondition
1048     : public OpRewritePattern<CondBranchOp> {
1049   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1050 
matchAndRewrite__anon2fd8af2c0c11::SimplifyCondBranchFromCondBranchOnSameCondition1051   LogicalResult matchAndRewrite(CondBranchOp condbr,
1052                                 PatternRewriter &rewriter) const override {
1053     // Check that we have a single distinct predecessor.
1054     Block *currentBlock = condbr->getBlock();
1055     Block *predecessor = currentBlock->getSinglePredecessor();
1056     if (!predecessor)
1057       return failure();
1058 
1059     // Check that the predecessor terminates with a conditional branch to this
1060     // block and that it branches on the same condition.
1061     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
1062     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
1063       return failure();
1064 
1065     // Fold this branch to an unconditional branch.
1066     if (currentBlock == predBranch.trueDest())
1067       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
1068                                             condbr.trueDestOperands());
1069     else
1070       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
1071                                             condbr.falseDestOperands());
1072     return success();
1073   }
1074 };
1075 } // end anonymous namespace
1076 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1077 void CondBranchOp::getCanonicalizationPatterns(
1078     OwningRewritePatternList &results, MLIRContext *context) {
1079   results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1080                  SimplifyCondBranchIdenticalSuccessors,
1081                  SimplifyCondBranchFromCondBranchOnSameCondition>(context);
1082 }
1083 
1084 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1085 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1086   assert(index < getNumSuccessors() && "invalid successor index");
1087   return index == trueIndex ? trueDestOperandsMutable()
1088                             : falseDestOperandsMutable();
1089 }
1090 
getSuccessorForOperands(ArrayRef<Attribute> operands)1091 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1092   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1093     return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1094   return nullptr;
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // Constant*Op
1099 //===----------------------------------------------------------------------===//
1100 
print(OpAsmPrinter & p,ConstantOp & op)1101 static void print(OpAsmPrinter &p, ConstantOp &op) {
1102   p << "constant ";
1103   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
1104 
1105   if (op.getAttrs().size() > 1)
1106     p << ' ';
1107   p << op.getValue();
1108 
1109   // If the value is a symbol reference, print a trailing type.
1110   if (op.getValue().isa<SymbolRefAttr>())
1111     p << " : " << op.getType();
1112 }
1113 
parseConstantOp(OpAsmParser & parser,OperationState & result)1114 static ParseResult parseConstantOp(OpAsmParser &parser,
1115                                    OperationState &result) {
1116   Attribute valueAttr;
1117   if (parser.parseOptionalAttrDict(result.attributes) ||
1118       parser.parseAttribute(valueAttr, "value", result.attributes))
1119     return failure();
1120 
1121   // If the attribute is a symbol reference, then we expect a trailing type.
1122   Type type;
1123   if (!valueAttr.isa<SymbolRefAttr>())
1124     type = valueAttr.getType();
1125   else if (parser.parseColonType(type))
1126     return failure();
1127 
1128   // Add the attribute type to the list.
1129   return parser.addTypeToList(type, result.types);
1130 }
1131 
1132 /// The constant op requires an attribute, and furthermore requires that it
1133 /// matches the return type.
verify(ConstantOp & op)1134 static LogicalResult verify(ConstantOp &op) {
1135   auto value = op.getValue();
1136   if (!value)
1137     return op.emitOpError("requires a 'value' attribute");
1138 
1139   auto type = op.getType();
1140   if (!value.getType().isa<NoneType>() && type != value.getType())
1141     return op.emitOpError() << "requires attribute's type (" << value.getType()
1142                             << ") to match op's return type (" << type << ")";
1143 
1144   if (type.isa<IndexType>() || value.isa<BoolAttr>())
1145     return success();
1146 
1147   if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1148     // If the type has a known bitwidth we verify that the value can be
1149     // represented with the given bitwidth.
1150     auto bitwidth = type.cast<IntegerType>().getWidth();
1151     auto intVal = intAttr.getValue();
1152     if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1153       return op.emitOpError("requires 'value' to be an integer within the "
1154                             "range of the integer result type");
1155     return success();
1156   }
1157 
1158   if (type.isa<FloatType>()) {
1159     if (!value.isa<FloatAttr>())
1160       return op.emitOpError("requires 'value' to be a floating point constant");
1161     return success();
1162   }
1163 
1164   if (type.isa<ShapedType>()) {
1165     if (!value.isa<ElementsAttr>())
1166       return op.emitOpError("requires 'value' to be a shaped constant");
1167     return success();
1168   }
1169 
1170   if (type.isa<FunctionType>()) {
1171     auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1172     if (!fnAttr)
1173       return op.emitOpError("requires 'value' to be a function reference");
1174 
1175     // Try to find the referenced function.
1176     auto fn =
1177         op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1178     if (!fn)
1179       return op.emitOpError()
1180              << "reference to undefined function '" << fnAttr.getValue() << "'";
1181 
1182     // Check that the referenced function has the correct type.
1183     if (fn.getType() != type)
1184       return op.emitOpError("reference to function with mismatched type");
1185 
1186     return success();
1187   }
1188 
1189   if (type.isa<NoneType>() && value.isa<UnitAttr>())
1190     return success();
1191 
1192   return op.emitOpError("unsupported 'value' attribute: ") << value;
1193 }
1194 
fold(ArrayRef<Attribute> operands)1195 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1196   assert(operands.empty() && "constant has no operands");
1197   return getValue();
1198 }
1199 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1200 void ConstantOp::getAsmResultNames(
1201     function_ref<void(Value, StringRef)> setNameFn) {
1202   Type type = getType();
1203   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1204     IntegerType intTy = type.dyn_cast<IntegerType>();
1205 
1206     // Sugar i1 constants with 'true' and 'false'.
1207     if (intTy && intTy.getWidth() == 1)
1208       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1209 
1210     // Otherwise, build a complex name with the value and type.
1211     SmallString<32> specialNameBuffer;
1212     llvm::raw_svector_ostream specialName(specialNameBuffer);
1213     specialName << 'c' << intCst.getInt();
1214     if (intTy)
1215       specialName << '_' << type;
1216     setNameFn(getResult(), specialName.str());
1217 
1218   } else if (type.isa<FunctionType>()) {
1219     setNameFn(getResult(), "f");
1220   } else {
1221     setNameFn(getResult(), "cst");
1222   }
1223 }
1224 
1225 /// Returns true if a constant operation can be built with the given value and
1226 /// result type.
isBuildableWith(Attribute value,Type type)1227 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1228   // SymbolRefAttr can only be used with a function type.
1229   if (value.isa<SymbolRefAttr>())
1230     return type.isa<FunctionType>();
1231   // Otherwise, the attribute must have the same type as 'type'.
1232   if (value.getType() != type)
1233     return false;
1234   // Finally, check that the attribute kind is handled.
1235   return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1236 }
1237 
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1238 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1239                             const APFloat &value, FloatType type) {
1240   ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1241 }
1242 
classof(Operation * op)1243 bool ConstantFloatOp::classof(Operation *op) {
1244   return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1245 }
1246 
1247 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1248 bool ConstantIntOp::classof(Operation *op) {
1249   return ConstantOp::classof(op) &&
1250          op->getResult(0).getType().isSignlessInteger();
1251 }
1252 
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1253 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1254                           int64_t value, unsigned width) {
1255   Type type = builder.getIntegerType(width);
1256   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1257 }
1258 
1259 /// Build a constant int op producing an integer with the specified type,
1260 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1261 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1262                           int64_t value, Type type) {
1263   assert(type.isSignlessInteger() &&
1264          "ConstantIntOp can only have signless integer type");
1265   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1266 }
1267 
1268 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1269 bool ConstantIndexOp::classof(Operation *op) {
1270   return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1271 }
1272 
build(OpBuilder & builder,OperationState & result,int64_t value)1273 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1274                             int64_t value) {
1275   Type type = builder.getIndexType();
1276   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1277 }
1278 
1279 //===----------------------------------------------------------------------===//
1280 // DeallocOp
1281 //===----------------------------------------------------------------------===//
1282 namespace {
1283 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
1284 /// by other Dealloc operations.
1285 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
1286   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1287 
matchAndRewrite__anon2fd8af2c0d11::SimplifyDeadDealloc1288   LogicalResult matchAndRewrite(DeallocOp dealloc,
1289                                 PatternRewriter &rewriter) const override {
1290     // Check that the memref operand's defining operation is an AllocOp.
1291     Value memref = dealloc.memref();
1292     if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
1293       return failure();
1294 
1295     // Check that all of the uses of the AllocOp are other DeallocOps.
1296     for (auto *user : memref.getUsers())
1297       if (!isa<DeallocOp>(user))
1298         return failure();
1299 
1300     // Erase the dealloc operation.
1301     rewriter.eraseOp(dealloc);
1302     return success();
1303   }
1304 };
1305 } // end anonymous namespace.
1306 
verify(DeallocOp op)1307 static LogicalResult verify(DeallocOp op) {
1308   if (!op.memref().getType().isa<MemRefType>())
1309     return op.emitOpError("operand must be a memref");
1310   return success();
1311 }
1312 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1313 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1314                                             MLIRContext *context) {
1315   results.insert<SimplifyDeadDealloc>(context);
1316 }
1317 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1318 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
1319                               SmallVectorImpl<OpFoldResult> &results) {
1320   /// dealloc(memrefcast) -> dealloc
1321   return foldMemRefCast(*this);
1322 }
1323 
1324 //===----------------------------------------------------------------------===//
1325 // DimOp
1326 //===----------------------------------------------------------------------===//
1327 
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,int64_t index)1328 void DimOp::build(OpBuilder &builder, OperationState &result,
1329                   Value memrefOrTensor, int64_t index) {
1330   auto loc = result.location;
1331   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
1332   build(builder, result, memrefOrTensor, indexValue);
1333 }
1334 
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,Value index)1335 void DimOp::build(OpBuilder &builder, OperationState &result,
1336                   Value memrefOrTensor, Value index) {
1337   auto indexTy = builder.getIndexType();
1338   build(builder, result, indexTy, memrefOrTensor, index);
1339 }
1340 
getConstantIndex()1341 Optional<int64_t> DimOp::getConstantIndex() {
1342   if (auto constantOp = index().getDefiningOp<ConstantOp>())
1343     return constantOp.getValue().cast<IntegerAttr>().getInt();
1344   return {};
1345 }
1346 
verify(DimOp op)1347 static LogicalResult verify(DimOp op) {
1348   // Assume unknown index to be in range.
1349   Optional<int64_t> index = op.getConstantIndex();
1350   if (!index.hasValue())
1351     return success();
1352 
1353   // Check that constant index is not knowingly out of range.
1354   auto type = op.memrefOrTensor().getType();
1355   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
1356     if (index.getValue() >= tensorType.getRank())
1357       return op.emitOpError("index is out of range");
1358   } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
1359     if (index.getValue() >= memrefType.getRank())
1360       return op.emitOpError("index is out of range");
1361   } else if (type.isa<UnrankedTensorType>() || type.isa<UnrankedMemRefType>()) {
1362     // Assume index to be in range.
1363   } else {
1364     llvm_unreachable("expected operand with tensor or memref type");
1365   }
1366 
1367   return success();
1368 }
1369 
fold(ArrayRef<Attribute> operands)1370 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1371   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
1372 
1373   // All forms of folding require a known index.
1374   if (!index)
1375     return {};
1376 
1377   auto argTy = memrefOrTensor().getType();
1378   // Fold if the shape extent along the given index is known.
1379   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
1380     // Folding for unranked types (UnrankedMemRefType, UnrankedTensorType) is
1381     // not supported.
1382     if (!shapedTy.hasRank())
1383       return {};
1384     if (!shapedTy.isDynamicDim(index.getInt())) {
1385       Builder builder(getContext());
1386       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
1387     }
1388   }
1389 
1390   Operation *definingOp = memrefOrTensor().getDefiningOp();
1391   // dim(tensor_load(memref)) -> dim(memref)
1392   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
1393     setOperand(0, tensorLoadOp.memref());
1394     return getResult();
1395   }
1396 
1397   // Fold dim to the operand of dynamic_tensor_from_elements.
1398   if (auto fromElements =
1399           dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) {
1400     auto resultType =
1401         fromElements.getResult().getType().cast<RankedTensorType>();
1402     // The case where the type encodes the size of the dimension is handled
1403     // above.
1404     assert(resultType.getShape()[index.getInt()] ==
1405            RankedTensorType::kDynamicSize);
1406 
1407     // Find the operand of the fromElements that corresponds to this index.
1408     auto dynExtents = fromElements.dynamicExtents().begin();
1409     for (auto dim : resultType.getShape().take_front(index.getInt()))
1410       if (dim == RankedTensorType::kDynamicSize)
1411         dynExtents++;
1412 
1413     return Value{*dynExtents};
1414   }
1415 
1416   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1417   auto memrefType = argTy.dyn_cast<MemRefType>();
1418   if (!memrefType)
1419     return {};
1420 
1421   // The size at the given index is now known to be a dynamic size of a memref.
1422   unsigned unsignedIndex = index.getValue().getZExtValue();
1423   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1424     return *(alloc.getDynamicSizes().begin() +
1425              memrefType.getDynamicDimIndex(unsignedIndex));
1426 
1427   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1428     return *(view.getDynamicSizes().begin() +
1429              memrefType.getDynamicDimIndex(unsignedIndex));
1430 
1431   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1432     assert(subview.isDynamicSize(unsignedIndex) &&
1433            "Expected dynamic subview size");
1434     return subview.getDynamicSize(unsignedIndex);
1435   }
1436 
1437   // dim(memrefcast) -> dim
1438   if (succeeded(foldMemRefCast(*this)))
1439     return getResult();
1440 
1441   return {};
1442 }
1443 
1444 namespace {
1445 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1446 /// operand.
1447 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1448   using OpRewritePattern<DimOp>::OpRewritePattern;
1449 
matchAndRewrite__anon2fd8af2c0e11::DimOfMemRefReshape1450   LogicalResult matchAndRewrite(DimOp dim,
1451                                 PatternRewriter &rewriter) const override {
1452     auto reshape = dim.memrefOrTensor().getDefiningOp<MemRefReshapeOp>();
1453 
1454     if (!reshape)
1455       return failure();
1456 
1457     // Place the load directly after the reshape to ensure that the shape memref
1458     // was not mutated.
1459     rewriter.setInsertionPointAfter(reshape);
1460     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
1461                                         llvm::makeArrayRef({dim.index()}));
1462     return success();
1463   }
1464 };
1465 } // end anonymous namespace.
1466 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1467 void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1468                                         MLIRContext *context) {
1469   results.insert<DimOfMemRefReshape>(context);
1470 }
1471 
1472 // ---------------------------------------------------------------------------
1473 // DmaStartOp
1474 // ---------------------------------------------------------------------------
1475 
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)1476 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1477                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1478                        ValueRange destIndices, Value numElements,
1479                        Value tagMemRef, ValueRange tagIndices, Value stride,
1480                        Value elementsPerStride) {
1481   result.addOperands(srcMemRef);
1482   result.addOperands(srcIndices);
1483   result.addOperands(destMemRef);
1484   result.addOperands(destIndices);
1485   result.addOperands({numElements, tagMemRef});
1486   result.addOperands(tagIndices);
1487   if (stride)
1488     result.addOperands({stride, elementsPerStride});
1489 }
1490 
print(OpAsmPrinter & p)1491 void DmaStartOp::print(OpAsmPrinter &p) {
1492   p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1493     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1494     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1495   if (isStrided())
1496     p << ", " << getStride() << ", " << getNumElementsPerStride();
1497 
1498   p.printOptionalAttrDict(getAttrs());
1499   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1500     << ", " << getTagMemRef().getType();
1501 }
1502 
1503 // Parse DmaStartOp.
1504 // Ex:
1505 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1506 //                       %tag[%index], %stride, %num_elt_per_stride :
1507 //                     : memref<3076 x f32, 0>,
1508 //                       memref<1024 x f32, 2>,
1509 //                       memref<1 x i32>
1510 //
parse(OpAsmParser & parser,OperationState & result)1511 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1512   OpAsmParser::OperandType srcMemRefInfo;
1513   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
1514   OpAsmParser::OperandType dstMemRefInfo;
1515   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
1516   OpAsmParser::OperandType numElementsInfo;
1517   OpAsmParser::OperandType tagMemrefInfo;
1518   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
1519   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1520 
1521   SmallVector<Type, 3> types;
1522   auto indexType = parser.getBuilder().getIndexType();
1523 
1524   // Parse and resolve the following list of operands:
1525   // *) source memref followed by its indices (in square brackets).
1526   // *) destination memref followed by its indices (in square brackets).
1527   // *) dma size in KiB.
1528   if (parser.parseOperand(srcMemRefInfo) ||
1529       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1530       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1531       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1532       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1533       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1534       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1535     return failure();
1536 
1537   // Parse optional stride and elements per stride.
1538   if (parser.parseTrailingOperandList(strideInfo))
1539     return failure();
1540 
1541   bool isStrided = strideInfo.size() == 2;
1542   if (!strideInfo.empty() && !isStrided) {
1543     return parser.emitError(parser.getNameLoc(),
1544                             "expected two stride related operands");
1545   }
1546 
1547   if (parser.parseColonTypeList(types))
1548     return failure();
1549   if (types.size() != 3)
1550     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1551 
1552   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1553       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1554       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1555       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1556       // size should be an index.
1557       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1558       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1559       // tag indices should be index.
1560       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1561     return failure();
1562 
1563   if (isStrided) {
1564     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1565       return failure();
1566   }
1567 
1568   return success();
1569 }
1570 
verify()1571 LogicalResult DmaStartOp::verify() {
1572   unsigned numOperands = getNumOperands();
1573 
1574   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1575   // the number of elements.
1576   if (numOperands < 4)
1577     return emitOpError("expected at least 4 operands");
1578 
1579   // Check types of operands. The order of these calls is important: the later
1580   // calls rely on some type properties to compute the operand position.
1581   // 1. Source memref.
1582   if (!getSrcMemRef().getType().isa<MemRefType>())
1583     return emitOpError("expected source to be of memref type");
1584   if (numOperands < getSrcMemRefRank() + 4)
1585     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1586                          << " operands";
1587   if (!getSrcIndices().empty() &&
1588       !llvm::all_of(getSrcIndices().getTypes(),
1589                     [](Type t) { return t.isIndex(); }))
1590     return emitOpError("expected source indices to be of index type");
1591 
1592   // 2. Destination memref.
1593   if (!getDstMemRef().getType().isa<MemRefType>())
1594     return emitOpError("expected destination to be of memref type");
1595   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1596   if (numOperands < numExpectedOperands)
1597     return emitOpError() << "expected at least " << numExpectedOperands
1598                          << " operands";
1599   if (!getDstIndices().empty() &&
1600       !llvm::all_of(getDstIndices().getTypes(),
1601                     [](Type t) { return t.isIndex(); }))
1602     return emitOpError("expected destination indices to be of index type");
1603 
1604   // 3. Number of elements.
1605   if (!getNumElements().getType().isIndex())
1606     return emitOpError("expected num elements to be of index type");
1607 
1608   // 4. Tag memref.
1609   if (!getTagMemRef().getType().isa<MemRefType>())
1610     return emitOpError("expected tag to be of memref type");
1611   numExpectedOperands += getTagMemRefRank();
1612   if (numOperands < numExpectedOperands)
1613     return emitOpError() << "expected at least " << numExpectedOperands
1614                          << " operands";
1615   if (!getTagIndices().empty() &&
1616       !llvm::all_of(getTagIndices().getTypes(),
1617                     [](Type t) { return t.isIndex(); }))
1618     return emitOpError("expected tag indices to be of index type");
1619 
1620   // DMAs from different memory spaces supported.
1621   if (getSrcMemorySpace() == getDstMemorySpace())
1622     return emitOpError("DMA should be between different memory spaces");
1623 
1624   // Optional stride-related operands must be either both present or both
1625   // absent.
1626   if (numOperands != numExpectedOperands &&
1627       numOperands != numExpectedOperands + 2)
1628     return emitOpError("incorrect number of operands");
1629 
1630   // 5. Strides.
1631   if (isStrided()) {
1632     if (!getStride().getType().isIndex() ||
1633         !getNumElementsPerStride().getType().isIndex())
1634       return emitOpError(
1635           "expected stride and num elements per stride to be of type index");
1636   }
1637 
1638   return success();
1639 }
1640 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1641 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1642                                SmallVectorImpl<OpFoldResult> &results) {
1643   /// dma_start(memrefcast) -> dma_start
1644   return foldMemRefCast(*this);
1645 }
1646 
1647 // ---------------------------------------------------------------------------
1648 // DmaWaitOp
1649 // ---------------------------------------------------------------------------
1650 
build(OpBuilder & builder,OperationState & result,Value tagMemRef,ValueRange tagIndices,Value numElements)1651 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1652                       Value tagMemRef, ValueRange tagIndices,
1653                       Value numElements) {
1654   result.addOperands(tagMemRef);
1655   result.addOperands(tagIndices);
1656   result.addOperands(numElements);
1657 }
1658 
print(OpAsmPrinter & p)1659 void DmaWaitOp::print(OpAsmPrinter &p) {
1660   p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
1661     << getNumElements();
1662   p.printOptionalAttrDict(getAttrs());
1663   p << " : " << getTagMemRef().getType();
1664 }
1665 
1666 // Parse DmaWaitOp.
1667 // Eg:
1668 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1669 //
parse(OpAsmParser & parser,OperationState & result)1670 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1671   OpAsmParser::OperandType tagMemrefInfo;
1672   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1673   Type type;
1674   auto indexType = parser.getBuilder().getIndexType();
1675   OpAsmParser::OperandType numElementsInfo;
1676 
1677   // Parse tag memref, its indices, and dma size.
1678   if (parser.parseOperand(tagMemrefInfo) ||
1679       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1680       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1681       parser.parseColonType(type) ||
1682       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1683       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1684       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1685     return failure();
1686 
1687   return success();
1688 }
1689 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1690 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1691                               SmallVectorImpl<OpFoldResult> &results) {
1692   /// dma_wait(memrefcast) -> dma_wait
1693   return foldMemRefCast(*this);
1694 }
1695 
verify()1696 LogicalResult DmaWaitOp::verify() {
1697   // Mandatory non-variadic operands are tag and the number of elements.
1698   if (getNumOperands() < 2)
1699     return emitOpError() << "expected at least 2 operands";
1700 
1701   // Check types of operands. The order of these calls is important: the later
1702   // calls rely on some type properties to compute the operand position.
1703   if (!getTagMemRef().getType().isa<MemRefType>())
1704     return emitOpError() << "expected tag to be of memref type";
1705 
1706   if (getNumOperands() != 2 + getTagMemRefRank())
1707     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1708                          << " operands";
1709 
1710   if (!getTagIndices().empty() &&
1711       !llvm::all_of(getTagIndices().getTypes(),
1712                     [](Type t) { return t.isIndex(); }))
1713     return emitOpError() << "expected tag indices to be of index type";
1714 
1715   if (!getNumElements().getType().isIndex())
1716     return emitOpError()
1717            << "expected the number of elements to be of index type";
1718 
1719   return success();
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // DynamicTensorFromElementsOp
1724 //===----------------------------------------------------------------------===//
1725 
parseDynamicTensorFromElementsOp(OpAsmParser & parser,OperationState & result)1726 static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser,
1727                                                     OperationState &result) {
1728   // Parse operands.
1729   SmallVector<OpAsmParser::OperandType, 4> dynamicExtents;
1730   Type indexTy = parser.getBuilder().getIndexType();
1731   if (parser.parseOperandList(dynamicExtents) ||
1732       parser.resolveOperands(dynamicExtents, indexTy, result.operands))
1733     return failure();
1734 
1735   // Parse body.
1736   Region *body = result.addRegion();
1737   if (parser.parseRegion(*body, {}, {}))
1738     return failure();
1739 
1740   // Parse result type.
1741   Type resultType;
1742   if (parser.parseOptionalAttrDict(result.attributes) ||
1743       parser.parseColonType(resultType))
1744     return failure();
1745   result.addTypes(resultType);
1746 
1747   return success();
1748 }
1749 
print(OpAsmPrinter & p,DynamicTensorFromElementsOp op)1750 static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) {
1751   p << "dynamic_tensor_from_elements " << op.dynamicExtents();
1752   p.printRegion(op.body());
1753   p.printOptionalAttrDict(op.getAttrs());
1754   p << " : " << op.getType();
1755 }
1756 
verify(DynamicTensorFromElementsOp op)1757 static LogicalResult verify(DynamicTensorFromElementsOp op) {
1758   // Ensure that the tensor type has as many dynamic dimensions as are specified
1759   // by the operands.
1760   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
1761   if (op.getNumOperands() != resultTy.getNumDynamicDims())
1762     return op.emitError("must have as many index operands as dynamic extents "
1763                         "in the result type");
1764 
1765   // Ensure that region arguments span the index space.
1766   if (!llvm::all_of(op.body().getArgumentTypes(),
1767                     [](Type ty) { return ty.isIndex(); }))
1768     return op.emitError("all body arguments must be index");
1769   if (op.body().getNumArguments() != resultTy.getRank())
1770     return op.emitError("must have one body argument per input dimension");
1771 
1772   // Ensure that the region yields an element of the right type.
1773   auto yieldOp =
1774       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
1775   if (yieldOp.value().getType() != resultTy.getElementType())
1776     return op.emitOpError(
1777         "body must be terminated with a `yield` operation of the tensor "
1778         "element type");
1779 
1780   return success();
1781 }
1782 
build(OpBuilder & b,OperationState & result,Type resultTy,ValueRange dynamicExtents,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)1783 void DynamicTensorFromElementsOp::build(
1784     OpBuilder &b, OperationState &result, Type resultTy,
1785     ValueRange dynamicExtents,
1786     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1787   build(b, result, resultTy, dynamicExtents);
1788 
1789   // Build and populate body.
1790   OpBuilder::InsertionGuard guard(b);
1791   Region *bodyRegion = result.regions.front().get();
1792   auto rank = resultTy.cast<RankedTensorType>().getRank();
1793   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1794   Block *bodyBlock =
1795       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
1796   bodyBuilder(b, result.location, bodyBlock->getArguments());
1797 }
1798 
1799 namespace {
1800 
1801 /// Canonicalizes dynamic_tensor_from_elements operations with a constant
1802 /// operand into the equivalent operation with the operand expressed in the
1803 /// result type, instead. We also insert a type cast to make sure that the
1804 /// resulting IR is still well-typed.
1805 struct StaticDynamicTensorFromElements
1806     : public OpRewritePattern<DynamicTensorFromElementsOp> {
1807   using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
1808 
matchAndRewrite__anon2fd8af2c1411::StaticDynamicTensorFromElements1809   LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
1810                                 PatternRewriter &rewriter) const final {
1811     auto resultType =
1812         tensorFromElements.getResult().getType().cast<RankedTensorType>();
1813 
1814     if (resultType.hasStaticShape())
1815       return failure();
1816 
1817     SmallVector<Value, 4> newOperands;
1818     SmallVector<int64_t, 4> newShape;
1819     auto operandsIt = tensorFromElements.dynamicExtents().begin();
1820 
1821     for (int64_t dim : resultType.getShape()) {
1822       if (dim != RankedTensorType::kDynamicSize) {
1823         newShape.push_back(dim);
1824         continue;
1825       }
1826       APInt index;
1827       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1828         newShape.push_back(RankedTensorType::kDynamicSize);
1829         newOperands.push_back(*operandsIt++);
1830         continue;
1831       }
1832       newShape.push_back(index.getSExtValue());
1833       operandsIt++;
1834     }
1835 
1836     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
1837       return failure();
1838 
1839     auto loc = tensorFromElements.getLoc();
1840     auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
1841         loc, RankedTensorType::get(newShape, resultType.getElementType()),
1842         newOperands);
1843     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
1844                                 newOp.body().begin());
1845     rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
1846                                               newOp);
1847     return success();
1848   }
1849 };
1850 
1851 /// Canonicalizes the pattern of the form
1852 ///
1853 /// %tensor = dynamic_tensor_from_elements %x {
1854 ///   ^bb0(%arg0: index):  // no predecessors
1855 ///   <computation>
1856 ///   yield %1 : index
1857 /// } : tensor<?xindex>
1858 /// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
1859 ///
1860 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1861 /// dynamic_tensor_from_elements operation has no side-effects.
1862 struct ExtractElementFromDynamicTensorFromElements
1863     : public OpRewritePattern<ExtractElementOp> {
1864   using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1865 
matchAndRewrite__anon2fd8af2c1411::ExtractElementFromDynamicTensorFromElements1866   LogicalResult matchAndRewrite(ExtractElementOp extract,
1867                                 PatternRewriter &rewriter) const final {
1868     auto tensorFromElements =
1869         extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
1870     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1871       return failure();
1872 
1873     BlockAndValueMapping mapping;
1874     Block *body = tensorFromElements.getBody();
1875     mapping.map(body->getArguments(), extract.indices());
1876     for (auto &op : body->without_terminator())
1877       rewriter.clone(op, mapping);
1878 
1879     auto yield = cast<YieldOp>(body->getTerminator());
1880 
1881     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
1882     return success();
1883   }
1884 };
1885 
1886 /// Canonicalizes the pattern of the form
1887 ///
1888 /// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
1889 /// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
1890 ///
1891 /// to
1892 ///
1893 /// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
1894 struct ExtractElementFromTensorCast
1895     : public OpRewritePattern<ExtractElementOp> {
1896   using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1897 
matchAndRewrite__anon2fd8af2c1411::ExtractElementFromTensorCast1898   LogicalResult matchAndRewrite(ExtractElementOp extract,
1899                                 PatternRewriter &rewriter) const final {
1900     auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
1901     if (!tensorCast)
1902       return failure();
1903 
1904     rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
1905                                                   extract.getIndices());
1906     return success();
1907   }
1908 };
1909 
1910 } // namespace
1911 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1912 void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
1913     OwningRewritePatternList &results, MLIRContext *context) {
1914   results.insert<ExtractElementFromDynamicTensorFromElements,
1915                  ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
1916       context);
1917 }
1918 
1919 //===----------------------------------------------------------------------===//
1920 // ExtractElementOp
1921 //===----------------------------------------------------------------------===//
1922 
verify(ExtractElementOp op)1923 static LogicalResult verify(ExtractElementOp op) {
1924   // Verify the # indices match if we have a ranked type.
1925   auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
1926   if (aggregateType.hasRank() &&
1927       aggregateType.getRank() != op.getNumOperands() - 1)
1928     return op.emitOpError("incorrect number of indices for extract_element");
1929 
1930   return success();
1931 }
1932 
fold(ArrayRef<Attribute> operands)1933 OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1934   assert(!operands.empty() && "extract_element takes at least one operand");
1935 
1936   // The aggregate operand must be a known constant.
1937   Attribute aggregate = operands.front();
1938   if (!aggregate)
1939     return {};
1940 
1941   // If this is a splat elements attribute, simply return the value. All of the
1942   // elements of a splat attribute are the same.
1943   if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
1944     return splatAggregate.getSplatValue();
1945 
1946   // Otherwise, collect the constant indices into the aggregate.
1947   SmallVector<uint64_t, 8> indices;
1948   for (Attribute indice : llvm::drop_begin(operands, 1)) {
1949     if (!indice || !indice.isa<IntegerAttr>())
1950       return {};
1951     indices.push_back(indice.cast<IntegerAttr>().getInt());
1952   }
1953 
1954   // If this is an elements attribute, query the value at the given indices.
1955   auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
1956   if (elementsAttr && elementsAttr.isValidIndex(indices))
1957     return elementsAttr.getValue(indices);
1958   return {};
1959 }
1960 
1961 //===----------------------------------------------------------------------===//
1962 // TensorFromElementsOp
1963 //===----------------------------------------------------------------------===//
1964 
build(OpBuilder & builder,OperationState & result,Type elementType,ValueRange elements)1965 void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
1966                                  Type elementType, ValueRange elements) {
1967   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
1968                                         elementType);
1969   result.addOperands(elements);
1970   result.addTypes(resultTy);
1971 }
1972 
build(OpBuilder & builder,OperationState & result,ValueRange elements)1973 void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
1974                                  ValueRange elements) {
1975   assert(!elements.empty() && "expected at least one element");
1976   build(builder, result, elements.front().getType(), elements);
1977 }
1978 
1979 namespace {
1980 
1981 // Canonicalizes the pattern of the form
1982 //
1983 // %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
1984 // %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
1985 //
1986 // to just %element.
1987 struct ExtractElementFromTensorFromElements
1988     : public OpRewritePattern<ExtractElementOp> {
1989   using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1990 
matchAndRewrite__anon2fd8af2c1511::ExtractElementFromTensorFromElements1991   LogicalResult matchAndRewrite(ExtractElementOp extract,
1992                                 PatternRewriter &rewriter) const final {
1993     if (extract.indices().size() != 1)
1994       return failure();
1995 
1996     auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
1997         extract.aggregate().getDefiningOp());
1998     if (tensorFromElements == nullptr)
1999       return failure();
2000 
2001     APInt index;
2002     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
2003       return failure();
2004     rewriter.replaceOp(extract,
2005                        tensorFromElements.getOperand(index.getZExtValue()));
2006     return success();
2007   }
2008 };
2009 
2010 } // namespace
2011 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2012 void TensorFromElementsOp::getCanonicalizationPatterns(
2013     OwningRewritePatternList &results, MLIRContext *context) {
2014   results.insert<ExtractElementFromTensorFromElements>(context);
2015 }
2016 
2017 //===----------------------------------------------------------------------===//
2018 // FPExtOp
2019 //===----------------------------------------------------------------------===//
2020 
areCastCompatible(Type a,Type b)2021 bool FPExtOp::areCastCompatible(Type a, Type b) {
2022   if (auto fa = a.dyn_cast<FloatType>())
2023     if (auto fb = b.dyn_cast<FloatType>())
2024       return fa.getWidth() < fb.getWidth();
2025   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2026 }
2027 
2028 //===----------------------------------------------------------------------===//
2029 // FPToSIOp
2030 //===----------------------------------------------------------------------===//
2031 
areCastCompatible(Type a,Type b)2032 bool FPToSIOp::areCastCompatible(Type a, Type b) {
2033   if (a.isa<FloatType>() && b.isSignlessInteger())
2034     return true;
2035   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2036 }
2037 
2038 //===----------------------------------------------------------------------===//
2039 // FPToUIOp
2040 //===----------------------------------------------------------------------===//
2041 
areCastCompatible(Type a,Type b)2042 bool FPToUIOp::areCastCompatible(Type a, Type b) {
2043   if (a.isa<FloatType>() && b.isSignlessInteger())
2044     return true;
2045   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2046 }
2047 
2048 //===----------------------------------------------------------------------===//
2049 // FPTruncOp
2050 //===----------------------------------------------------------------------===//
2051 
areCastCompatible(Type a,Type b)2052 bool FPTruncOp::areCastCompatible(Type a, Type b) {
2053   if (auto fa = a.dyn_cast<FloatType>())
2054     if (auto fb = b.dyn_cast<FloatType>())
2055       return fa.getWidth() > fb.getWidth();
2056   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // GlobalMemrefOp
2061 //===----------------------------------------------------------------------===//
2062 
printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter & p,GlobalMemrefOp op,TypeAttr type,Attribute initialValue)2063 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
2064                                                    GlobalMemrefOp op,
2065                                                    TypeAttr type,
2066                                                    Attribute initialValue) {
2067   p << type;
2068   if (!op.isExternal()) {
2069     p << " = ";
2070     if (op.isUninitialized())
2071       p << "uninitialized";
2072     else
2073       p.printAttributeWithoutType(initialValue);
2074   }
2075 }
2076 
2077 static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & initialValue)2078 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
2079                                        Attribute &initialValue) {
2080   Type type;
2081   if (parser.parseType(type))
2082     return failure();
2083 
2084   auto memrefType = type.dyn_cast<MemRefType>();
2085   if (!memrefType || !memrefType.hasStaticShape())
2086     return parser.emitError(parser.getNameLoc())
2087            << "type should be static shaped memref, but got " << type;
2088   typeAttr = TypeAttr::get(type);
2089 
2090   if (parser.parseOptionalEqual())
2091     return success();
2092 
2093   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
2094     initialValue = UnitAttr::get(parser.getBuilder().getContext());
2095     return success();
2096   }
2097 
2098   Type tensorType = getTensorTypeFromMemRefType(memrefType);
2099   if (parser.parseAttribute(initialValue, tensorType))
2100     return failure();
2101   if (!initialValue.isa<ElementsAttr>())
2102     return parser.emitError(parser.getNameLoc())
2103            << "initial value should be a unit or elements attribute";
2104   return success();
2105 }
2106 
verify(GlobalMemrefOp op)2107 static LogicalResult verify(GlobalMemrefOp op) {
2108   auto memrefType = op.type().dyn_cast<MemRefType>();
2109   if (!memrefType || !memrefType.hasStaticShape())
2110     return op.emitOpError("type should be static shaped memref, but got ")
2111            << op.type();
2112 
2113   // Verify that the initial value, if present, is either a unit attribute or
2114   // an elements attribute.
2115   if (op.initial_value().hasValue()) {
2116     Attribute initValue = op.initial_value().getValue();
2117     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
2118       return op.emitOpError("initial value should be a unit or elements "
2119                             "attribute, but got ")
2120              << initValue;
2121 
2122     // Check that the type of the initial value is compatible with the type of
2123     // the global variable.
2124     if (initValue.isa<ElementsAttr>()) {
2125       Type initType = initValue.getType();
2126       Type tensorType = getTensorTypeFromMemRefType(memrefType);
2127       if (initType != tensorType)
2128         return op.emitOpError("initial value expected to be of type ")
2129                << tensorType << ", but was of type " << initType;
2130     }
2131   }
2132 
2133   // TODO: verify visibility for declarations.
2134   return success();
2135 }
2136 
2137 //===----------------------------------------------------------------------===//
2138 // GetGlobalMemrefOp
2139 //===----------------------------------------------------------------------===//
2140 
2141 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)2142 GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2143   // Verify that the result type is same as the type of the referenced
2144   // global_memref op.
2145   auto global =
2146       symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
2147   if (!global)
2148     return emitOpError("'")
2149            << name() << "' does not reference a valid global memref";
2150 
2151   Type resultType = result().getType();
2152   if (global.type() != resultType)
2153     return emitOpError("result type ")
2154            << resultType << " does not match type " << global.type()
2155            << " of the global memref @" << name();
2156   return success();
2157 }
2158 
2159 //===----------------------------------------------------------------------===//
2160 // IndexCastOp
2161 //===----------------------------------------------------------------------===//
2162 
2163 // Index cast is applicable from index to integer and backwards.
areCastCompatible(Type a,Type b)2164 bool IndexCastOp::areCastCompatible(Type a, Type b) {
2165   if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
2166     auto aShaped = a.cast<ShapedType>();
2167     auto bShaped = b.cast<ShapedType>();
2168 
2169     return (aShaped.getShape() == bShaped.getShape()) &&
2170            areCastCompatible(aShaped.getElementType(),
2171                              bShaped.getElementType());
2172   }
2173 
2174   return (a.isIndex() && b.isSignlessInteger()) ||
2175          (a.isSignlessInteger() && b.isIndex());
2176 }
2177 
fold(ArrayRef<Attribute> cstOperands)2178 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
2179   // Fold IndexCast(IndexCast(x)) -> x
2180   auto cast = getOperand().getDefiningOp<IndexCastOp>();
2181   if (cast && cast.getOperand().getType() == getType())
2182     return cast.getOperand();
2183 
2184   // Fold IndexCast(constant) -> constant
2185   // A little hack because we go through int.  Otherwise, the size
2186   // of the constant might need to change.
2187   if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
2188     return IntegerAttr::get(getType(), value.getInt());
2189 
2190   return {};
2191 }
2192 
2193 //===----------------------------------------------------------------------===//
2194 // LoadOp
2195 //===----------------------------------------------------------------------===//
2196 
verify(LoadOp op)2197 static LogicalResult verify(LoadOp op) {
2198   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
2199     return op.emitOpError("incorrect number of indices for load");
2200   return success();
2201 }
2202 
fold(ArrayRef<Attribute> cstOperands)2203 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
2204   /// load(memrefcast) -> load
2205   if (succeeded(foldMemRefCast(*this)))
2206     return getResult();
2207   return OpFoldResult();
2208 }
2209 
2210 namespace {
2211 /// Fold a load on a tensor_to_memref operation into an extract_element on the
2212 /// corresponding tensor.
2213 struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
2214   using OpRewritePattern<LoadOp>::OpRewritePattern;
2215 
matchAndRewrite__anon2fd8af2c1611::LoadOfTensorToMemref2216   LogicalResult matchAndRewrite(LoadOp load,
2217                                 PatternRewriter &rewriter) const override {
2218     auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
2219     if (!tensorToMemref)
2220       return failure();
2221 
2222     rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
2223                                                   load.indices());
2224     return success();
2225   }
2226 };
2227 } // end anonymous namespace.
2228 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2229 void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2230                                          MLIRContext *context) {
2231   results.insert<LoadOfTensorToMemref>(context);
2232 }
2233 
2234 //===----------------------------------------------------------------------===//
2235 // MemRefCastOp
2236 //===----------------------------------------------------------------------===//
2237 
getViewSource()2238 Value MemRefCastOp::getViewSource() { return source(); }
2239 
areCastCompatible(Type a,Type b)2240 bool MemRefCastOp::areCastCompatible(Type a, Type b) {
2241   auto aT = a.dyn_cast<MemRefType>();
2242   auto bT = b.dyn_cast<MemRefType>();
2243 
2244   auto uaT = a.dyn_cast<UnrankedMemRefType>();
2245   auto ubT = b.dyn_cast<UnrankedMemRefType>();
2246 
2247   if (aT && bT) {
2248     if (aT.getElementType() != bT.getElementType())
2249       return false;
2250     if (aT.getAffineMaps() != bT.getAffineMaps()) {
2251       int64_t aOffset, bOffset;
2252       SmallVector<int64_t, 4> aStrides, bStrides;
2253       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
2254           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
2255           aStrides.size() != bStrides.size())
2256         return false;
2257 
2258       // Strides along a dimension/offset are compatible if the value in the
2259       // source memref is static and the value in the target memref is the
2260       // same. They are also compatible if either one is dynamic (see
2261       // description of MemRefCastOp for details).
2262       auto checkCompatible = [](int64_t a, int64_t b) {
2263         return (a == MemRefType::getDynamicStrideOrOffset() ||
2264                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
2265       };
2266       if (!checkCompatible(aOffset, bOffset))
2267         return false;
2268       for (auto aStride : enumerate(aStrides))
2269         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
2270           return false;
2271     }
2272     if (aT.getMemorySpace() != bT.getMemorySpace())
2273       return false;
2274 
2275     // They must have the same rank, and any specified dimensions must match.
2276     if (aT.getRank() != bT.getRank())
2277       return false;
2278 
2279     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
2280       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
2281       if (aDim != -1 && bDim != -1 && aDim != bDim)
2282         return false;
2283     }
2284     return true;
2285   } else {
2286     if (!aT && !uaT)
2287       return false;
2288     if (!bT && !ubT)
2289       return false;
2290     // Unranked to unranked casting is unsupported
2291     if (uaT && ubT)
2292       return false;
2293 
2294     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
2295     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
2296     if (aEltType != bEltType)
2297       return false;
2298 
2299     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
2300     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
2301     if (aMemSpace != bMemSpace)
2302       return false;
2303 
2304     return true;
2305   }
2306 
2307   return false;
2308 }
2309 
fold(ArrayRef<Attribute> operands)2310 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
2311   if (Value folded = impl::foldCastOp(*this))
2312     return folded;
2313   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
2314 }
2315 
2316 //===----------------------------------------------------------------------===//
2317 // MemRefReinterpretCastOp
2318 //===----------------------------------------------------------------------===//
2319 
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,int64_t staticOffset,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2320 void mlir::MemRefReinterpretCastOp::build(
2321     OpBuilder &b, OperationState &result, MemRefType resultType, Value source,
2322     int64_t staticOffset, ArrayRef<int64_t> staticSizes,
2323     ArrayRef<int64_t> staticStrides, ValueRange offset, ValueRange sizes,
2324     ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2325   build(b, result, resultType, source, offset, sizes, strides,
2326         b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes),
2327         b.getI64ArrayAttr(staticStrides));
2328   result.addAttributes(attrs);
2329 }
2330 
2331 /// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`,
2332 /// `staticSizes` and `staticStrides` are automatically filled with
2333 /// source-memref-rank sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,Value offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2334 void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2335                                           MemRefType resultType, Value source,
2336                                           Value offset, ValueRange sizes,
2337                                           ValueRange strides,
2338                                           ArrayRef<NamedAttribute> attrs) {
2339   unsigned rank = resultType.getRank();
2340   SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
2341   SmallVector<int64_t, 4> staticStridesVector(
2342       rank, ShapedType::kDynamicStrideOrOffset);
2343   build(b, result, resultType, source,
2344         /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector,
2345         staticStridesVector, offset, sizes, strides, attrs);
2346 }
2347 
2348 /// Print a memref_reinterpret_cast op of the form:
2349 /// ```
2350 ///   `memref_reinterpret_cast` ssa-name to
2351 ///       offset: `[` offset `]`
2352 ///       sizes: `[` size-list `]`
2353 ///       strides:`[` stride-list `]`
2354 ///   `:` any-memref-type to strided-memref-type
2355 /// ```
print(OpAsmPrinter & p,MemRefReinterpretCastOp op)2356 static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
2357   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
2358   p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
2359   p << op.source() << " ";
2360   printOffsetsSizesAndStrides(
2361       p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
2362       /*stridePrefix=*/", strides: ");
2363   p << ": " << op.source().getType() << " to " << op.getType();
2364 }
2365 
2366 /// Parse a memref_reinterpret_cast op of the form:
2367 /// ```
2368 ///   `memref_reinterpret_cast` ssa-name to
2369 ///       offset: `[` offset `]`
2370 ///       sizes: `[` size-list `]`
2371 ///       strides:`[` stride-list `]`
2372 ///   `:` any-memref-type to strided-memref-type
2373 /// ```
parseMemRefReinterpretCastOp(OpAsmParser & parser,OperationState & result)2374 static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
2375                                                 OperationState &result) {
2376   // Parse `operand`
2377   OpAsmParser::OperandType srcInfo;
2378   if (parser.parseOperand(srcInfo))
2379     return failure();
2380 
2381   auto parseOffsetPrefix = [](OpAsmParser &parser) {
2382     return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") ||
2383                    parser.parseColon());
2384   };
2385   auto parseSizePrefix = [](OpAsmParser &parser) {
2386     return failure(parser.parseComma() || parser.parseKeyword("sizes") ||
2387                    parser.parseColon());
2388   };
2389   auto parseStridePrefix = [](OpAsmParser &parser) {
2390     return failure(parser.parseComma() || parser.parseKeyword("strides") ||
2391                    parser.parseColon());
2392   };
2393 
2394   Type srcType, dstType;
2395   auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
2396     return failure(parser.parseOptionalAttrDict(result.attributes) ||
2397                    parser.parseColonType(srcType) ||
2398                    parser.parseKeywordType("to", dstType) ||
2399                    parser.resolveOperand(srcInfo, srcType, result.operands));
2400   };
2401   if (failed(parseOffsetsSizesAndStrides(parser, result,
2402                                          /*segmentSizes=*/{1}, // source memref
2403                                          preResolutionFn, parseOffsetPrefix,
2404                                          parseSizePrefix, parseStridePrefix)))
2405     return failure();
2406   return parser.addTypeToList(dstType, result.types);
2407 }
2408 
verify(MemRefReinterpretCastOp op)2409 static LogicalResult verify(MemRefReinterpretCastOp op) {
2410   // The source and result memrefs should be in the same memory space.
2411   auto srcType = op.source().getType().cast<BaseMemRefType>();
2412   auto resultType = op.getType().cast<MemRefType>();
2413   if (srcType.getMemorySpace() != resultType.getMemorySpace())
2414     return op.emitError("different memory spaces specified for source type ")
2415            << srcType << " and result memref type " << resultType;
2416   if (srcType.getElementType() != resultType.getElementType())
2417     return op.emitError("different element types specified for source type ")
2418            << srcType << " and result memref type " << resultType;
2419 
2420   // Match sizes in result memref type and in static_sizes attribute.
2421   for (auto &en :
2422        llvm::enumerate(llvm::zip(resultType.getShape(),
2423                                  extractFromI64ArrayAttr(op.static_sizes())))) {
2424     int64_t resultSize = std::get<0>(en.value());
2425     int64_t expectedSize = std::get<1>(en.value());
2426     if (resultSize != expectedSize)
2427       return op.emitError("expected result type with size = ")
2428              << expectedSize << " instead of " << resultSize
2429              << " in dim = " << en.index();
2430   }
2431 
2432   // Match offset and strides in static_offset and static_strides attributes if
2433   // result memref type has an affine map specified.
2434   if (!resultType.getAffineMaps().empty()) {
2435     int64_t resultOffset;
2436     SmallVector<int64_t, 4> resultStrides;
2437     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
2438       return failure();
2439 
2440     // Match offset in result memref type and in static_offsets attribute.
2441     int64_t expectedOffset =
2442         extractFromI64ArrayAttr(op.static_offsets()).front();
2443     if (resultOffset != expectedOffset)
2444       return op.emitError("expected result type with offset = ")
2445              << resultOffset << " instead of " << expectedOffset;
2446 
2447     // Match strides in result memref type and in static_strides attribute.
2448     for (auto &en : llvm::enumerate(llvm::zip(
2449              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
2450       int64_t resultStride = std::get<0>(en.value());
2451       int64_t expectedStride = std::get<1>(en.value());
2452       if (resultStride != expectedStride)
2453         return op.emitError("expected result type with stride = ")
2454                << expectedStride << " instead of " << resultStride
2455                << " in dim = " << en.index();
2456     }
2457   }
2458   return success();
2459 }
2460 
2461 //===----------------------------------------------------------------------===//
2462 // MemRefReshapeOp
2463 //===----------------------------------------------------------------------===//
2464 
verify(MemRefReshapeOp op)2465 static LogicalResult verify(MemRefReshapeOp op) {
2466   Type operandType = op.source().getType();
2467   Type resultType = op.result().getType();
2468 
2469   Type operandElementType = operandType.cast<ShapedType>().getElementType();
2470   Type resultElementType = resultType.cast<ShapedType>().getElementType();
2471   if (operandElementType != resultElementType)
2472     return op.emitOpError("element types of source and destination memref "
2473                           "types should be the same");
2474 
2475   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2476     if (!operandMemRefType.getAffineMaps().empty())
2477       return op.emitOpError(
2478           "source memref type should have identity affine map");
2479 
2480   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
2481   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2482   if (resultMemRefType) {
2483     if (!resultMemRefType.getAffineMaps().empty())
2484       return op.emitOpError(
2485           "result memref type should have identity affine map");
2486     if (shapeSize == ShapedType::kDynamicSize)
2487       return op.emitOpError("cannot use shape operand with dynamic length to "
2488                             "reshape to statically-ranked memref type");
2489     if (shapeSize != resultMemRefType.getRank())
2490       return op.emitOpError(
2491           "length of shape operand differs from the result's memref rank");
2492   }
2493   return success();
2494 }
2495 
2496 //===----------------------------------------------------------------------===//
2497 // MulFOp
2498 //===----------------------------------------------------------------------===//
2499 
fold(ArrayRef<Attribute> operands)2500 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
2501   return constFoldBinaryOp<FloatAttr>(
2502       operands, [](APFloat a, APFloat b) { return a * b; });
2503 }
2504 
2505 //===----------------------------------------------------------------------===//
2506 // MulIOp
2507 //===----------------------------------------------------------------------===//
2508 
fold(ArrayRef<Attribute> operands)2509 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
2510   /// muli(x, 0) -> 0
2511   if (matchPattern(rhs(), m_Zero()))
2512     return rhs();
2513   /// muli(x, 1) -> x
2514   if (matchPattern(rhs(), m_One()))
2515     return getOperand(0);
2516 
2517   // TODO: Handle the overflow case.
2518   return constFoldBinaryOp<IntegerAttr>(operands,
2519                                         [](APInt a, APInt b) { return a * b; });
2520 }
2521 
2522 //===----------------------------------------------------------------------===//
2523 // OrOp
2524 //===----------------------------------------------------------------------===//
2525 
fold(ArrayRef<Attribute> operands)2526 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
2527   /// or(x, 0) -> x
2528   if (matchPattern(rhs(), m_Zero()))
2529     return lhs();
2530   /// or(x,x) -> x
2531   if (lhs() == rhs())
2532     return rhs();
2533 
2534   return constFoldBinaryOp<IntegerAttr>(operands,
2535                                         [](APInt a, APInt b) { return a | b; });
2536 }
2537 
2538 //===----------------------------------------------------------------------===//
2539 // PrefetchOp
2540 //===----------------------------------------------------------------------===//
2541 
print(OpAsmPrinter & p,PrefetchOp op)2542 static void print(OpAsmPrinter &p, PrefetchOp op) {
2543   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
2544   p.printOperands(op.indices());
2545   p << ']' << ", " << (op.isWrite() ? "write" : "read");
2546   p << ", locality<" << op.localityHint();
2547   p << ">, " << (op.isDataCache() ? "data" : "instr");
2548   p.printOptionalAttrDict(
2549       op.getAttrs(),
2550       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
2551   p << " : " << op.getMemRefType();
2552 }
2553 
parsePrefetchOp(OpAsmParser & parser,OperationState & result)2554 static ParseResult parsePrefetchOp(OpAsmParser &parser,
2555                                    OperationState &result) {
2556   OpAsmParser::OperandType memrefInfo;
2557   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
2558   IntegerAttr localityHint;
2559   MemRefType type;
2560   StringRef readOrWrite, cacheType;
2561 
2562   auto indexTy = parser.getBuilder().getIndexType();
2563   auto i32Type = parser.getBuilder().getIntegerType(32);
2564   if (parser.parseOperand(memrefInfo) ||
2565       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2566       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2567       parser.parseComma() || parser.parseKeyword("locality") ||
2568       parser.parseLess() ||
2569       parser.parseAttribute(localityHint, i32Type, "localityHint",
2570                             result.attributes) ||
2571       parser.parseGreater() || parser.parseComma() ||
2572       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
2573       parser.resolveOperand(memrefInfo, type, result.operands) ||
2574       parser.resolveOperands(indexInfo, indexTy, result.operands))
2575     return failure();
2576 
2577   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2578     return parser.emitError(parser.getNameLoc(),
2579                             "rw specifier has to be 'read' or 'write'");
2580   result.addAttribute(
2581       PrefetchOp::getIsWriteAttrName(),
2582       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2583 
2584   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2585     return parser.emitError(parser.getNameLoc(),
2586                             "cache type has to be 'data' or 'instr'");
2587 
2588   result.addAttribute(
2589       PrefetchOp::getIsDataCacheAttrName(),
2590       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2591 
2592   return success();
2593 }
2594 
verify(PrefetchOp op)2595 static LogicalResult verify(PrefetchOp op) {
2596   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
2597     return op.emitOpError("too few indices");
2598 
2599   return success();
2600 }
2601 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2602 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2603                                SmallVectorImpl<OpFoldResult> &results) {
2604   // prefetch(memrefcast) -> prefetch
2605   return foldMemRefCast(*this);
2606 }
2607 
2608 //===----------------------------------------------------------------------===//
2609 // RankOp
2610 //===----------------------------------------------------------------------===//
2611 
fold(ArrayRef<Attribute> operands)2612 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2613   // Constant fold rank when the rank of the operand is known.
2614   auto type = getOperand().getType();
2615   if (auto shapedType = type.dyn_cast<ShapedType>())
2616     if (shapedType.hasRank())
2617       return IntegerAttr::get(IndexType::get(getContext()),
2618                               shapedType.getRank());
2619   return IntegerAttr();
2620 }
2621 
2622 //===----------------------------------------------------------------------===//
2623 // ReturnOp
2624 //===----------------------------------------------------------------------===//
2625 
verify(ReturnOp op)2626 static LogicalResult verify(ReturnOp op) {
2627   auto function = cast<FuncOp>(op->getParentOp());
2628 
2629   // The operand number and types must match the function signature.
2630   const auto &results = function.getType().getResults();
2631   if (op.getNumOperands() != results.size())
2632     return op.emitOpError("has ")
2633            << op.getNumOperands() << " operands, but enclosing function (@"
2634            << function.getName() << ") returns " << results.size();
2635 
2636   for (unsigned i = 0, e = results.size(); i != e; ++i)
2637     if (op.getOperand(i).getType() != results[i])
2638       return op.emitError()
2639              << "type of return operand " << i << " ("
2640              << op.getOperand(i).getType()
2641              << ") doesn't match function result type (" << results[i] << ")"
2642              << " in function @" << function.getName();
2643 
2644   return success();
2645 }
2646 
2647 //===----------------------------------------------------------------------===//
2648 // SelectOp
2649 //===----------------------------------------------------------------------===//
2650 
fold(ArrayRef<Attribute> operands)2651 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2652   auto condition = getCondition();
2653 
2654   // select true, %0, %1 => %0
2655   if (matchPattern(condition, m_One()))
2656     return getTrueValue();
2657 
2658   // select false, %0, %1 => %1
2659   if (matchPattern(condition, m_Zero()))
2660     return getFalseValue();
2661   return nullptr;
2662 }
2663 
print(OpAsmPrinter & p,SelectOp op)2664 static void print(OpAsmPrinter &p, SelectOp op) {
2665   p << "select " << op.getOperands();
2666   p.printOptionalAttrDict(op.getAttrs());
2667   p << " : ";
2668   if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
2669     p << condType << ", ";
2670   p << op.getType();
2671 }
2672 
parseSelectOp(OpAsmParser & parser,OperationState & result)2673 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
2674   Type conditionType, resultType;
2675   SmallVector<OpAsmParser::OperandType, 3> operands;
2676   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2677       parser.parseOptionalAttrDict(result.attributes) ||
2678       parser.parseColonType(resultType))
2679     return failure();
2680 
2681   // Check for the explicit condition type if this is a masked tensor or vector.
2682   if (succeeded(parser.parseOptionalComma())) {
2683     conditionType = resultType;
2684     if (parser.parseType(resultType))
2685       return failure();
2686   } else {
2687     conditionType = parser.getBuilder().getI1Type();
2688   }
2689 
2690   result.addTypes(resultType);
2691   return parser.resolveOperands(operands,
2692                                 {conditionType, resultType, resultType},
2693                                 parser.getNameLoc(), result.operands);
2694 }
2695 
verify(SelectOp op)2696 static LogicalResult verify(SelectOp op) {
2697   Type conditionType = op.getCondition().getType();
2698   if (conditionType.isSignlessInteger(1))
2699     return success();
2700 
2701   // If the result type is a vector or tensor, the type can be a mask with the
2702   // same elements.
2703   Type resultType = op.getType();
2704   if (!resultType.isa<TensorType, VectorType>())
2705     return op.emitOpError()
2706            << "expected condition to be a signless i1, but got "
2707            << conditionType;
2708   Type shapedConditionType = getI1SameShape(resultType);
2709   if (conditionType != shapedConditionType)
2710     return op.emitOpError()
2711            << "expected condition type to have the same shape "
2712               "as the result type, expected "
2713            << shapedConditionType << ", but got " << conditionType;
2714   return success();
2715 }
2716 
2717 //===----------------------------------------------------------------------===//
2718 // SignExtendIOp
2719 //===----------------------------------------------------------------------===//
2720 
verify(SignExtendIOp op)2721 static LogicalResult verify(SignExtendIOp op) {
2722   // Get the scalar type (which is either directly the type of the operand
2723   // or the vector's/tensor's element type.
2724   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2725   auto dstType = getElementTypeOrSelf(op.getType());
2726 
2727   // For now, index is forbidden for the source and the destination type.
2728   if (srcType.isa<IndexType>())
2729     return op.emitError() << srcType << " is not a valid operand type";
2730   if (dstType.isa<IndexType>())
2731     return op.emitError() << dstType << " is not a valid result type";
2732 
2733   if (srcType.cast<IntegerType>().getWidth() >=
2734       dstType.cast<IntegerType>().getWidth())
2735     return op.emitError("result type ")
2736            << dstType << " must be wider than operand type " << srcType;
2737 
2738   return success();
2739 }
2740 
2741 //===----------------------------------------------------------------------===//
2742 // SignedDivIOp
2743 //===----------------------------------------------------------------------===//
2744 
fold(ArrayRef<Attribute> operands)2745 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
2746   assert(operands.size() == 2 && "binary operation takes two operands");
2747 
2748   // Don't fold if it would overflow or if it requires a division by zero.
2749   bool overflowOrDiv0 = false;
2750   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2751     if (overflowOrDiv0 || !b) {
2752       overflowOrDiv0 = true;
2753       return a;
2754     }
2755     return a.sdiv_ov(b, overflowOrDiv0);
2756   });
2757 
2758   // Fold out division by one. Assumes all tensors of all ones are splats.
2759   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2760     if (rhs.getValue() == 1)
2761       return lhs();
2762   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2763     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2764       return lhs();
2765   }
2766 
2767   return overflowOrDiv0 ? Attribute() : result;
2768 }
2769 
2770 //===----------------------------------------------------------------------===//
2771 // SignedFloorDivIOp
2772 //===----------------------------------------------------------------------===//
2773 
signedCeilNonnegInputs(APInt a,APInt b,bool & overflow)2774 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
2775   // Returns (a-1)/b + 1
2776   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
2777   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
2778   return val.sadd_ov(one, overflow);
2779 }
2780 
fold(ArrayRef<Attribute> operands)2781 OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
2782   assert(operands.size() == 2 && "binary operation takes two operands");
2783 
2784   // Don't fold if it would overflow or if it requires a division by zero.
2785   bool overflowOrDiv0 = false;
2786   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2787     if (overflowOrDiv0 || !b) {
2788       overflowOrDiv0 = true;
2789       return a;
2790     }
2791     unsigned bits = a.getBitWidth();
2792     APInt zero = APInt::getNullValue(bits);
2793     if (a.sge(zero) && b.sgt(zero)) {
2794       // Both positive (or a is zero), return a / b.
2795       return a.sdiv_ov(b, overflowOrDiv0);
2796     } else if (a.sle(zero) && b.slt(zero)) {
2797       // Both negative (or a is zero), return -a / -b.
2798       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2799       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2800       return posA.sdiv_ov(posB, overflowOrDiv0);
2801     } else if (a.slt(zero) && b.sgt(zero)) {
2802       // A is negative, b is positive, return - ceil(-a, b).
2803       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2804       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
2805       return zero.ssub_ov(ceil, overflowOrDiv0);
2806     } else {
2807       // A is positive, b is negative, return - ceil(a, -b).
2808       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2809       APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
2810       return zero.ssub_ov(ceil, overflowOrDiv0);
2811     }
2812   });
2813 
2814   // Fold out floor division by one. Assumes all tensors of all ones are
2815   // splats.
2816   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2817     if (rhs.getValue() == 1)
2818       return lhs();
2819   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2820     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2821       return lhs();
2822   }
2823 
2824   return overflowOrDiv0 ? Attribute() : result;
2825 }
2826 
2827 //===----------------------------------------------------------------------===//
2828 // SignedCeilDivIOp
2829 //===----------------------------------------------------------------------===//
2830 
fold(ArrayRef<Attribute> operands)2831 OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
2832   assert(operands.size() == 2 && "binary operation takes two operands");
2833 
2834   // Don't fold if it would overflow or if it requires a division by zero.
2835   bool overflowOrDiv0 = false;
2836   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2837     if (overflowOrDiv0 || !b) {
2838       overflowOrDiv0 = true;
2839       return a;
2840     }
2841     unsigned bits = a.getBitWidth();
2842     APInt zero = APInt::getNullValue(bits);
2843     if (a.sgt(zero) && b.sgt(zero)) {
2844       // Both positive, return ceil(a, b).
2845       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
2846     } else if (a.slt(zero) && b.slt(zero)) {
2847       // Both negative, return ceil(-a, -b).
2848       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2849       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2850       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
2851     } else if (a.slt(zero) && b.sgt(zero)) {
2852       // A is negative, b is positive, return - ( -a / b).
2853       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
2854       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
2855       return zero.ssub_ov(div, overflowOrDiv0);
2856     } else {
2857       // A is positive (or zero), b is negative, return - (a / -b).
2858       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
2859       APInt div = a.sdiv_ov(posB, overflowOrDiv0);
2860       return zero.ssub_ov(div, overflowOrDiv0);
2861     }
2862   });
2863 
2864   // Fold out floor division by one. Assumes all tensors of all ones are
2865   // splats.
2866   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2867     if (rhs.getValue() == 1)
2868       return lhs();
2869   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2870     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2871       return lhs();
2872   }
2873 
2874   return overflowOrDiv0 ? Attribute() : result;
2875 }
2876 
2877 //===----------------------------------------------------------------------===//
2878 // SignedRemIOp
2879 //===----------------------------------------------------------------------===//
2880 
fold(ArrayRef<Attribute> operands)2881 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
2882   assert(operands.size() == 2 && "remi_signed takes two operands");
2883 
2884   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2885   if (!rhs)
2886     return {};
2887   auto rhsValue = rhs.getValue();
2888 
2889   // x % 1 = 0
2890   if (rhsValue.isOneValue())
2891     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2892 
2893   // Don't fold if it requires division by zero.
2894   if (rhsValue.isNullValue())
2895     return {};
2896 
2897   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2898   if (!lhs)
2899     return {};
2900   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
2901 }
2902 
2903 //===----------------------------------------------------------------------===//
2904 // SIToFPOp
2905 //===----------------------------------------------------------------------===//
2906 
2907 // sitofp is applicable from integer types to float types.
areCastCompatible(Type a,Type b)2908 bool SIToFPOp::areCastCompatible(Type a, Type b) {
2909   if (a.isSignlessInteger() && b.isa<FloatType>())
2910     return true;
2911   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2912 }
2913 
2914 //===----------------------------------------------------------------------===//
2915 // SplatOp
2916 //===----------------------------------------------------------------------===//
2917 
verify(SplatOp op)2918 static LogicalResult verify(SplatOp op) {
2919   // TODO: we could replace this by a trait.
2920   if (op.getOperand().getType() !=
2921       op.getType().cast<ShapedType>().getElementType())
2922     return op.emitError("operand should be of elemental type of result type");
2923 
2924   return success();
2925 }
2926 
2927 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)2928 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2929   assert(operands.size() == 1 && "splat takes one operand");
2930 
2931   auto constOperand = operands.front();
2932   if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
2933     return {};
2934 
2935   auto shapedType = getType().cast<ShapedType>();
2936   assert(shapedType.getElementType() == constOperand.getType() &&
2937          "incorrect input attribute type for folding");
2938 
2939   // SplatElementsAttr::get treats single value for second arg as being a splat.
2940   return SplatElementsAttr::get(shapedType, {constOperand});
2941 }
2942 
2943 //===----------------------------------------------------------------------===//
2944 // StoreOp
2945 //===----------------------------------------------------------------------===//
2946 
verify(StoreOp op)2947 static LogicalResult verify(StoreOp op) {
2948   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
2949     return op.emitOpError("store index operand count not equal to memref rank");
2950 
2951   return success();
2952 }
2953 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2954 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2955                             SmallVectorImpl<OpFoldResult> &results) {
2956   /// store(memrefcast) -> store
2957   return foldMemRefCast(*this);
2958 }
2959 
2960 //===----------------------------------------------------------------------===//
2961 // SubFOp
2962 //===----------------------------------------------------------------------===//
2963 
fold(ArrayRef<Attribute> operands)2964 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
2965   return constFoldBinaryOp<FloatAttr>(
2966       operands, [](APFloat a, APFloat b) { return a - b; });
2967 }
2968 
2969 //===----------------------------------------------------------------------===//
2970 // SubIOp
2971 //===----------------------------------------------------------------------===//
2972 
fold(ArrayRef<Attribute> operands)2973 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
2974   // subi(x,x) -> 0
2975   if (getOperand(0) == getOperand(1))
2976     return Builder(getContext()).getZeroAttr(getType());
2977   // subi(x,0) -> x
2978   if (matchPattern(rhs(), m_Zero()))
2979     return lhs();
2980 
2981   return constFoldBinaryOp<IntegerAttr>(operands,
2982                                         [](APInt a, APInt b) { return a - b; });
2983 }
2984 
2985 //===----------------------------------------------------------------------===//
2986 // UIToFPOp
2987 //===----------------------------------------------------------------------===//
2988 
2989 // uitofp is applicable from integer types to float types.
areCastCompatible(Type a,Type b)2990 bool UIToFPOp::areCastCompatible(Type a, Type b) {
2991   if (a.isSignlessInteger() && b.isa<FloatType>())
2992     return true;
2993   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2994 }
2995 
2996 //===----------------------------------------------------------------------===//
2997 // SubViewOp
2998 //===----------------------------------------------------------------------===//
2999 
3000 namespace {
3001 /// Helpers to write more idiomatic operations.
3002 namespace saturated_arith {
3003 struct Wrapper {
Wrapper__anon2fd8af2c2411::saturated_arith::Wrapper3004   explicit Wrapper(int64_t v) : v(v) {}
operator int64_t__anon2fd8af2c2411::saturated_arith::Wrapper3005   operator int64_t() { return v; }
3006   int64_t v;
3007 };
operator +(Wrapper a,int64_t b)3008 Wrapper operator+(Wrapper a, int64_t b) {
3009   if (ShapedType::isDynamicStrideOrOffset(a) ||
3010       ShapedType::isDynamicStrideOrOffset(b))
3011     return Wrapper(ShapedType::kDynamicStrideOrOffset);
3012   return Wrapper(a.v + b);
3013 }
operator *(Wrapper a,int64_t b)3014 Wrapper operator*(Wrapper a, int64_t b) {
3015   if (ShapedType::isDynamicStrideOrOffset(a) ||
3016       ShapedType::isDynamicStrideOrOffset(b))
3017     return Wrapper(ShapedType::kDynamicStrideOrOffset);
3018   return Wrapper(a.v * b);
3019 }
3020 } // end namespace saturated_arith
3021 } // end namespace
3022 
3023 /// A subview result type can be fully inferred from the source type and the
3024 /// static representation of offsets, sizes and strides. Special sentinels
3025 /// encode the dynamic case.
inferResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)3026 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
3027                                 ArrayRef<int64_t> staticOffsets,
3028                                 ArrayRef<int64_t> staticSizes,
3029                                 ArrayRef<int64_t> staticStrides) {
3030   unsigned rank = sourceMemRefType.getRank();
3031   (void)rank;
3032   assert(staticOffsets.size() == rank &&
3033          "unexpected staticOffsets size mismatch");
3034   assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
3035   assert(staticStrides.size() == rank &&
3036          "unexpected staticStrides size mismatch");
3037 
3038   // Extract source offset and strides.
3039   int64_t sourceOffset;
3040   SmallVector<int64_t, 4> sourceStrides;
3041   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
3042   assert(succeeded(res) && "SubViewOp expected strided memref type");
3043   (void)res;
3044 
3045   // Compute target offset whose value is:
3046   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
3047   int64_t targetOffset = sourceOffset;
3048   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
3049     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
3050     using namespace saturated_arith;
3051     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
3052   }
3053 
3054   // Compute target stride whose value is:
3055   //   `sourceStrides_i * staticStrides_i`.
3056   SmallVector<int64_t, 4> targetStrides;
3057   targetStrides.reserve(staticOffsets.size());
3058   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
3059     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3060     using namespace saturated_arith;
3061     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
3062   }
3063 
3064   // The type is now known.
3065   return MemRefType::get(
3066       staticSizes, sourceMemRefType.getElementType(),
3067       makeStridedLinearLayoutMap(targetStrides, targetOffset,
3068                                  sourceMemRefType.getContext()),
3069       sourceMemRefType.getMemorySpace());
3070 }
3071 
3072 /// Print a subview op of the form:
3073 /// ```
3074 ///   `subview` ssa-name
3075 ///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3076 ///     `:` strided-memref-type `to` strided-memref-type
3077 /// ```
print(OpAsmPrinter & p,SubViewOp op)3078 static void print(OpAsmPrinter &p, SubViewOp op) {
3079   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
3080   p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
3081   p << op.source();
3082   printOffsetsSizesAndStrides(p, op);
3083   p << " : " << op.getSourceType() << " to " << op.getType();
3084 }
3085 
3086 /// Parse a subview op of the form:
3087 /// ```
3088 ///   `subview` ssa-name
3089 ///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3090 ///     `:` strided-memref-type `to` strided-memref-type
3091 /// ```
parseSubViewOp(OpAsmParser & parser,OperationState & result)3092 static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
3093   OpAsmParser::OperandType srcInfo;
3094   if (parser.parseOperand(srcInfo))
3095     return failure();
3096   Type srcType, dstType;
3097   auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
3098     return failure(parser.parseOptionalAttrDict(result.attributes) ||
3099                    parser.parseColonType(srcType) ||
3100                    parser.parseKeywordType("to", dstType) ||
3101                    parser.resolveOperand(srcInfo, srcType, result.operands));
3102   };
3103 
3104   if (failed(parseOffsetsSizesAndStrides(parser, result,
3105                                          /*segmentSizes=*/{1}, // source memref
3106                                          preResolutionFn)))
3107     return failure();
3108   return parser.addTypeToList(dstType, result.types);
3109 }
3110 
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3111 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3112                             ArrayRef<int64_t> staticOffsets,
3113                             ArrayRef<int64_t> staticSizes,
3114                             ArrayRef<int64_t> staticStrides, ValueRange offsets,
3115                             ValueRange sizes, ValueRange strides,
3116                             ArrayRef<NamedAttribute> attrs) {
3117   auto sourceMemRefType = source.getType().cast<MemRefType>();
3118   auto resultType = inferResultType(sourceMemRefType, staticOffsets,
3119                                     staticSizes, staticStrides);
3120   build(b, result, resultType, source, offsets, sizes, strides,
3121         b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3122         b.getI64ArrayAttr(staticStrides));
3123   result.addAttributes(attrs);
3124 }
3125 
3126 /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
3127 /// and `staticStrides` are automatically filled with source-memref-rank
3128 /// sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3129 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3130                             ValueRange offsets, ValueRange sizes,
3131                             ValueRange strides,
3132                             ArrayRef<NamedAttribute> attrs) {
3133   auto sourceMemRefType = source.getType().cast<MemRefType>();
3134   unsigned rank = sourceMemRefType.getRank();
3135   SmallVector<int64_t, 4> staticOffsetsVector;
3136   staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3137   SmallVector<int64_t, 4> staticSizesVector;
3138   staticSizesVector.assign(rank, ShapedType::kDynamicSize);
3139   SmallVector<int64_t, 4> staticStridesVector;
3140   staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3141   build(b, result, source, staticOffsetsVector, staticSizesVector,
3142         staticStridesVector, offsets, sizes, strides, attrs);
3143 }
3144 
3145 /// Build a SubViewOp as above but with custom result type.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3146 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
3147                             MemRefType resultType, Value source,
3148                             ArrayRef<int64_t> staticOffsets,
3149                             ArrayRef<int64_t> staticSizes,
3150                             ArrayRef<int64_t> staticStrides, ValueRange offsets,
3151                             ValueRange sizes, ValueRange strides,
3152                             ArrayRef<NamedAttribute> attrs) {
3153   build(b, result, resultType, source, offsets, sizes, strides,
3154         b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3155         b.getI64ArrayAttr(staticStrides));
3156   result.addAttributes(attrs);
3157 }
3158 
3159 /// Build a SubViewOp as above but with custom result type.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3160 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
3161                             MemRefType resultType, Value source,
3162                             ValueRange offsets, ValueRange sizes,
3163                             ValueRange strides,
3164                             ArrayRef<NamedAttribute> attrs) {
3165   auto sourceMemRefType = source.getType().cast<MemRefType>();
3166   unsigned rank = sourceMemRefType.getRank();
3167   SmallVector<int64_t, 4> staticOffsetsVector;
3168   staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3169   SmallVector<int64_t, 4> staticSizesVector;
3170   staticSizesVector.assign(rank, ShapedType::kDynamicSize);
3171   SmallVector<int64_t, 4> staticStridesVector;
3172   staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
3173   build(b, result, resultType, source, staticOffsetsVector, staticSizesVector,
3174         staticStridesVector, offsets, sizes, strides, attrs);
3175 }
3176 
3177 /// For ViewLikeOpInterface.
getViewSource()3178 Value SubViewOp::getViewSource() { return source(); }
3179 
3180 llvm::Optional<SmallVector<bool, 4>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> reducedShape)3181 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
3182                                ArrayRef<int64_t> reducedShape) {
3183   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
3184   SmallVector<bool, 4> mask(originalRank);
3185   unsigned reducedIdx = 0;
3186   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
3187     // Skip matching dims greedily.
3188     mask[originalIdx] =
3189         (reducedIdx < reducedRank) &&
3190         (originalShape[originalIdx] == reducedShape[reducedIdx]);
3191     if (mask[originalIdx])
3192       reducedIdx++;
3193     // 1 is the only non-matching allowed.
3194     else if (originalShape[originalIdx] != 1)
3195       return {};
3196   }
3197 
3198   if (reducedIdx != reducedRank)
3199     return {};
3200 
3201   return mask;
3202 }
3203 
3204 enum SubViewVerificationResult {
3205   Success,
3206   RankTooLarge,
3207   SizeMismatch,
3208   StrideMismatch,
3209   ElemTypeMismatch,
3210   MemSpaceMismatch,
3211   AffineMapMismatch
3212 };
3213 
3214 /// Checks if `original` Type type can be rank reduced to `reduced` type.
3215 /// This function is slight variant of `is subsequence` algorithm where
3216 /// not matching dimension must be 1.
isRankReducedType(Type originalType,Type reducedType)3217 static SubViewVerificationResult isRankReducedType(Type originalType,
3218                                                    Type reducedType) {
3219   if (originalType == reducedType)
3220     return SubViewVerificationResult::Success;
3221   if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
3222     return SubViewVerificationResult::Success;
3223   if (originalType.isa<RankedTensorType>() &&
3224       !reducedType.isa<RankedTensorType>())
3225     return SubViewVerificationResult::Success;
3226   if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
3227     return SubViewVerificationResult::Success;
3228 
3229   ShapedType originalShapedType = originalType.cast<ShapedType>();
3230   ShapedType reducedShapedType = reducedType.cast<ShapedType>();
3231 
3232   // Rank and size logic is valid for all ShapedTypes.
3233   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
3234   ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
3235   unsigned originalRank = originalShape.size(),
3236            reducedRank = reducedShape.size();
3237   if (reducedRank > originalRank)
3238     return SubViewVerificationResult::RankTooLarge;
3239 
3240   auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
3241 
3242   // Sizes cannot be matched in case empty vector is returned.
3243   if (!optionalMask.hasValue())
3244     return SubViewVerificationResult::SizeMismatch;
3245 
3246   // We are done for the tensor case.
3247   if (originalType.isa<RankedTensorType>())
3248     return SubViewVerificationResult::Success;
3249 
3250   // Strided layout logic is relevant for MemRefType only.
3251   MemRefType original = originalType.cast<MemRefType>();
3252   MemRefType reduced = reducedType.cast<MemRefType>();
3253   MLIRContext *c = original.getContext();
3254   int64_t originalOffset, reducedOffset;
3255   SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
3256   SmallVector<bool, 4> keepMask = optionalMask.getValue();
3257   getStridesAndOffset(original, originalStrides, originalOffset);
3258   getStridesAndOffset(reduced, reducedStrides, reducedOffset);
3259 
3260   // Filter strides based on the mask and check that they are the same
3261   // as reduced ones.
3262   unsigned reducedIdx = 0;
3263   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
3264     if (keepMask[originalIdx]) {
3265       if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
3266         return SubViewVerificationResult::StrideMismatch;
3267       keepStrides.push_back(originalStrides[originalIdx]);
3268     }
3269   }
3270 
3271   if (original.getElementType() != reduced.getElementType())
3272     return SubViewVerificationResult::ElemTypeMismatch;
3273 
3274   if (original.getMemorySpace() != reduced.getMemorySpace())
3275     return SubViewVerificationResult::MemSpaceMismatch;
3276 
3277   auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
3278   if (!reduced.getAffineMaps().empty() &&
3279       reducedMap != reduced.getAffineMaps().front())
3280     return SubViewVerificationResult::AffineMapMismatch;
3281 
3282   return SubViewVerificationResult::Success;
3283 }
3284 
3285 template <typename OpTy>
produceSubViewErrorMsg(SubViewVerificationResult result,OpTy op,Type expectedType)3286 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
3287                                             OpTy op, Type expectedType) {
3288   auto memrefType = expectedType.cast<ShapedType>();
3289   switch (result) {
3290   case SubViewVerificationResult::Success:
3291     return success();
3292   case SubViewVerificationResult::RankTooLarge:
3293     return op.emitError("expected result rank to be smaller or equal to ")
3294            << "the source rank.";
3295   case SubViewVerificationResult::SizeMismatch:
3296     return op.emitError("expected result type to be ")
3297            << expectedType
3298            << " or a rank-reduced version. (mismatch of result sizes)";
3299   case SubViewVerificationResult::StrideMismatch:
3300     return op.emitError("expected result type to be ")
3301            << expectedType
3302            << " or a rank-reduced version. (mismatch of result strides)";
3303   case SubViewVerificationResult::ElemTypeMismatch:
3304     return op.emitError("expected result element type to be ")
3305            << memrefType.getElementType();
3306   case SubViewVerificationResult::MemSpaceMismatch:
3307     return op.emitError("expected result and source memory spaces to match.");
3308   case SubViewVerificationResult::AffineMapMismatch:
3309     return op.emitError("expected result type to be ")
3310            << expectedType
3311            << " or a rank-reduced version. (mismatch of result affine map)";
3312   }
3313   llvm_unreachable("unexpected subview verification result");
3314 }
3315 
3316 /// Verifier for SubViewOp.
verify(SubViewOp op)3317 static LogicalResult verify(SubViewOp op) {
3318   MemRefType baseType = op.getSourceType();
3319   MemRefType subViewType = op.getType();
3320 
3321   // The base memref and the view memref should be in the same memory space.
3322   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3323     return op.emitError("different memory spaces specified for base memref "
3324                         "type ")
3325            << baseType << " and subview memref type " << subViewType;
3326 
3327   // Verify that the base memref type has a strided layout map.
3328   if (!isStrided(baseType))
3329     return op.emitError("base type ") << baseType << " is not strided";
3330 
3331   // Verify result type against inferred type.
3332   auto expectedType = SubViewOp::inferResultType(
3333       baseType, extractFromI64ArrayAttr(op.static_offsets()),
3334       extractFromI64ArrayAttr(op.static_sizes()),
3335       extractFromI64ArrayAttr(op.static_strides()));
3336 
3337   auto result = isRankReducedType(expectedType, subViewType);
3338   return produceSubViewErrorMsg(result, op, expectedType);
3339 }
3340 
operator <<(raw_ostream & os,Range & range)3341 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
3342   return os << "range " << range.offset << ":" << range.size << ":"
3343             << range.stride;
3344 }
3345 
3346 /// Return the list of Range (i.e. offset, size, stride). Each Range
3347 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3348 /// with `b` at location `loc`.
getOrCreateRanges(OffsetSizeAndStrideOpInterface op,OpBuilder & b,Location loc)3349 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3350                                               OpBuilder &b, Location loc) {
3351   std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
3352   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3353   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3354   SmallVector<Range, 8> res;
3355   unsigned rank = ranks[0];
3356   res.reserve(rank);
3357   for (unsigned idx = 0; idx < rank; ++idx) {
3358     Value offset =
3359         op.isDynamicOffset(idx)
3360             ? op.getDynamicOffset(idx)
3361             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
3362     Value size = op.isDynamicSize(idx)
3363                      ? op.getDynamicSize(idx)
3364                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
3365     Value stride =
3366         op.isDynamicStride(idx)
3367             ? op.getDynamicStride(idx)
3368             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
3369     res.emplace_back(Range{offset, size, stride});
3370   }
3371   return res;
3372 }
3373 
3374 namespace {
3375 
3376 /// Take a list of `values` with potential new constant to extract and a list
3377 /// of `constantValues` with`values.size()` sentinel that evaluate to true by
3378 /// applying `isDynamic`.
3379 /// Detects the `values` produced by a ConstantIndexOp and places the new
3380 /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<Value> & values,SmallVectorImpl<int64_t> & constantValues,llvm::function_ref<bool (int64_t)> isDynamic)3381 void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
3382                              SmallVectorImpl<int64_t> &constantValues,
3383                              llvm::function_ref<bool(int64_t)> isDynamic) {
3384   bool hasNewStaticValue = llvm::any_of(
3385       values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
3386   if (hasNewStaticValue) {
3387     for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
3388          cstIdx != e; ++cstIdx) {
3389       // Was already static, skip.
3390       if (!isDynamic(constantValues[cstIdx]))
3391         continue;
3392       // Newly static, move from Value to constant.
3393       if (matchPattern(values[valIdx], m_ConstantIndex())) {
3394         constantValues[cstIdx] =
3395             cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
3396         // Erase for impl. simplicity. Reverse iterator if we really must.
3397         values.erase(std::next(values.begin(), valIdx));
3398         continue;
3399       }
3400       // Remains dynamic move to next value.
3401       ++valIdx;
3402     }
3403   }
3404 }
3405 
replaceWithNewOp(PatternRewriter & rewriter,SubViewOp op,SubViewOp newOp)3406 static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
3407                              SubViewOp newOp) {
3408   rewriter.replaceOpWithNewOp<MemRefCastOp>(op, newOp, op.getType());
3409 }
3410 
replaceWithNewOp(PatternRewriter & rewriter,SubTensorOp op,SubTensorOp newOp)3411 static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
3412                              SubTensorOp newOp) {
3413   rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
3414 }
3415 
3416 /// Pattern to rewrite a subview op with constant arguments.
3417 template <typename OpType>
3418 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
3419     : public OpRewritePattern<OpType> {
3420 public:
3421   using OpRewritePattern<OpType>::OpRewritePattern;
3422 
matchAndRewrite(OpType op,PatternRewriter & rewriter) const3423   LogicalResult matchAndRewrite(OpType op,
3424                                 PatternRewriter &rewriter) const override {
3425     // No constant operand, just return;
3426     if (llvm::none_of(op.getOperands(), [](Value operand) {
3427           return matchPattern(operand, m_ConstantIndex());
3428         }))
3429       return failure();
3430 
3431     // At least one of offsets/sizes/strides is a new constant.
3432     // Form the new list of operands and constant attributes from the existing.
3433     SmallVector<Value, 8> newOffsets(op.offsets());
3434     SmallVector<int64_t, 8> newStaticOffsets =
3435         extractFromI64ArrayAttr(op.static_offsets());
3436     std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
3437     (void)ranks;
3438     assert(newStaticOffsets.size() == ranks[0]);
3439     canonicalizeSubViewPart(newOffsets, newStaticOffsets,
3440                             ShapedType::isDynamicStrideOrOffset);
3441 
3442     SmallVector<Value, 8> newSizes(op.sizes());
3443     SmallVector<int64_t, 8> newStaticSizes =
3444         extractFromI64ArrayAttr(op.static_sizes());
3445     assert(newStaticSizes.size() == ranks[1]);
3446     canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
3447 
3448     SmallVector<Value, 8> newStrides(op.strides());
3449     SmallVector<int64_t, 8> newStaticStrides =
3450         extractFromI64ArrayAttr(op.static_strides());
3451     assert(newStaticStrides.size() == ranks[2]);
3452     canonicalizeSubViewPart(newStrides, newStaticStrides,
3453                             ShapedType::isDynamicStrideOrOffset);
3454 
3455     // Create the new op in canonical form.
3456     auto newOp = rewriter.create<OpType>(
3457         op.getLoc(), op.source(), newStaticOffsets, newStaticSizes,
3458         newStaticStrides, newOffsets, newSizes, newStrides);
3459 
3460     replaceWithNewOp(rewriter, op, newOp);
3461 
3462     return success();
3463   }
3464 };
3465 
3466 } // end anonymous namespace
3467 
3468 /// Determines whether MemRefCastOp casts to a more dynamic version of the
3469 /// source memref. This is useful to to fold a memref_cast into a consuming op
3470 /// and implement canonicalization patterns for ops in different dialects that
3471 /// may consume the results of memref_cast operations. Such foldable memref_cast
3472 /// operations are typically inserted as `view` and `subview` ops are
3473 /// canonicalized, to preserve the type compatibility of their uses.
3474 ///
3475 /// Returns true when all conditions are met:
3476 /// 1. source and result are ranked memrefs with strided semantics and same
3477 /// element type and rank.
3478 /// 2. each of the source's size, offset or stride has more static information
3479 /// than the corresponding result's size, offset or stride.
3480 ///
3481 /// Example 1:
3482 /// ```mlir
3483 ///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
3484 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
3485 /// ```
3486 ///
3487 /// may fold into:
3488 ///
3489 /// ```mlir
3490 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
3491 /// ```
3492 ///
3493 /// Example 2:
3494 /// ```
3495 ///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
3496 ///          to memref<?x?xf32>
3497 ///   consumer %1 : memref<?x?xf32> ...
3498 /// ```
3499 ///
3500 /// may fold into:
3501 ///
3502 /// ```
3503 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
3504 /// ```
canFoldIntoConsumerOp(MemRefCastOp castOp)3505 bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
3506   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
3507   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
3508 
3509   // Requires ranked MemRefType.
3510   if (!sourceType || !resultType)
3511     return false;
3512 
3513   // Requires same elemental type.
3514   if (sourceType.getElementType() != resultType.getElementType())
3515     return false;
3516 
3517   // Requires same rank.
3518   if (sourceType.getRank() != resultType.getRank())
3519     return false;
3520 
3521   // Only fold casts between strided memref forms.
3522   int64_t sourceOffset, resultOffset;
3523   SmallVector<int64_t, 4> sourceStrides, resultStrides;
3524   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
3525       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
3526     return false;
3527 
3528   // If cast is towards more static sizes along any dimension, don't fold.
3529   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
3530     auto ss = std::get<0>(it), st = std::get<1>(it);
3531     if (ss != st)
3532       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
3533         return false;
3534   }
3535 
3536   // If cast is towards more static offset along any dimension, don't fold.
3537   if (sourceOffset != resultOffset)
3538     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
3539         !MemRefType::isDynamicStrideOrOffset(resultOffset))
3540       return false;
3541 
3542   // If cast is towards more static strides along any dimension, don't fold.
3543   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
3544     auto ss = std::get<0>(it), st = std::get<1>(it);
3545     if (ss != st)
3546       if (MemRefType::isDynamicStrideOrOffset(ss) &&
3547           !MemRefType::isDynamicStrideOrOffset(st))
3548         return false;
3549   }
3550 
3551   return true;
3552 }
3553 
3554 /// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
3555 /// Determines whether TensorCastOp casts to a more dynamic version of the
3556 /// source tensor. This is useful to fold a tensor_cast into a consuming op and
3557 /// implement canonicalization patterns for ops in different dialects that may
3558 /// consume the results of tensor_cast operations. Such foldable tensor_cast
3559 /// operations are typically inserted as `subtensor` ops and are canonicalized,
3560 /// to preserve the type compatibility of their uses.
3561 ///
3562 /// Returns true when all conditions are met:
3563 /// 1. source and result are ranked tensors with same element type and rank.
3564 /// 2. the tensor type has more static information than the result
3565 ///
3566 /// Example:
3567 /// ```mlir
3568 ///   %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3569 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
3570 /// ```
3571 ///
3572 /// folds into:
3573 ///
3574 /// ```mlir
3575 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
3576 /// ```
canFoldIntoConsumerOp(TensorCastOp castOp)3577 bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
3578   if (!castOp)
3579     return false;
3580 
3581   RankedTensorType sourceType =
3582       castOp.source().getType().dyn_cast<RankedTensorType>();
3583   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
3584 
3585   // Requires RankedTensorType.
3586   if (!sourceType || !resultType)
3587     return false;
3588 
3589   // Requires same elemental type.
3590   if (sourceType.getElementType() != resultType.getElementType())
3591     return false;
3592 
3593   // Requires same rank.
3594   if (sourceType.getRank() != resultType.getRank())
3595     return false;
3596 
3597   // If cast is towards more static sizes along any dimension, don't fold.
3598   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
3599     auto ss = std::get<0>(it), st = std::get<1>(it);
3600     if (ss != st)
3601       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
3602         return false;
3603   }
3604 
3605   return true;
3606 }
3607 
3608 namespace {
3609 /// Pattern to rewrite a subview op with MemRefCast arguments.
3610 /// This essentially pushes memref_cast past its consuming subview when
3611 /// `canFoldIntoConsumerOp` is true.
3612 ///
3613 /// Example:
3614 /// ```
3615 ///   %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
3616 ///   %1 = subview %0[0, 0][3, 4][1, 1] :
3617 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
3618 /// ```
3619 /// is rewritten into:
3620 /// ```
3621 ///   %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3622 ///   %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
3623 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
3624 /// ```
3625 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3626 public:
3627   using OpRewritePattern<SubViewOp>::OpRewritePattern;
3628 
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const3629   LogicalResult matchAndRewrite(SubViewOp subViewOp,
3630                                 PatternRewriter &rewriter) const override {
3631     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
3632     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3633           return matchPattern(operand, m_ConstantIndex());
3634         }))
3635       return failure();
3636 
3637     auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
3638     if (!castOp)
3639       return failure();
3640 
3641     if (!canFoldIntoConsumerOp(castOp))
3642       return failure();
3643 
3644     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
3645     /// the cast source operand type and the SubViewOp static information. This
3646     /// is the resulting type if the MemRefCastOp were folded.
3647     Type resultType = SubViewOp::inferResultType(
3648         castOp.source().getType().cast<MemRefType>(),
3649         extractFromI64ArrayAttr(subViewOp.static_offsets()),
3650         extractFromI64ArrayAttr(subViewOp.static_sizes()),
3651         extractFromI64ArrayAttr(subViewOp.static_strides()));
3652     Value newSubView = rewriter.create<SubViewOp>(
3653         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
3654         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
3655         subViewOp.static_sizes(), subViewOp.static_strides());
3656     rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
3657                                               newSubView);
3658     return success();
3659   }
3660 };
3661 } // namespace
3662 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3663 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3664                                             MLIRContext *context) {
3665   results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubViewOp>,
3666                  SubViewOpMemRefCastFolder>(context);
3667 }
3668 
fold(ArrayRef<Attribute> operands)3669 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
3670   if (getResult().getType().cast<ShapedType>().getRank() == 0 &&
3671       source().getType().cast<ShapedType>().getRank() == 0)
3672     return getViewSource();
3673 
3674   return {};
3675 }
3676 
3677 //===----------------------------------------------------------------------===//
3678 // SubTensorOp
3679 //===----------------------------------------------------------------------===//
3680 
3681 /// Print a subtensor op of the form:
3682 /// ```
3683 ///   `subtensor` ssa-name
3684 ///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3685 ///     `:` ranked-tensor-type `to` ranked-tensor-type
3686 /// ```
print(OpAsmPrinter & p,SubTensorOp op)3687 static void print(OpAsmPrinter &p, SubTensorOp op) {
3688   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
3689   p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
3690   p << op.source();
3691   printOffsetsSizesAndStrides(p, op);
3692   p << " : " << op.getSourceType() << " to " << op.getType();
3693 }
3694 
3695 /// Parse a subtensor op of the form:
3696 /// ```
3697 ///   `subtensor` ssa-name
3698 ///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3699 ///     `:` ranked-tensor-type `to` ranked-tensor-type
3700 /// ```
parseSubTensorOp(OpAsmParser & parser,OperationState & result)3701 static ParseResult parseSubTensorOp(OpAsmParser &parser,
3702                                     OperationState &result) {
3703   OpAsmParser::OperandType srcInfo;
3704   if (parser.parseOperand(srcInfo))
3705     return failure();
3706   Type srcType, dstType;
3707   auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
3708     return failure(parser.parseOptionalAttrDict(result.attributes) ||
3709                    parser.parseColonType(srcType) ||
3710                    parser.parseKeywordType("to", dstType) ||
3711                    parser.resolveOperand(srcInfo, srcType, result.operands));
3712   };
3713 
3714   if (failed(parseOffsetsSizesAndStrides(parser, result,
3715                                          /*segmentSizes=*/{1}, // source tensor
3716                                          preResolutionFn)))
3717     return failure();
3718   return parser.addTypeToList(dstType, result.types);
3719 }
3720 
3721 /// A subtensor result type can be fully inferred from the source type and the
3722 /// static representation of offsets, sizes and strides. Special sentinels
3723 /// encode the dynamic case.
inferResultType(RankedTensorType sourceRankedTensorType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)3724 Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
3725                                   ArrayRef<int64_t> staticOffsets,
3726                                   ArrayRef<int64_t> staticSizes,
3727                                   ArrayRef<int64_t> staticStrides) {
3728   unsigned rank = sourceRankedTensorType.getRank();
3729   (void)rank;
3730   assert(staticOffsets.size() == rank &&
3731          "unexpected staticOffsets size mismatch");
3732   assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
3733   assert(staticStrides.size() == rank &&
3734          "unexpected staticStrides size mismatch");
3735   return RankedTensorType::get(staticSizes,
3736                                sourceRankedTensorType.getElementType());
3737 }
3738 
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3739 void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
3740                               Value source, ArrayRef<int64_t> staticOffsets,
3741                               ArrayRef<int64_t> staticSizes,
3742                               ArrayRef<int64_t> staticStrides,
3743                               ValueRange offsets, ValueRange sizes,
3744                               ValueRange strides,
3745                               ArrayRef<NamedAttribute> attrs) {
3746   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3747   auto resultType = inferResultType(sourceRankedTensorType, staticOffsets,
3748                                     staticSizes, staticStrides);
3749   build(b, result, resultType, source, offsets, sizes, strides,
3750         b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3751         b.getI64ArrayAttr(staticStrides));
3752   result.addAttributes(attrs);
3753 }
3754 
3755 /// Build a SubTensorOp with all dynamic entries: `staticOffsets`, `staticSizes`
3756 /// and `staticStrides` are automatically filled with sentinel values that
3757 /// encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3758 void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
3759                               Value source, ValueRange offsets,
3760                               ValueRange sizes, ValueRange strides,
3761                               ArrayRef<NamedAttribute> attrs) {
3762   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3763   unsigned rank = sourceRankedTensorType.getRank();
3764   SmallVector<int64_t, 4> staticOffsetsVector(
3765       rank, ShapedType::kDynamicStrideOrOffset);
3766   SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
3767   SmallVector<int64_t, 4> staticStridesVector(
3768       rank, ShapedType::kDynamicStrideOrOffset);
3769   build(b, result, source, staticOffsetsVector, staticSizesVector,
3770         staticStridesVector, offsets, sizes, strides, attrs);
3771 }
3772 
3773 /// Verifier for SubTensorOp.
verify(SubTensorOp op)3774 static LogicalResult verify(SubTensorOp op) {
3775   // Verify result type against inferred type.
3776   auto expectedType = SubTensorOp::inferResultType(
3777       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
3778       extractFromI64ArrayAttr(op.static_sizes()),
3779       extractFromI64ArrayAttr(op.static_strides()));
3780   auto result = isRankReducedType(expectedType, op.getType());
3781   return produceSubViewErrorMsg(result, op, expectedType);
3782 }
3783 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3784 void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3785                                               MLIRContext *context) {
3786   results
3787       .insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>>(
3788           context);
3789 }
3790 
3791 //===----------------------------------------------------------------------===//
3792 // SubTensorInsertOp
3793 //===----------------------------------------------------------------------===//
3794 
3795 /// Print a subtensor_insert op of the form:
3796 /// ```
3797 ///   `subtensor_insert` ssa-name `into` ssa-name
3798 ///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3799 ///     `:` ranked-tensor-type `into` ranked-tensor-type
3800 /// ```
print(OpAsmPrinter & p,SubTensorInsertOp op)3801 static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
3802   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
3803   p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
3804   p << op.source() << " into " << op.dest();
3805   printOffsetsSizesAndStrides(p, op);
3806   p << " : " << op.getSourceType() << " into " << op.getType();
3807 }
3808 
3809 /// Parse a subtensor_insert op of the form:
3810 /// ```
3811 ///   `subtensor_insert` ssa-name `into` ssa-name
3812 ///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
3813 ///     `:` ranked-tensor-type `into` ranked-tensor-type
3814 /// ```
parseSubTensorInsertOp(OpAsmParser & parser,OperationState & result)3815 static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
3816                                           OperationState &result) {
3817   OpAsmParser::OperandType srcInfo, dstInfo;
3818   if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") ||
3819       parser.parseOperand(dstInfo))
3820     return failure();
3821   Type srcType, dstType;
3822   auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
3823     return failure(parser.parseOptionalAttrDict(result.attributes) ||
3824                    parser.parseColonType(srcType) ||
3825                    parser.parseKeywordType("into", dstType) ||
3826                    parser.resolveOperand(srcInfo, srcType, result.operands) ||
3827                    parser.resolveOperand(dstInfo, dstType, result.operands));
3828   };
3829 
3830   if (failed(parseOffsetsSizesAndStrides(
3831           parser, result,
3832           /*segmentSizes=*/{1, 1}, // source tensor, destination tensor
3833           preResolutionFn)))
3834     return failure();
3835   return parser.addTypeToList(dstType, result.types);
3836 }
3837 
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3838 void mlir::SubTensorInsertOp::build(
3839     OpBuilder &b, OperationState &result, Value source, Value dest,
3840     ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
3841     ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes,
3842     ValueRange strides, ArrayRef<NamedAttribute> attrs) {
3843   build(b, result, dest.getType(), source, dest, offsets, sizes, strides,
3844         b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3845         b.getI64ArrayAttr(staticStrides));
3846   result.addAttributes(attrs);
3847 }
3848 
3849 /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
3850 /// and `staticStrides` are automatically filled with source-memref-rank
3851 /// sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)3852 void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
3853                                     Value source, Value dest,
3854                                     ValueRange offsets, ValueRange sizes,
3855                                     ValueRange strides,
3856                                     ArrayRef<NamedAttribute> attrs) {
3857   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3858   unsigned rank = sourceRankedTensorType.getRank();
3859   SmallVector<int64_t, 4> staticOffsetsVector(
3860       rank, ShapedType::kDynamicStrideOrOffset);
3861   SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
3862   SmallVector<int64_t, 4> staticStridesVector(
3863       rank, ShapedType::kDynamicStrideOrOffset);
3864   build(b, result, source, dest, staticOffsetsVector, staticSizesVector,
3865         staticStridesVector, offsets, sizes, strides, attrs);
3866 }
3867 
3868 /// Verifier for SubViewOp.
verify(SubTensorInsertOp op)3869 static LogicalResult verify(SubTensorInsertOp op) {
3870   if (op.getType() != op.dest().getType())
3871     return op.emitError("expected result type to be ") << op.dest().getType();
3872   return success();
3873 }
3874 
3875 //===----------------------------------------------------------------------===//
3876 // TensorCastOp
3877 //===----------------------------------------------------------------------===//
3878 
areCastCompatible(Type a,Type b)3879 bool TensorCastOp::areCastCompatible(Type a, Type b) {
3880   auto aT = a.dyn_cast<TensorType>();
3881   auto bT = b.dyn_cast<TensorType>();
3882   if (!aT || !bT)
3883     return false;
3884 
3885   if (aT.getElementType() != bT.getElementType())
3886     return false;
3887 
3888   return succeeded(verifyCompatibleShape(aT, bT));
3889 }
3890 
fold(ArrayRef<Attribute> operands)3891 OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
3892   return impl::foldCastOp(*this);
3893 }
3894 
3895 /// Compute a TensorType that has the joined shape knowledge of the two
3896 /// given TensorTypes. The element types need to match.
joinShapes(TensorType one,TensorType two)3897 static TensorType joinShapes(TensorType one, TensorType two) {
3898   assert(one.getElementType() == two.getElementType());
3899 
3900   if (!one.hasRank())
3901     return two;
3902   if (!two.hasRank())
3903     return one;
3904 
3905   int64_t rank = one.getRank();
3906   if (rank != two.getRank())
3907     return {};
3908 
3909   SmallVector<int64_t, 4> join;
3910   join.reserve(rank);
3911   for (int64_t i = 0; i < rank; ++i) {
3912     if (one.isDynamicDim(i)) {
3913       join.push_back(two.getDimSize(i));
3914       continue;
3915     }
3916     if (two.isDynamicDim(i)) {
3917       join.push_back(one.getDimSize(i));
3918       continue;
3919     }
3920     if (one.getDimSize(i) != two.getDimSize(i))
3921       return {};
3922     join.push_back(one.getDimSize(i));
3923   }
3924   return RankedTensorType::get(join, one.getElementType());
3925 }
3926 
3927 namespace {
3928 
3929 /// Replaces chains of two tensor_cast operations by a single tensor_cast
3930 /// operation if doing so does not remove runtime constraints.
3931 struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
3932   using OpRewritePattern<TensorCastOp>::OpRewritePattern;
3933 
matchAndRewrite__anon2fd8af2c2d11::ChainedTensorCast3934   LogicalResult matchAndRewrite(TensorCastOp tensorCast,
3935                                 PatternRewriter &rewriter) const final {
3936     auto tensorCastOperand =
3937         tensorCast.getOperand().getDefiningOp<TensorCastOp>();
3938 
3939     if (!tensorCastOperand)
3940       return failure();
3941 
3942     auto sourceType =
3943         tensorCastOperand.getOperand().getType().cast<TensorType>();
3944     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
3945     auto resultType = tensorCast.getType().cast<TensorType>();
3946 
3947     // We can remove the intermediate cast if joining all three produces the
3948     // same result as just joining the source and result shapes.
3949     auto firstJoin =
3950         joinShapes(joinShapes(sourceType, intermediateType), resultType);
3951 
3952     // The join might not exist if the cast sequence would fail at runtime.
3953     if (!firstJoin)
3954       return failure();
3955 
3956     // The newJoin always exists if the above join exists, it might just contain
3957     // less information. If so, we cannot drop the intermediate cast, as doing
3958     // so would remove runtime checks.
3959     auto newJoin = joinShapes(sourceType, resultType);
3960     if (firstJoin != newJoin)
3961       return failure();
3962 
3963     rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
3964                                               tensorCastOperand.getOperand());
3965     return success();
3966   }
3967 };
3968 
3969 } // namespace
3970 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3971 void TensorCastOp::getCanonicalizationPatterns(
3972     OwningRewritePatternList &results, MLIRContext *context) {
3973   results.insert<ChainedTensorCast>(context);
3974 }
3975 
3976 //===----------------------------------------------------------------------===//
3977 // TensorLoadOp
3978 //===----------------------------------------------------------------------===//
3979 
fold(ArrayRef<Attribute>)3980 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
3981   if (auto tensorToMemref = memref().getDefiningOp<TensorToMemrefOp>())
3982     return tensorToMemref.tensor();
3983   return {};
3984 }
3985 
3986 //===----------------------------------------------------------------------===//
3987 // TensorToMemrefOp
3988 //===----------------------------------------------------------------------===//
3989 
fold(ArrayRef<Attribute>)3990 OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
3991   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
3992     if (tensorLoad.memref().getType() == getType())
3993       return tensorLoad.memref();
3994   return {};
3995 }
3996 
3997 //===----------------------------------------------------------------------===//
3998 // TransposeOp
3999 //===----------------------------------------------------------------------===//
4000 
4001 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
inferTransposeResultType(MemRefType memRefType,AffineMap permutationMap)4002 static MemRefType inferTransposeResultType(MemRefType memRefType,
4003                                            AffineMap permutationMap) {
4004   auto rank = memRefType.getRank();
4005   auto originalSizes = memRefType.getShape();
4006   // Compute permuted sizes.
4007   SmallVector<int64_t, 4> sizes(rank, 0);
4008   for (auto en : llvm::enumerate(permutationMap.getResults()))
4009     sizes[en.index()] =
4010         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
4011 
4012   // Compute permuted strides.
4013   int64_t offset;
4014   SmallVector<int64_t, 4> strides;
4015   auto res = getStridesAndOffset(memRefType, strides, offset);
4016   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
4017   (void)res;
4018   auto map =
4019       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
4020   map = permutationMap ? map.compose(permutationMap) : map;
4021   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
4022 }
4023 
build(OpBuilder & b,OperationState & result,Value in,AffineMapAttr permutation,ArrayRef<NamedAttribute> attrs)4024 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
4025                         AffineMapAttr permutation,
4026                         ArrayRef<NamedAttribute> attrs) {
4027   auto permutationMap = permutation.getValue();
4028   assert(permutationMap);
4029 
4030   auto memRefType = in.getType().cast<MemRefType>();
4031   // Compute result type.
4032   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
4033 
4034   build(b, result, resultType, in, attrs);
4035   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
4036 }
4037 
4038 // transpose $in $permutation attr-dict : type($in) `to` type(results)
print(OpAsmPrinter & p,TransposeOp op)4039 static void print(OpAsmPrinter &p, TransposeOp op) {
4040   p << "transpose " << op.in() << " " << op.permutation();
4041   p.printOptionalAttrDict(op.getAttrs(),
4042                           {TransposeOp::getPermutationAttrName()});
4043   p << " : " << op.in().getType() << " to " << op.getType();
4044 }
4045 
parseTransposeOp(OpAsmParser & parser,OperationState & result)4046 static ParseResult parseTransposeOp(OpAsmParser &parser,
4047                                     OperationState &result) {
4048   OpAsmParser::OperandType in;
4049   AffineMap permutation;
4050   MemRefType srcType, dstType;
4051   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
4052       parser.parseOptionalAttrDict(result.attributes) ||
4053       parser.parseColonType(srcType) ||
4054       parser.resolveOperand(in, srcType, result.operands) ||
4055       parser.parseKeywordType("to", dstType) ||
4056       parser.addTypeToList(dstType, result.types))
4057     return failure();
4058 
4059   result.addAttribute(TransposeOp::getPermutationAttrName(),
4060                       AffineMapAttr::get(permutation));
4061   return success();
4062 }
4063 
verify(TransposeOp op)4064 static LogicalResult verify(TransposeOp op) {
4065   if (!op.permutation().isPermutation())
4066     return op.emitOpError("expected a permutation map");
4067   if (op.permutation().getNumDims() != op.getShapedType().getRank())
4068     return op.emitOpError(
4069         "expected a permutation map of same rank as the input");
4070 
4071   auto srcType = op.in().getType().cast<MemRefType>();
4072   auto dstType = op.getType().cast<MemRefType>();
4073   auto transposedType = inferTransposeResultType(srcType, op.permutation());
4074   if (dstType != transposedType)
4075     return op.emitOpError("output type ")
4076            << dstType << " does not match transposed input type " << srcType
4077            << ", " << transposedType;
4078   return success();
4079 }
4080 
fold(ArrayRef<Attribute>)4081 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
4082   if (succeeded(foldMemRefCast(*this)))
4083     return getResult();
4084   return {};
4085 }
4086 
4087 //===----------------------------------------------------------------------===//
4088 // TruncateIOp
4089 //===----------------------------------------------------------------------===//
4090 
verify(TruncateIOp op)4091 static LogicalResult verify(TruncateIOp op) {
4092   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
4093   auto dstType = getElementTypeOrSelf(op.getType());
4094 
4095   if (srcType.isa<IndexType>())
4096     return op.emitError() << srcType << " is not a valid operand type";
4097   if (dstType.isa<IndexType>())
4098     return op.emitError() << dstType << " is not a valid result type";
4099 
4100   if (srcType.cast<IntegerType>().getWidth() <=
4101       dstType.cast<IntegerType>().getWidth())
4102     return op.emitError("operand type ")
4103            << srcType << " must be wider than result type " << dstType;
4104 
4105   return success();
4106 }
4107 
4108 //===----------------------------------------------------------------------===//
4109 // UnsignedDivIOp
4110 //===----------------------------------------------------------------------===//
4111 
fold(ArrayRef<Attribute> operands)4112 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
4113   assert(operands.size() == 2 && "binary operation takes two operands");
4114 
4115   // Don't fold if it would require a division by zero.
4116   bool div0 = false;
4117   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
4118     if (div0 || !b) {
4119       div0 = true;
4120       return a;
4121     }
4122     return a.udiv(b);
4123   });
4124 
4125   // Fold out division by one. Assumes all tensors of all ones are splats.
4126   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
4127     if (rhs.getValue() == 1)
4128       return lhs();
4129   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
4130     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
4131       return lhs();
4132   }
4133 
4134   return div0 ? Attribute() : result;
4135 }
4136 
4137 //===----------------------------------------------------------------------===//
4138 // UnsignedRemIOp
4139 //===----------------------------------------------------------------------===//
4140 
fold(ArrayRef<Attribute> operands)4141 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
4142   assert(operands.size() == 2 && "remi_unsigned takes two operands");
4143 
4144   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
4145   if (!rhs)
4146     return {};
4147   auto rhsValue = rhs.getValue();
4148 
4149   // x % 1 = 0
4150   if (rhsValue.isOneValue())
4151     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
4152 
4153   // Don't fold if it requires division by zero.
4154   if (rhsValue.isNullValue())
4155     return {};
4156 
4157   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
4158   if (!lhs)
4159     return {};
4160   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
4161 }
4162 
4163 //===----------------------------------------------------------------------===//
4164 // ViewOp
4165 //===----------------------------------------------------------------------===//
4166 
parseViewOp(OpAsmParser & parser,OperationState & result)4167 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
4168   OpAsmParser::OperandType srcInfo;
4169   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
4170   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
4171   auto indexType = parser.getBuilder().getIndexType();
4172   Type srcType, dstType;
4173   llvm::SMLoc offsetLoc;
4174   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
4175       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
4176     return failure();
4177 
4178   if (offsetInfo.size() != 1)
4179     return parser.emitError(offsetLoc) << "expects 1 offset operand";
4180 
4181   return failure(
4182       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
4183       parser.parseOptionalAttrDict(result.attributes) ||
4184       parser.parseColonType(srcType) ||
4185       parser.resolveOperand(srcInfo, srcType, result.operands) ||
4186       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
4187       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
4188       parser.parseKeywordType("to", dstType) ||
4189       parser.addTypeToList(dstType, result.types));
4190 }
4191 
print(OpAsmPrinter & p,ViewOp op)4192 static void print(OpAsmPrinter &p, ViewOp op) {
4193   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
4194   p.printOperand(op.byte_shift());
4195   p << "][" << op.sizes() << ']';
4196   p.printOptionalAttrDict(op.getAttrs());
4197   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
4198 }
4199 
verify(ViewOp op)4200 static LogicalResult verify(ViewOp op) {
4201   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
4202   auto viewType = op.getType();
4203 
4204   // The base memref should have identity layout map (or none).
4205   if (baseType.getAffineMaps().size() > 1 ||
4206       (baseType.getAffineMaps().size() == 1 &&
4207        !baseType.getAffineMaps()[0].isIdentity()))
4208     return op.emitError("unsupported map for base memref type ") << baseType;
4209 
4210   // The result memref should have identity layout map (or none).
4211   if (viewType.getAffineMaps().size() > 1 ||
4212       (viewType.getAffineMaps().size() == 1 &&
4213        !viewType.getAffineMaps()[0].isIdentity()))
4214     return op.emitError("unsupported map for result memref type ") << viewType;
4215 
4216   // The base memref and the view memref should be in the same memory space.
4217   if (baseType.getMemorySpace() != viewType.getMemorySpace())
4218     return op.emitError("different memory spaces specified for base memref "
4219                         "type ")
4220            << baseType << " and view memref type " << viewType;
4221 
4222   // Verify that we have the correct number of sizes for the result type.
4223   unsigned numDynamicDims = viewType.getNumDynamicDims();
4224   if (op.sizes().size() != numDynamicDims)
4225     return op.emitError("incorrect number of size operands for type ")
4226            << viewType;
4227 
4228   return success();
4229 }
4230 
getViewSource()4231 Value ViewOp::getViewSource() { return source(); }
4232 
4233 namespace {
4234 
4235 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
4236   using OpRewritePattern<ViewOp>::OpRewritePattern;
4237 
matchAndRewrite__anon2fd8af2c2f11::ViewOpShapeFolder4238   LogicalResult matchAndRewrite(ViewOp viewOp,
4239                                 PatternRewriter &rewriter) const override {
4240     // Return if none of the operands are constants.
4241     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
4242           return matchPattern(operand, m_ConstantIndex());
4243         }))
4244       return failure();
4245 
4246     // Get result memref type.
4247     auto memrefType = viewOp.getType();
4248 
4249     // Get offset from old memref view type 'memRefType'.
4250     int64_t oldOffset;
4251     SmallVector<int64_t, 4> oldStrides;
4252     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
4253       return failure();
4254     assert(oldOffset == 0 && "Expected 0 offset");
4255 
4256     SmallVector<Value, 4> newOperands;
4257 
4258     // Offset cannot be folded into result type.
4259 
4260     // Fold any dynamic dim operands which are produced by a constant.
4261     SmallVector<int64_t, 4> newShapeConstants;
4262     newShapeConstants.reserve(memrefType.getRank());
4263 
4264     unsigned dynamicDimPos = 0;
4265     unsigned rank = memrefType.getRank();
4266     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
4267       int64_t dimSize = memrefType.getDimSize(dim);
4268       // If this is already static dimension, keep it.
4269       if (!ShapedType::isDynamic(dimSize)) {
4270         newShapeConstants.push_back(dimSize);
4271         continue;
4272       }
4273       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
4274       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
4275         // Dynamic shape dimension will be folded.
4276         newShapeConstants.push_back(constantIndexOp.getValue());
4277       } else {
4278         // Dynamic shape dimension not folded; copy operand from old memref.
4279         newShapeConstants.push_back(dimSize);
4280         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
4281       }
4282       dynamicDimPos++;
4283     }
4284 
4285     // Create new memref type with constant folded dims.
4286     MemRefType newMemRefType =
4287         MemRefType::Builder(memrefType).setShape(newShapeConstants);
4288     // Nothing new, don't fold.
4289     if (newMemRefType == memrefType)
4290       return failure();
4291 
4292     // Create new ViewOp.
4293     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
4294                                              viewOp.getOperand(0),
4295                                              viewOp.byte_shift(), newOperands);
4296     // Insert a cast so we have the same type as the old memref type.
4297     rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
4298                                               viewOp.getType());
4299     return success();
4300   }
4301 };
4302 
4303 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
4304   using OpRewritePattern<ViewOp>::OpRewritePattern;
4305 
matchAndRewrite__anon2fd8af2c2f11::ViewOpMemrefCastFolder4306   LogicalResult matchAndRewrite(ViewOp viewOp,
4307                                 PatternRewriter &rewriter) const override {
4308     Value memrefOperand = viewOp.getOperand(0);
4309     MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
4310     if (!memrefCastOp)
4311       return failure();
4312     Value allocOperand = memrefCastOp.getOperand();
4313     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
4314     if (!allocOp)
4315       return failure();
4316     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
4317                                         viewOp.byte_shift(), viewOp.sizes());
4318     return success();
4319   }
4320 };
4321 
4322 } // end anonymous namespace
4323 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)4324 void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
4325                                          MLIRContext *context) {
4326   results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
4327 }
4328 
4329 //===----------------------------------------------------------------------===//
4330 // XOrOp
4331 //===----------------------------------------------------------------------===//
4332 
fold(ArrayRef<Attribute> operands)4333 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
4334   /// xor(x, 0) -> x
4335   if (matchPattern(rhs(), m_Zero()))
4336     return lhs();
4337   /// xor(x,x) -> 0
4338   if (lhs() == rhs())
4339     return Builder(getContext()).getZeroAttr(getType());
4340 
4341   return constFoldBinaryOp<IntegerAttr>(operands,
4342                                         [](APInt a, APInt b) { return a ^ b; });
4343 }
4344 
4345 //===----------------------------------------------------------------------===//
4346 // ZeroExtendIOp
4347 //===----------------------------------------------------------------------===//
4348 
verify(ZeroExtendIOp op)4349 static LogicalResult verify(ZeroExtendIOp op) {
4350   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
4351   auto dstType = getElementTypeOrSelf(op.getType());
4352 
4353   if (srcType.isa<IndexType>())
4354     return op.emitError() << srcType << " is not a valid operand type";
4355   if (dstType.isa<IndexType>())
4356     return op.emitError() << dstType << " is not a valid result type";
4357 
4358   if (srcType.cast<IntegerType>().getWidth() >=
4359       dstType.cast<IntegerType>().getWidth())
4360     return op.emitError("result type ")
4361            << dstType << " must be wider than operand type " << srcType;
4362 
4363   return success();
4364 }
4365 
4366 //===----------------------------------------------------------------------===//
4367 // TableGen'd op method definitions
4368 //===----------------------------------------------------------------------===//
4369 
4370 #define GET_OP_CLASSES
4371 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
4372