1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
verifyOpWithOffsetSizesAndStridesPart(OffsetSizeAndStrideOpInterface op,StringRef name,unsigned expectedNumElements,StringRef attrName,ArrayAttr attr,llvm::function_ref<bool (int64_t)> isDynamic,ValueRange values)20 static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
21     OffsetSizeAndStrideOpInterface op, StringRef name,
22     unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
23     llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
24   /// Check static and dynamic offsets/sizes/strides breakdown.
25   if (attr.size() != expectedNumElements)
26     return op.emitError("expected ")
27            << expectedNumElements << " " << name << " values";
28   unsigned expectedNumDynamicEntries =
29       llvm::count_if(attr.getValue(), [&](Attribute attr) {
30         return isDynamic(attr.cast<IntegerAttr>().getInt());
31       });
32   if (values.size() != expectedNumDynamicEntries)
33     return op.emitError("expected ")
34            << expectedNumDynamicEntries << " dynamic " << name << " values";
35   return success();
36 }
37 
verify(OffsetSizeAndStrideOpInterface op)38 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
39   std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
40   if (failed(verifyOpWithOffsetSizesAndStridesPart(
41           op, "offset", ranks[0],
42           OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
43           op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
44           op.offsets())))
45     return failure();
46   if (failed(verifyOpWithOffsetSizesAndStridesPart(
47           op, "size", ranks[1],
48           OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
49           op.static_sizes(), ShapedType::isDynamic, op.sizes())))
50     return failure();
51   if (failed(verifyOpWithOffsetSizesAndStridesPart(
52           op, "stride", ranks[2],
53           OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
54           op.static_strides(), ShapedType::isDynamicStrideOrOffset,
55           op.strides())))
56     return failure();
57   return success();
58 }
59 
60 /// Print a list with either (1) the static integer value in `arrayAttr` if
61 /// `isDynamic` evaluates to false or (2) the next value otherwise.
62 /// This allows idiomatic printing of mixed value and integer attributes in a
63 /// list. E.g. `[%arg0, 7, 42, %arg42]`.
64 static void
printListOfOperandsOrIntegers(OpAsmPrinter & p,ValueRange values,ArrayAttr arrayAttr,llvm::function_ref<bool (int64_t)> isDynamic)65 printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
66                               ArrayAttr arrayAttr,
67                               llvm::function_ref<bool(int64_t)> isDynamic) {
68   p << '[';
69   unsigned idx = 0;
70   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
71     int64_t val = a.cast<IntegerAttr>().getInt();
72     if (isDynamic(val))
73       p << values[idx++];
74     else
75       p << val;
76   });
77   p << ']';
78 }
79 
printOffsetsSizesAndStrides(OpAsmPrinter & p,OffsetSizeAndStrideOpInterface op,StringRef offsetPrefix,StringRef sizePrefix,StringRef stridePrefix,ArrayRef<StringRef> elidedAttrs)80 void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
81                                        OffsetSizeAndStrideOpInterface op,
82                                        StringRef offsetPrefix,
83                                        StringRef sizePrefix,
84                                        StringRef stridePrefix,
85                                        ArrayRef<StringRef> elidedAttrs) {
86   p << offsetPrefix;
87   printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
88                                 ShapedType::isDynamicStrideOrOffset);
89   p << sizePrefix;
90   printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
91                                 ShapedType::isDynamic);
92   p << stridePrefix;
93   printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
94                                 ShapedType::isDynamicStrideOrOffset);
95   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
96 }
97 
98 /// Parse a mixed list with either (1) static integer values or (2) SSA values.
99 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
100 /// encode the position of SSA values. Add the parsed SSA values to `ssa`
101 /// in-order.
102 //
103 /// E.g. after parsing "[%arg0, 7, 42, %arg42]":
104 ///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
105 ///   2. `ssa` is filled with "[%arg0, %arg1]".
106 static ParseResult
parseListOfOperandsOrIntegers(OpAsmParser & parser,OperationState & result,StringRef attrName,int64_t dynVal,SmallVectorImpl<OpAsmParser::OperandType> & ssa)107 parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
108                               StringRef attrName, int64_t dynVal,
109                               SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
110   if (failed(parser.parseLSquare()))
111     return failure();
112   // 0-D.
113   if (succeeded(parser.parseOptionalRSquare())) {
114     result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
115     return success();
116   }
117 
118   SmallVector<int64_t, 4> attrVals;
119   while (true) {
120     OpAsmParser::OperandType operand;
121     auto res = parser.parseOptionalOperand(operand);
122     if (res.hasValue() && succeeded(res.getValue())) {
123       ssa.push_back(operand);
124       attrVals.push_back(dynVal);
125     } else {
126       IntegerAttr attr;
127       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
128         return parser.emitError(parser.getNameLoc())
129                << "expected SSA value or integer";
130       attrVals.push_back(attr.getInt());
131     }
132 
133     if (succeeded(parser.parseOptionalComma()))
134       continue;
135     if (failed(parser.parseRSquare()))
136       return failure();
137     break;
138   }
139 
140   auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
141   result.addAttribute(attrName, arrayAttr);
142   return success();
143 }
144 
parseOffsetsSizesAndStrides(OpAsmParser & parser,OperationState & result,ArrayRef<int> segmentSizes,llvm::function_ref<ParseResult (OpAsmParser &)> parseOptionalOffsetPrefix,llvm::function_ref<ParseResult (OpAsmParser &)> parseOptionalSizePrefix,llvm::function_ref<ParseResult (OpAsmParser &)> parseOptionalStridePrefix)145 ParseResult mlir::parseOffsetsSizesAndStrides(
146     OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
147     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
148     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
149     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
150   return parseOffsetsSizesAndStrides(
151       parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
152       parseOptionalSizePrefix, parseOptionalStridePrefix);
153 }
154 
parseOffsetsSizesAndStrides(OpAsmParser & parser,OperationState & result,ArrayRef<int> segmentSizes,llvm::function_ref<ParseResult (OpAsmParser &,OperationState &)> preResolutionFn,llvm::function_ref<ParseResult (OpAsmParser &)> parseOptionalOffsetPrefix,llvm::function_ref<ParseResult (OpAsmParser &)> parseOptionalSizePrefix,llvm::function_ref<ParseResult (OpAsmParser &)> parseOptionalStridePrefix)155 ParseResult mlir::parseOffsetsSizesAndStrides(
156     OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
157     llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
158         preResolutionFn,
159     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
160     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
161     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
162   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
163   auto indexType = parser.getBuilder().getIndexType();
164   if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
165       parseListOfOperandsOrIntegers(
166           parser, result,
167           OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
168           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
169       (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
170       parseListOfOperandsOrIntegers(
171           parser, result,
172           OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
173           ShapedType::kDynamicSize, sizesInfo) ||
174       (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
175       parseListOfOperandsOrIntegers(
176           parser, result,
177           OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
178           ShapedType::kDynamicStrideOrOffset, stridesInfo))
179     return failure();
180   // Add segment sizes to result
181   SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
182                                         segmentSizes.end());
183   segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
184                             static_cast<int>(sizesInfo.size()),
185                             static_cast<int>(stridesInfo.size())});
186   result.addAttribute(
187       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
188       parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
189   return failure(
190       (preResolutionFn && preResolutionFn(parser, result)) ||
191       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
192       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
193       parser.resolveOperands(stridesInfo, indexType, result.operands));
194 }
195