1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/FoldUtils.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "mlir/Transforms/RegionUtils.h"
22 
23 using namespace mlir;
24 using namespace mlir::tosa;
25 
26 //===----------------------------------------------------------------------===//
27 // Tosa dialect structs and interface includes.
28 //===----------------------------------------------------------------------===//
29 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
30 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
31 
32 namespace {
33 //===----------------------------------------------------------------------===//
34 // Dialect Function Inliner Interface.
35 //===----------------------------------------------------------------------===//
36 struct TosaInlinerInterface : public DialectInlinerInterface {
37   using DialectInlinerInterface::DialectInlinerInterface;
38 
39   //===--------------------------------------------------------------------===//
40   // Analysis Hooks.
41   //===--------------------------------------------------------------------===//
42 
43   /// All operations can be inlined by default.
isLegalToInline__anon6494a0970111::TosaInlinerInterface44   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
45                        BlockAndValueMapping &map) const final {
46     return true;
47   }
48 
49   /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon6494a0970111::TosaInlinerInterface50   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
51                        BlockAndValueMapping &map) const final {
52     return (isa<tosa::IfOp>(dest->getParentOp()) ||
53             isa<tosa::WhileOp>(dest->getParentOp()));
54   }
55 };
56 } // end anonymous namespace
57 
58 //===----------------------------------------------------------------------===//
59 // TOSA control flow support.
60 //===----------------------------------------------------------------------===//
61 
62 /// Returns the while loop body.
getLoopBody()63 Region &tosa::WhileOp::getLoopBody() { return body(); }
64 
isDefinedOutsideOfLoop(Value value)65 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
66   return !body().isAncestor(value.getParentRegion());
67 }
68 
moveOutOfLoop(ArrayRef<mlir::Operation * > ops)69 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
70   if (ops.empty())
71     return success();
72 
73   Operation *tosaWhileOp = this->getOperation();
74   for (auto *op : ops)
75     op->moveBefore(tosaWhileOp);
76 
77   return success();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Tosa dialect initialization.
82 //===----------------------------------------------------------------------===//
83 
initialize()84 void TosaDialect::initialize() {
85   addOperations<
86 #define GET_OP_LIST
87 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
88       >();
89   addInterfaces<TosaInlinerInterface>();
90 }
91 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)92 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
93                                             Type type, Location loc) {
94   // Tosa dialect constants only support ElementsAttr unlike standard dialect
95   // constant which supports all attributes.
96   if (value.isa<ElementsAttr>())
97     return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
98   return nullptr;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Operator Folders.
103 //===----------------------------------------------------------------------===//
104 
fold(ArrayRef<Attribute> operands)105 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
106   assert(operands.empty() && "constant has no operands");
107   return valueAttr();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // TOSA Operator Verifiers.
112 //===----------------------------------------------------------------------===//
113 
114 template <typename T>
verifyConvOp(T op)115 static LogicalResult verifyConvOp(T op) {
116   // All TOSA conv ops have an input() and weight().
117   auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
118   auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
119 
120   // Must be ranked tensor types
121   if (!inputType || !weightType)
122     return failure();
123 
124   auto inputQType =
125       inputType.getElementType().template isa<mlir::quant::QuantizedType>();
126   auto weightQType =
127       weightType.getElementType().template isa<mlir::quant::QuantizedType>();
128 
129   // Either both must be quantized or both unquantized.
130   if (inputQType != weightQType)
131     return failure();
132 
133   // Quantized type must have constructed the quantizationattr, and unquantized
134   // types should not have a quantizationattr.
135   if ((inputQType && !op.quantization_info()) ||
136       (!inputQType && op.quantization_info()))
137     return failure();
138 
139   return success();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // TOSA Operator Quantization Builders.
144 //===----------------------------------------------------------------------===//
145 
146 /// This builder is called on all convolution operators except TransposeConv,
147 /// which has specialized output shape semantics. The builder also defines the
148 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)149 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
150                                      Type outputType, Value input, Value weight,
151                                      Value bias, ArrayAttr pad,
152                                      ArrayAttr stride, ArrayAttr dilation) {
153 
154   result.addOperands({input, weight, bias});
155   result.addAttribute("pad", pad);
156   result.addAttribute("stride", stride);
157   result.addAttribute("dilation", dilation);
158 
159   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
160   if (quantAttr) {
161     result.addAttribute("quantization_info", quantAttr);
162     result.addTypes(
163         buildConvOpResultTypeInfo(builder, outputType, input, weight));
164   } else {
165     result.addTypes(outputType);
166   }
167 }
168 
169 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
170 static void
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr dilation,ArrayAttr outputShape)171 buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
172                               Type outputType, Value input, Value weight,
173                               Value bias, ArrayAttr outpad, ArrayAttr stride,
174                               ArrayAttr dilation, ArrayAttr outputShape) {
175   result.addOperands({input, weight, bias});
176   result.addAttribute("out_pad", outpad);
177   result.addAttribute("stride", stride);
178   result.addAttribute("dilation", dilation);
179   result.addAttribute("out_shape", outputShape);
180   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
181 
182   if (quantAttr) {
183     result.addAttribute("quantization_info", quantAttr);
184     result.addTypes(
185         buildConvOpResultTypeInfo(builder, outputType, input, weight));
186   } else {
187     result.addTypes(outputType);
188   }
189 }
190 
191 /// The tosa.fully_connected op has its own builder as it does not have
192 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)193 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
194                                    Type outputType, Value input, Value weight,
195                                    Value bias) {
196 
197   result.addOperands({input, weight, bias});
198   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
199   if (quantAttr) {
200     result.addAttribute("quantization_info", quantAttr);
201     result.addTypes(
202         buildConvOpResultTypeInfo(builder, outputType, input, weight));
203   } else {
204     result.addTypes(outputType);
205   }
206 }
207 
208 /// The tosa.matmul op is also intended to be generated where a fully_connected
209 /// op must be constructed where the weight is not a constant. In this case,
210 /// the fully_connected op must be expressed using matmul.
211 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)212 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
213                                        OperationState &result, Type outputType,
214                                        Value a, Value b) {
215   result.addOperands({a, b});
216   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
217 
218   if (quantAttr) {
219     result.addAttribute("quantization_info", quantAttr);
220 
221     auto inputType = a.getType().dyn_cast<RankedTensorType>();
222     assert(inputType && "Input must be a ranked tensor type!");
223 
224     auto inputQType = inputType.getElementType()
225                           .dyn_cast<mlir::quant::UniformQuantizedType>();
226     assert(inputQType && "Tensor must have quantized datatype!");
227 
228     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
229 
230     auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
231     assert(outputShapedType && "Output must be a ranked tensor type");
232 
233     auto outputShape = outputShapedType.getShape();
234 
235     IntegerType accElementType;
236     if (inputBits == 16)
237       accElementType = builder.getIntegerType(48);
238     else
239       accElementType = builder.getI32Type();
240     auto accType = RankedTensorType::get(outputShape, accElementType);
241     result.addTypes(accType);
242   } else {
243     result.addTypes(outputType);
244   }
245 }
246 
247 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
248 /// but avg_pool operator has its own builder as it has additional parameters
249 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)250 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
251                                           OperationState &result,
252                                           Type outputType, Value input,
253                                           ArrayAttr kernel, ArrayAttr stride,
254                                           ArrayAttr pad) {
255   result.addOperands(input);
256   result.addAttribute("kernel", kernel);
257   result.addAttribute("stride", stride);
258   result.addAttribute("pad", pad);
259   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
260   if (quantAttr)
261     result.addAttribute("quantization_info", quantAttr);
262   result.types.push_back(outputType);
263 }
264 
265 /// This builder is called on single-parameter unary operators that have scale
266 /// relationship between their input and output, expressed by the
267 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)268 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
269                                       OperationState &result, Type outputType,
270                                       Value input) {
271   result.addOperands(input);
272   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
273   if (quantAttr)
274     result.addAttribute("quantization_info", quantAttr);
275   result.types.push_back(outputType);
276 }
277 
278 /// This builder is called on TOSA pad operator that needs to create its own
279 /// OptionalAttr quantization_attr parameter to scale the padding values
280 /// correctly.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)281 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
282                                     Type outputType, Value input,
283                                     Value paddings) {
284   result.addOperands({input, paddings});
285   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
286   if (quantAttr)
287     result.addAttribute("quantization_info", quantAttr);
288   result.types.push_back(outputType);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // TOSA Operator Definitions.
293 //===----------------------------------------------------------------------===//
294 
295 #define GET_OP_CLASSES
296 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
297