1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/FoldUtils.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "mlir/Transforms/RegionUtils.h"
22
23 using namespace mlir;
24 using namespace mlir::tosa;
25
26 //===----------------------------------------------------------------------===//
27 // Tosa dialect structs and interface includes.
28 //===----------------------------------------------------------------------===//
29 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
30 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
31
32 namespace {
33 //===----------------------------------------------------------------------===//
34 // Dialect Function Inliner Interface.
35 //===----------------------------------------------------------------------===//
36 struct TosaInlinerInterface : public DialectInlinerInterface {
37 using DialectInlinerInterface::DialectInlinerInterface;
38
39 //===--------------------------------------------------------------------===//
40 // Analysis Hooks.
41 //===--------------------------------------------------------------------===//
42
43 /// All operations can be inlined by default.
isLegalToInline__anon6494a0970111::TosaInlinerInterface44 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
45 BlockAndValueMapping &map) const final {
46 return true;
47 }
48
49 /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon6494a0970111::TosaInlinerInterface50 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
51 BlockAndValueMapping &map) const final {
52 return (isa<tosa::IfOp>(dest->getParentOp()) ||
53 isa<tosa::WhileOp>(dest->getParentOp()));
54 }
55 };
56 } // end anonymous namespace
57
58 //===----------------------------------------------------------------------===//
59 // TOSA control flow support.
60 //===----------------------------------------------------------------------===//
61
62 /// Returns the while loop body.
getLoopBody()63 Region &tosa::WhileOp::getLoopBody() { return body(); }
64
isDefinedOutsideOfLoop(Value value)65 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
66 return !body().isAncestor(value.getParentRegion());
67 }
68
moveOutOfLoop(ArrayRef<mlir::Operation * > ops)69 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
70 if (ops.empty())
71 return success();
72
73 Operation *tosaWhileOp = this->getOperation();
74 for (auto *op : ops)
75 op->moveBefore(tosaWhileOp);
76
77 return success();
78 }
79
80 //===----------------------------------------------------------------------===//
81 // Tosa dialect initialization.
82 //===----------------------------------------------------------------------===//
83
initialize()84 void TosaDialect::initialize() {
85 addOperations<
86 #define GET_OP_LIST
87 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
88 >();
89 addInterfaces<TosaInlinerInterface>();
90 }
91
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)92 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
93 Type type, Location loc) {
94 // Tosa dialect constants only support ElementsAttr unlike standard dialect
95 // constant which supports all attributes.
96 if (value.isa<ElementsAttr>())
97 return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
98 return nullptr;
99 }
100
101 //===----------------------------------------------------------------------===//
102 // Operator Folders.
103 //===----------------------------------------------------------------------===//
104
fold(ArrayRef<Attribute> operands)105 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
106 assert(operands.empty() && "constant has no operands");
107 return valueAttr();
108 }
109
110 //===----------------------------------------------------------------------===//
111 // TOSA Operator Verifiers.
112 //===----------------------------------------------------------------------===//
113
114 template <typename T>
verifyConvOp(T op)115 static LogicalResult verifyConvOp(T op) {
116 // All TOSA conv ops have an input() and weight().
117 auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
118 auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
119
120 // Must be ranked tensor types
121 if (!inputType || !weightType)
122 return failure();
123
124 auto inputQType =
125 inputType.getElementType().template isa<mlir::quant::QuantizedType>();
126 auto weightQType =
127 weightType.getElementType().template isa<mlir::quant::QuantizedType>();
128
129 // Either both must be quantized or both unquantized.
130 if (inputQType != weightQType)
131 return failure();
132
133 // Quantized type must have constructed the quantizationattr, and unquantized
134 // types should not have a quantizationattr.
135 if ((inputQType && !op.quantization_info()) ||
136 (!inputQType && op.quantization_info()))
137 return failure();
138
139 return success();
140 }
141
142 //===----------------------------------------------------------------------===//
143 // TOSA Operator Quantization Builders.
144 //===----------------------------------------------------------------------===//
145
146 /// This builder is called on all convolution operators except TransposeConv,
147 /// which has specialized output shape semantics. The builder also defines the
148 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)149 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
150 Type outputType, Value input, Value weight,
151 Value bias, ArrayAttr pad,
152 ArrayAttr stride, ArrayAttr dilation) {
153
154 result.addOperands({input, weight, bias});
155 result.addAttribute("pad", pad);
156 result.addAttribute("stride", stride);
157 result.addAttribute("dilation", dilation);
158
159 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
160 if (quantAttr) {
161 result.addAttribute("quantization_info", quantAttr);
162 result.addTypes(
163 buildConvOpResultTypeInfo(builder, outputType, input, weight));
164 } else {
165 result.addTypes(outputType);
166 }
167 }
168
169 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
170 static void
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr dilation,ArrayAttr outputShape)171 buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
172 Type outputType, Value input, Value weight,
173 Value bias, ArrayAttr outpad, ArrayAttr stride,
174 ArrayAttr dilation, ArrayAttr outputShape) {
175 result.addOperands({input, weight, bias});
176 result.addAttribute("out_pad", outpad);
177 result.addAttribute("stride", stride);
178 result.addAttribute("dilation", dilation);
179 result.addAttribute("out_shape", outputShape);
180 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
181
182 if (quantAttr) {
183 result.addAttribute("quantization_info", quantAttr);
184 result.addTypes(
185 buildConvOpResultTypeInfo(builder, outputType, input, weight));
186 } else {
187 result.addTypes(outputType);
188 }
189 }
190
191 /// The tosa.fully_connected op has its own builder as it does not have
192 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)193 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
194 Type outputType, Value input, Value weight,
195 Value bias) {
196
197 result.addOperands({input, weight, bias});
198 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
199 if (quantAttr) {
200 result.addAttribute("quantization_info", quantAttr);
201 result.addTypes(
202 buildConvOpResultTypeInfo(builder, outputType, input, weight));
203 } else {
204 result.addTypes(outputType);
205 }
206 }
207
208 /// The tosa.matmul op is also intended to be generated where a fully_connected
209 /// op must be constructed where the weight is not a constant. In this case,
210 /// the fully_connected op must be expressed using matmul.
211 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)212 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
213 OperationState &result, Type outputType,
214 Value a, Value b) {
215 result.addOperands({a, b});
216 auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
217
218 if (quantAttr) {
219 result.addAttribute("quantization_info", quantAttr);
220
221 auto inputType = a.getType().dyn_cast<RankedTensorType>();
222 assert(inputType && "Input must be a ranked tensor type!");
223
224 auto inputQType = inputType.getElementType()
225 .dyn_cast<mlir::quant::UniformQuantizedType>();
226 assert(inputQType && "Tensor must have quantized datatype!");
227
228 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
229
230 auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
231 assert(outputShapedType && "Output must be a ranked tensor type");
232
233 auto outputShape = outputShapedType.getShape();
234
235 IntegerType accElementType;
236 if (inputBits == 16)
237 accElementType = builder.getIntegerType(48);
238 else
239 accElementType = builder.getI32Type();
240 auto accType = RankedTensorType::get(outputShape, accElementType);
241 result.addTypes(accType);
242 } else {
243 result.addTypes(outputType);
244 }
245 }
246
247 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
248 /// but avg_pool operator has its own builder as it has additional parameters
249 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)250 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
251 OperationState &result,
252 Type outputType, Value input,
253 ArrayAttr kernel, ArrayAttr stride,
254 ArrayAttr pad) {
255 result.addOperands(input);
256 result.addAttribute("kernel", kernel);
257 result.addAttribute("stride", stride);
258 result.addAttribute("pad", pad);
259 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
260 if (quantAttr)
261 result.addAttribute("quantization_info", quantAttr);
262 result.types.push_back(outputType);
263 }
264
265 /// This builder is called on single-parameter unary operators that have scale
266 /// relationship between their input and output, expressed by the
267 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)268 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
269 OperationState &result, Type outputType,
270 Value input) {
271 result.addOperands(input);
272 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
273 if (quantAttr)
274 result.addAttribute("quantization_info", quantAttr);
275 result.types.push_back(outputType);
276 }
277
278 /// This builder is called on TOSA pad operator that needs to create its own
279 /// OptionalAttr quantization_attr parameter to scale the padding values
280 /// correctly.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)281 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
282 Type outputType, Value input,
283 Value paddings) {
284 result.addOperands({input, paddings});
285 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
286 if (quantAttr)
287 result.addAttribute("quantization_info", quantAttr);
288 result.types.push_back(outputType);
289 }
290
291 //===----------------------------------------------------------------------===//
292 // TOSA Operator Definitions.
293 //===----------------------------------------------------------------------===//
294
295 #define GET_OP_CLASSES
296 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
297