1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11 
12 using namespace mlir;
13 using namespace mlir::quant;
14 
getDefaultStorageParams(unsigned numBits,bool narrowRange,bool isSigned,MLIRContext * ctx,Type & storageType,int64_t & qmin,int64_t & qmax)15 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16                                     bool isSigned, MLIRContext *ctx,
17                                     Type &storageType, int64_t &qmin,
18                                     int64_t &qmax) {
19   // Hard-coded type mapping from TFLite.
20   if (numBits <= 8) {
21     storageType = IntegerType::get(8, ctx);
22     if (isSigned) {
23       qmin = -128;
24       qmax = 127;
25     } else {
26       qmin = 0;
27       qmax = 255;
28     }
29   } else if (numBits <= 16) {
30     storageType = IntegerType::get(16, ctx);
31     if (isSigned) {
32       qmin = -32768;
33       qmax = 32767;
34     } else {
35       qmin = 0;
36       qmax = 65535;
37     }
38   } else if (numBits <= 32) {
39     storageType = IntegerType::get(32, ctx);
40     if (isSigned) {
41       qmin = std::numeric_limits<int32_t>::min();
42       qmax = std::numeric_limits<int32_t>::max();
43     } else {
44       qmin = std::numeric_limits<uint32_t>::min();
45       qmax = std::numeric_limits<uint32_t>::max();
46     }
47   } else {
48     return true;
49   }
50 
51   // Handle narrowRange.
52   if (narrowRange) {
53     qmin += 1;
54   }
55   return false;
56 }
57 
58 // This is a specific implementation of nudging:
59 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61 // point is derived from the shifted range, and the scale isn't changed. As
62 // a consequence some values, which are supposed in the original [rmin, rmax]
63 // range will be outside the shifted range and be clamped during quantization.
64 // TODO: we should nudge the scale as well, but that requires the
65 // fake quant op used in the training to use the nudged scale as well.
getNudgedScaleAndZeroPoint(int64_t qmin,int64_t qmax,double rmin,double rmax,double & scale,int64_t & nudgedZeroPoint)66 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67                                        double rmax, double &scale,
68                                        int64_t &nudgedZeroPoint) {
69   // Determine the scale.
70   const double qminDouble = qmin;
71   const double qmaxDouble = qmax;
72   scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73 
74   // Zero point computation.
75   // In float, solve the affine equation for any known pair
76   // (real value, corresponding quantized value), of which, two such pairs
77   // are known: (rmin, qmin), (rmax, qmax).
78   // The arithmetic error on the zero point computed from either pair will be
79   // roughly machine_epsilon * (sum of absolute values of terms).
80   // Use the variant that adds the smaller error.
81   const double zeroPointFromMin = qminDouble - rmin / scale;
82   const double zeroPointFromMinError =
83       std::abs(qminDouble) + std::abs(rmin / scale);
84   const double zeroPointFromMax = qmaxDouble - rmax / scale;
85   const double zeroPointFromMaxError =
86       std::abs(qmaxDouble) + std::abs(rmax / scale);
87 
88   const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89                                      ? zeroPointFromMin
90                                      : zeroPointFromMax;
91 
92   // Now nudge the zero point to be an integer.
93   nudgedZeroPoint = 0;
94   if (zeroPointDouble < qminDouble) {
95     nudgedZeroPoint = qmin;
96   } else if (zeroPointDouble > qmaxDouble) {
97     nudgedZeroPoint = qmax;
98   } else {
99     nudgedZeroPoint = round(zeroPointDouble);
100   }
101 
102   // By construction, the nudged zero point should always be in range.
103   assert(nudgedZeroPoint >= qmin);
104   assert(nudgedZeroPoint <= qmax);
105 }
106 
107 UniformQuantizedType
fakeQuantAttrsToType(Location loc,unsigned numBits,double rmin,double rmax,bool narrowRange,Type expressedType,bool isSigned)108 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109                                   double rmax, bool narrowRange,
110                                   Type expressedType, bool isSigned) {
111   MLIRContext *ctx = expressedType.getContext();
112   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113   Type storageType;
114   int64_t qmin;
115   int64_t qmax;
116   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117                               qmin, qmax)) {
118     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
119             nullptr);
120   }
121 
122   // Special case where min/max is close enough. The tensor contents are all
123   // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124   // points and dequantized to 0.0.
125   if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
126     return UniformQuantizedType::getChecked(flags, storageType, expressedType,
127                                             1.0, qmin, qmin, qmax, loc);
128   }
129 
130   double scale;
131   int64_t nudgedZeroPoint;
132   getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133 
134   return UniformQuantizedType::getChecked(flags, storageType, expressedType,
135                                           scale, nudgedZeroPoint, qmin, qmax,
136                                           loc);
137 }
138 
fakeQuantAttrsToType(Location loc,unsigned numBits,int32_t quantizedDimension,ArrayRef<double> rmins,ArrayRef<double> rmaxs,bool narrowRange,Type expressedType,bool isSigned)139 UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
140     Location loc, unsigned numBits, int32_t quantizedDimension,
141     ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142     Type expressedType, bool isSigned) {
143   size_t axis_size = rmins.size();
144   if (axis_size != rmaxs.size()) {
145     return (emitError(loc, "mismatched per-axis min and max size: ")
146                 << axis_size << " vs. " << rmaxs.size(),
147             nullptr);
148   }
149 
150   MLIRContext *ctx = expressedType.getContext();
151   Type storageType;
152   int64_t qmin;
153   int64_t qmax;
154   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155                               qmin, qmax)) {
156     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
157             nullptr);
158   }
159 
160   SmallVector<double, 4> scales;
161   SmallVector<int64_t, 4> zeroPoints;
162   scales.reserve(axis_size);
163   zeroPoints.reserve(axis_size);
164   for (size_t axis = 0; axis != axis_size; ++axis) {
165     double rmin = rmins[axis];
166     double rmax = rmaxs[axis];
167     if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168       scales.push_back(1.0);
169       zeroPoints.push_back(qmin);
170       continue;
171     }
172 
173     double scale;
174     int64_t nudgedZeroPoint;
175     getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176     scales.push_back(scale);
177     zeroPoints.push_back(nudgedZeroPoint);
178   }
179 
180   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
181   return UniformQuantizedPerAxisType::getChecked(
182       flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
183       qmin, qmax, loc);
184 }
185