1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantizeUtils.h"
11 #include "mlir/Dialect/Quant/UniformSupport.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "gmock/gmock.h"
15 #include "gtest/gtest.h"
16
17 using namespace mlir;
18 using namespace mlir::quant;
19
20 namespace {
21
22 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
23 class TestUniformQuantizedValueConverter
24 : public UniformQuantizedValueConverter {
25 public:
TestUniformQuantizedValueConverter(UniformQuantizedType type)26 TestUniformQuantizedValueConverter(UniformQuantizedType type)
27 : UniformQuantizedValueConverter(type), qtype(type) {}
quantizeFloatToInt(APFloat expressedValue) const28 APInt quantizeFloatToInt(APFloat expressedValue) const {
29 return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
30 }
31
32 private:
33 UniformQuantizedType qtype;
34 };
35
getTestFloatAttr(double value,MLIRContext * ctx)36 Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
37 return FloatAttr::get(FloatType::getF32(ctx), value);
38 }
39
40 template <typename ConcreteAttrClass, typename... Arg>
getTestElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape,Arg...value)41 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
42 Arg... value) {
43 auto eleType = FloatType::getF32(ctx);
44 ShapedType tensorType;
45 if (shape.size() == 1 && shape[0] == -1) {
46 tensorType = UnrankedTensorType::get(eleType);
47 } else {
48 tensorType = RankedTensorType::get(shape, eleType);
49 }
50 return ConcreteAttrClass::get(tensorType, value...);
51 }
52
getTestSparseElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape)53 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
54 ArrayRef<int64_t> shape) {
55 auto eleType = FloatType::getF32(ctx);
56 ShapedType tensorType;
57 if (shape.size() == 1 && shape[0] == -1) {
58 tensorType = UnrankedTensorType::get(eleType);
59 } else {
60 tensorType = RankedTensorType::get(shape, eleType);
61 }
62 auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
63 auto indices =
64 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
65 auto valuesType = RankedTensorType::get({1}, eleType);
66 auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
67 return SparseElementsAttr::get(tensorType, indices, values);
68 }
69
getTestQuantizedType(Type storageType,MLIRContext * ctx)70 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
71 return UniformQuantizedType::get(/*flags=*/false, storageType,
72 FloatType::getF32(ctx), /*scale=*/1.0,
73 /*zeroPoint=*/0, /*storageTypeMin=*/0,
74 /*storageTypeMax=*/255);
75 }
76
TEST(QuantizationUtilsTest,convertFloatAttrUniform)77 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
78 MLIRContext ctx;
79 ctx.getOrLoadDialect<QuantizationDialect>();
80 IntegerType convertedType = IntegerType::get(8, &ctx);
81 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
82 TestUniformQuantizedValueConverter converter(quantizedType);
83
84 auto realValue = getTestFloatAttr(1.0, &ctx);
85 Type typeResult;
86 auto valueResult =
87 quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
88
89 EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
90 EXPECT_EQ(
91 valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
92 convertedType.getWidth());
93 }
94
TEST(QuantizationUtilsTest,convertRankedDenseAttrUniform)95 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
96 MLIRContext ctx;
97 ctx.getOrLoadDialect<QuantizationDialect>();
98 IntegerType convertedType = IntegerType::get(8, &ctx);
99 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
100 TestUniformQuantizedValueConverter converter(quantizedType);
101 auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
102 &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
103
104 Type returnedType;
105 auto returnedValue =
106 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
107
108 // Check Elements attribute shape and kind are not changed.
109 auto tensorType = returnedType.cast<TensorType>();
110 auto expectedTensorType = realValue.getType().cast<TensorType>();
111 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
112 EXPECT_EQ(tensorType.getElementType(), convertedType);
113 EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
114
115 // Check Elements attribute element value is expected.
116 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
117 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
118 }
119
TEST(QuantizationUtilsTest,convertRankedSplatAttrUniform)120 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
121 MLIRContext ctx;
122 ctx.getOrLoadDialect<QuantizationDialect>();
123 IntegerType convertedType = IntegerType::get(8, &ctx);
124 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
125 TestUniformQuantizedValueConverter converter(quantizedType);
126 auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
127 &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
128
129 Type returnedType;
130 auto returnedValue =
131 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
132
133 // Check Elements attribute shape and kind are not changed.
134 auto tensorType = returnedType.cast<TensorType>();
135 auto expectedTensorType = realValue.getType().cast<TensorType>();
136 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
137 EXPECT_EQ(tensorType.getElementType(), convertedType);
138 EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
139
140 // Check Elements attribute element value is expected.
141 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
142 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
143 }
144
TEST(QuantizationUtilsTest,convertRankedSparseAttrUniform)145 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
146 MLIRContext ctx;
147 ctx.getOrLoadDialect<QuantizationDialect>();
148 IntegerType convertedType = IntegerType::get(8, &ctx);
149 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
150 TestUniformQuantizedValueConverter converter(quantizedType);
151 auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
152
153 Type returnedType;
154 auto returnedValue =
155 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
156
157 // Check Elements attribute shape and kind are not changed.
158 auto tensorType = returnedType.cast<TensorType>();
159 auto expectedTensorType = realValue.getType().cast<TensorType>();
160 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
161 EXPECT_EQ(tensorType.getElementType(), convertedType);
162 EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
163
164 // Check Elements attribute element value is expected.
165 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
166 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
167 }
168
169 } // end namespace
170