1 //===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantizeUtils.h"
10 #include "mlir/Dialect/Quant/UniformSupport.h"
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 
14 using namespace mlir;
15 using namespace mlir::quant;
16 
17 /// Converts a possible primitive, real expressed value attribute to a
18 /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
19 /// quantizedElementType is the QuantizedType that describes the expressed
20 /// origValue.
21 /// Returns a converter Attribute or nullptr if conversion is not possible.
convertPrimitiveValueAttr(Attribute origRealValue,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)22 static Attribute convertPrimitiveValueAttr(
23     Attribute origRealValue, QuantizedType quantizedElementType,
24     const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
25   if (origRealValue.isa<FloatAttr>()) {
26     FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
27     outConvertedType = quantizedElementType.getStorageType();
28     return IntegerAttr::get(quantizedElementType.getStorageType(),
29                             converter.quantizeFloatToInt(floatAttr.getValue()));
30   }
31 
32   return nullptr;
33 }
34 
35 /// Converts a real expressed DenseFPElementsAttr to a corresponding
36 /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
37 /// storage values assuming the given quantizedElementType and converter.
38 static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)39 convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
40                            QuantizedType quantizedElementType,
41                            const UniformQuantizedValueConverter &converter) {
42   // Convert to corresponding quantized value attributes.
43   SmallVector<APInt, 8> quantValues;
44   if (realFPElementsAttr.isSplat()) {
45     quantValues.push_back(
46         converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
47   } else {
48     quantValues.reserve(realFPElementsAttr.getNumElements());
49     for (APFloat realVal : realFPElementsAttr) {
50       quantValues.push_back(converter.quantizeFloatToInt(realVal));
51     }
52   }
53 
54   // Cast from an expressed-type-based type to storage-type-based type,
55   // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
56   ShapedType newDenseType =
57       quantizedElementType
58           .castExpressedToStorageType(realFPElementsAttr.getType())
59           .dyn_cast_or_null<ShapedType>();
60   if (!newDenseType) {
61     return nullptr;
62   }
63   return DenseIntElementsAttr::get(newDenseType, quantValues);
64 }
65 
66 /// Converts a real expressed SplatElementsAttr to a corresponding
67 /// SplatElementsAttr containing quantized storage values assuming the given
68 /// quantizedElementType and converter.
69 static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)70 convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
71                           QuantizedType quantizedElementType,
72                           const UniformQuantizedValueConverter &converter) {
73   DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
74   if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
75     return nullptr;
76   }
77   DenseElementsAttr quantDenseAttr =
78       convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
79                                  quantizedElementType, converter);
80   if (!quantDenseAttr) {
81     return nullptr;
82   }
83 
84   // Cast from an expressed-type-based type to storage-type-based type,
85   // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
86   ShapedType newSparseType =
87       quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
88           .dyn_cast_or_null<ShapedType>();
89   if (!newSparseType) {
90     return nullptr;
91   }
92   return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
93                                  quantDenseAttr);
94 }
95 
96 /// Converts a real expressed Attribute to a corresponding Attribute containing
97 /// quantized storage values assuming the given uniform quantizedElementType and
98 /// converter.
quantizeAttrUniform(Attribute realValue,UniformQuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)99 Attribute mlir::quant::quantizeAttrUniform(
100     Attribute realValue, UniformQuantizedType quantizedElementType,
101     const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
102   // Fork to handle different variants of constants supported.
103   if (realValue.isa<DenseFPElementsAttr>()) {
104     // Dense tensor or vector constant.
105     auto converted = convertDenseFPElementsAttr(
106         realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
107     outConvertedType = converted.getType();
108     return converted;
109   } else if (realValue.isa<SparseElementsAttr>()) {
110     // Sparse tensor or vector constant.
111     auto converted = convertSparseElementsAttr(
112         realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
113     outConvertedType = converted.getType();
114     return converted;
115   } else {
116     // Nothing else matched: try to convert a primitive.
117     return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
118                                      outConvertedType);
119   }
120 }
121 
122 /// Convert an attribute from a type based on
123 /// quantizedElementType.getExpressedType() to one based on
124 /// quantizedElementType.getStorageType().
125 /// Returns nullptr if the conversion is not supported.
126 /// On success, stores the converted type in outConvertedType.
quantizeAttr(Attribute realValue,QuantizedType quantizedElementType,Type & outConvertedType)127 Attribute mlir::quant::quantizeAttr(Attribute realValue,
128                                     QuantizedType quantizedElementType,
129                                     Type &outConvertedType) {
130   if (auto uniformQuantized =
131           quantizedElementType.dyn_cast<UniformQuantizedType>()) {
132     UniformQuantizedValueConverter converter(uniformQuantized);
133     return quantizeAttrUniform(realValue, uniformQuantized, converter,
134                                outConvertedType);
135 
136   } else if (auto uniformQuantizedPerAxis =
137                  quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
138     UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
139     auto converted = converter.convert(realValue);
140     // TODO: why we need this outConvertedType? remove it?
141     if (converted) {
142       outConvertedType = converted.getType();
143     }
144     return converted;
145   } else {
146     return nullptr;
147   }
148 }
149