//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // \file // This file implements the TOSA Specification: // https://developer.mlplatform.org/w/tosa/ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" using namespace mlir; using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Tosa dialect structs and interface includes. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" namespace { //===----------------------------------------------------------------------===// // Dialect Function Inliner Interface. //===----------------------------------------------------------------------===// struct TosaInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks. //===--------------------------------------------------------------------===// /// All operations can be inlined by default. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, BlockAndValueMapping &map) const final { return true; } /// All regions with If and While parent operators can be inlined. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &map) const final { return (isa(dest->getParentOp()) || isa(dest->getParentOp())); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // TOSA control flow support. //===----------------------------------------------------------------------===// /// Returns the while loop body. Region &tosa::WhileOp::getLoopBody() { return body(); } bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { return !body().isAncestor(value.getParentRegion()); } LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { if (ops.empty()) return success(); Operation *tosaWhileOp = this->getOperation(); for (auto *op : ops) op->moveBefore(tosaWhileOp); return success(); } //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// void TosaDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" >(); addInterfaces(); } Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (value.isa()) return builder.create(loc, type, value.cast()); return nullptr; } //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// OpFoldResult ConstOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); return valueAttr(); } //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = op.input().getType().template dyn_cast(); auto weightType = op.weight().getType().template dyn_cast(); // Must be ranked tensor types if (!inputType || !weightType) return failure(); auto inputQType = inputType.getElementType().template isa(); auto weightQType = weightType.getElementType().template isa(); // Either both must be quantized or both unquantized. if (inputQType != weightQType) return failure(); // Quantized type must have constructed the quantizationattr, and unquantized // types should not have a quantizationattr. if ((inputQType && !op.quantization_info()) || (!inputQType && op.quantization_info())) return failure(); return success(); } //===----------------------------------------------------------------------===// // TOSA Operator Quantization Builders. //===----------------------------------------------------------------------===// /// This builder is called on all convolution operators except TransposeConv, /// which has specialized output shape semantics. The builder also defines the /// bitwidth of the output given the bit width of the input & weight content. static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, ArrayAttr pad, ArrayAttr stride, ArrayAttr dilation) { result.addOperands({input, weight, bias}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// Handles tosa.transpose_conv2d which has outpad and output shape attributes. static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, ArrayAttr outpad, ArrayAttr stride, ArrayAttr dilation, ArrayAttr outputShape) { result.addOperands({input, weight, bias}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); result.addAttribute("out_shape", outputShape); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// The tosa.fully_connected op has its own builder as it does not have /// strides/dilation/padding. static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias) { result.addOperands({input, weight, bias}); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// The tosa.matmul op is also intended to be generated where a fully_connected /// op must be constructed where the weight is not a constant. In this case, /// the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b) { result.addOperands({a, b}); auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); auto inputType = a.getType().dyn_cast(); assert(inputType && "Input must be a ranked tensor type!"); auto inputQType = inputType.getElementType() .dyn_cast(); assert(inputQType && "Tensor must have quantized datatype!"); unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); auto outputShapedType = outputType.dyn_cast(); assert(outputShapedType && "Output must be a ranked tensor type"); auto outputShape = outputShapedType.getShape(); IntegerType accElementType; if (inputBits == 16) accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); auto accType = RankedTensorType::get(outputShape, accElementType); result.addTypes(accType); } else { result.addTypes(outputType); } } /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr /// but avg_pool operator has its own builder as it has additional parameters /// not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, ArrayAttr kernel, ArrayAttr stride, ArrayAttr pad) { result.addOperands(input); result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on single-parameter unary operators that have scale /// relationship between their input and output, expressed by the /// UnaryOpQuantizationAttr. static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input) { result.addOperands(input); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on TOSA pad operator that needs to create its own /// OptionalAttr quantization_attr parameter to scale the padding values /// correctly. static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings) { result.addOperands({input, paddings}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"