1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_QUANT_TYPES_H_
10 #define MLIR_DIALECT_QUANT_QUANT_TYPES_H_
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 
23 class QuantizedIntegerType;
24 
25 namespace detail {
26 
27 struct QuantizedTypeStorage;
28 struct AnyQuantizedTypeStorage;
29 struct UniformQuantizedTypeStorage;
30 struct UniformQuantizedPerAxisTypeStorage;
31 struct CalibratedQuantizedTypeStorage;
32 
33 } // namespace detail
34 
35 /// Enumeration of bit-mapped flags related to quantized types.
36 namespace QuantizationFlags {
37 enum FlagValue {
38   // Indicates that the storage type should be interpreted as a signed
39   // integer. The default is to interpret it as an unsigned value.
40   Signed = 1,
41 };
42 } // namespace QuantizationFlags
43 
44 /// Base class for all quantized types known to this dialect.
45 /// All quantized types have:
46 ///   - storageType: The (narrower) numeric type that is being used to
47 ///     approximate some expressed type.
48 ///   - expressedType: The type that is being approximated.
49 ///
50 /// The base class provides generic support for manipulating the types based
51 /// on these fields.
52 class QuantizedType : public Type {
53 public:
54   using ImplType = detail::QuantizedTypeStorage;
55   using Type::Type;
56 
57   /// The maximum number of bits supported for storage types.
58   static constexpr unsigned MaxStorageBits = 32;
59 
60   static LogicalResult
61   verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
62                                Type expressedType, int64_t storageTypeMin,
63                                int64_t storageTypeMax);
64 
65   /// Support method to enable LLVM-style type casting.
66   static bool classof(Type type);
67 
68   /// Gets the minimum possible stored by a storageType. storageTypeMin must
69   /// be greater than or equal to this value.
getDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)70   static int64_t getDefaultMinimumForInteger(bool isSigned,
71                                              unsigned integralWidth) {
72     if (isSigned) {
73       return llvm::minIntN(integralWidth);
74     }
75     return 0;
76   }
77 
78   /// Gets the maximum possible stored by a storageType. storageTypeMax must
79   /// be less than or equal to this value.
getDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)80   static int64_t getDefaultMaximumForInteger(bool isSigned,
81                                              unsigned integralWidth) {
82     if (isSigned) {
83       return llvm::maxIntN(integralWidth);
84     }
85     return llvm::maxUIntN(integralWidth);
86   }
87 
88   /// Gets the original expressed type that this quantized type approximates.
89   /// Note that this presumes that the quantized type was always derived from
90   /// a floating point type, which in the broadest definition, is not true (i.e.
91   /// it could be some form of integral, fixed type or affine type in its own
92   /// right); however, at the high level, no examples of such usage are
93   /// presently known and the restriction serves some useful purposes (such as
94   /// always being able to reverse a transformation or measure error). In most
95   /// cases, this will be f32.
96   Type getExpressedType() const;
97 
98   /// Gets the flags associated with this type. Typically a more specific
99   /// accessor is appropriate.
100   unsigned getFlags() const;
101 
102   // Convenience helpers.
103   /// Whether the storage type should be interpreted as a signed quantity
104   /// (true) or an unsigned value (false).
isSigned()105   bool isSigned() const {
106     return (getFlags() & QuantizationFlags::Signed) ==
107            QuantizationFlags::Signed;
108   }
109 
110   /// Gets the underlying type used for to store values. Note that this may
111   /// be signed or unsigned. Use the isSigned() accessor to differentiate.
112   Type getStorageType() const;
113 
114   /// The minimum value that storageType can take.
115   int64_t getStorageTypeMin() const;
116 
117   /// The maximum value that storageType can take.
118   int64_t getStorageTypeMax() const;
119 
120   /// Gets the integral bit width that the underlying storage type can exactly
121   /// represent. For integral storage types, this will just be their width.
122   unsigned getStorageTypeIntegralWidth() const;
123 
124   /// Returns whether the candidateExpressedType is a match for this
125   /// QuantizedType. This will be true if the candidate type is either a
126   /// primitive type or a container type whose element type equals this
127   /// QuantizedType's expressed type.
128   /// Examples of compatible candidateExpressedType:
129   ///   !quant.uniform<i8:f32, 1.0> =~ f32
130   ///   !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
131   bool isCompatibleExpressedType(Type candidateExpressedType);
132 
133   /// Returns the element type as a QuantizedType or nullptr if it is not
134   /// a quantized type. If the type is primitive, returns that. If it is a
135   /// container (vector/tensor), return the element type.
136   /// Examples:
137   ///   !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
138   ///   tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
139   static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
140 
141   /// Casts from a type based on the storageType to a corresponding type based
142   /// on this type (returns nullptr if the cast is not valid).
143   /// Examples:
144   ///   i8 -> !quant.uniform<i8:f32, 1.0>
145   ///   tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
146   ///   vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
147   Type castFromStorageType(Type candidateType);
148 
149   /// Casts from a type based on a QuantizedType to a corresponding type based
150   /// on the storageType (returns nullptr if the cast is not valid).
151   /// This is the inverse of castFromStorageType().
152   static Type castToStorageType(Type quantizedType);
153 
154   /// Casts from a type based on the expressedType to a corresponding type based
155   /// on this type (returns nullptr if the cast is not valid).
156   /// Examples:
157   ///   f32 -> !quant.uniform<i8:f32, 1.0>
158   ///   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
159   ///   vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
160   Type castFromExpressedType(Type candidateType);
161 
162   /// Casts from a type based on QuantizedType to a corresponding type based
163   /// on the expressedType (returns nullptr if the cast is not valid).
164   /// This is the inverse of castFromExpressedType.
165   static Type castToExpressedType(Type quantizedType);
166 
167   /// Casts from a type based on the expressedType to the equivalent type
168   /// based on storageType by way of this QuantizedType. Equivalent to:
169   ///   QuantizedType::castToStorageType(castFromExpressedType(candidateType))
170   /// (but with validity checks).
171   /// Example (for this = !quant.uniform<i8:f32, 1.0>):
172   ///   tensor<4xf32> -> tensor<4xi8>
173   Type castExpressedToStorageType(Type candidateType);
174 
175 private:
176   /// Hide the following methods inherited from `Type`. It is almost certainly
177   /// a bug to call them from a `QuantizedType` object. Users should call
178   /// `getStorageType` or `getExpressedType` to get the underlying types
179   /// they want to inspect.
180   using Type::isBF16;
181   using Type::isF16;
182   using Type::isF32;
183   using Type::isF64;
184   using Type::isIndex;
185   using Type::isInteger;
186 };
187 
188 /// A quantized type that maps storage to/from expressed types in an
189 /// unspecified way.
190 ///
191 /// Typical syntax:
192 ///   quant.any<i8:f32>
193 ///   quant.any<i8>
194 ///   quant.any<i8<-16,15>>
195 ///
196 /// Note that for the any type, the expressed type is optional.
197 class AnyQuantizedType
198     : public Type::TypeBase<AnyQuantizedType, QuantizedType,
199                             detail::AnyQuantizedTypeStorage> {
200 public:
201   using Base::Base;
202 
203   /// Gets an instance of the type with all parameters specified but not
204   /// checked.
205   static AnyQuantizedType get(unsigned flags, Type storageType,
206                               Type expressedType, int64_t storageTypeMin,
207                               int64_t storageTypeMax);
208 
209   /// Gets an instance of the type with all specified parameters checked.
210   /// Returns a nullptr convertible type on failure.
211   static AnyQuantizedType getChecked(unsigned flags, Type storageType,
212                                      Type expressedType, int64_t storageTypeMin,
213                                      int64_t storageTypeMax, Location location);
214 
215   /// Verifies construction invariants and issues errors/warnings.
216   static LogicalResult
217   verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
218                                Type expressedType, int64_t storageTypeMin,
219                                int64_t storageTypeMax);
220 };
221 
222 /// Represents a family of uniform, quantized types.
223 ///
224 /// Each instance of this type expresses a mapping between real values (most
225 /// often expressed in floating point f32) and quantized values (either fixed
226 /// point or affine).
227 ///
228 /// The relationship is:
229 ///     real_value = scale * (quantized_value - zero_point)
230 ///
231 /// It is used as part of high level graph transformations that have the goal
232 /// of re-expressing parts of a computation in terms of this common form for
233 /// more efficient execution at runtime. In addition, it is designed to be
234 /// expressive enough to facilitate lowering to precise types and operations
235 /// in target hardware.
236 ///
237 /// As a high-level type, focused on intermediate passes, this type holds
238 /// opinions consistent with high-level usage. If lowering math kernels below
239 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
240 /// instruction sets), it is expected that the information expressed here
241 /// will be used to drive low level codegen and target specific type selection,
242 /// but this type will likely be erased in the process.
243 ///
244 /// Syntax synopsis:
245 ///   Per-layer, all parameters expressed:
246 ///     !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
247 ///   Per-layer, optional parameters omitted:
248 ///     !quant<uniform[StorageType]{Scale}>
249 ///
250 ///   StorageType: 'i'|'u' NumBits
251 ///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
252 ///   Scale: A legal double value
253 ///   ZeroPoint: An integer value
254 class UniformQuantizedType
255     : public Type::TypeBase<UniformQuantizedType, QuantizedType,
256                             detail::UniformQuantizedTypeStorage> {
257 public:
258   using Base::Base;
259 
260   /// Gets an instance of the type with all parameters specified but not
261   /// checked.
262   static UniformQuantizedType get(unsigned flags, Type storageType,
263                                   Type expressedType, double scale,
264                                   int64_t zeroPoint, int64_t storageTypeMin,
265                                   int64_t storageTypeMax);
266 
267   /// Gets an instance of the type with all specified parameters checked.
268   /// Returns a nullptr convertible type on failure.
269   static UniformQuantizedType
270   getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
271              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
272              Location location);
273 
274   /// Verifies construction invariants and issues errors/warnings.
275   static LogicalResult
276   verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
277                                Type expressedType, double scale,
278                                int64_t zeroPoint, int64_t storageTypeMin,
279                                int64_t storageTypeMax);
280 
281   /// Gets the scale term. The scale designates the difference between the real
282   /// values corresponding to consecutive quantized values differing by 1.
283   double getScale() const;
284 
285   /// Gets the storage value corresponding to the real value 0 in the affine
286   /// equation.
287   int64_t getZeroPoint() const;
288 
289   // Fixed point values are real numbers divided by a scale.
290   // Currently, only signed storage types are treated as fixed point.
291   // A fixed point value can be obtained from an affine value by subtracting
292   // the zeroPoint.
293   // In the future, this may be explicit versus implied by type and zeroPoint.
isFixedPoint()294   bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
295 };
296 
297 /// Represents per-axis (also known as per-channel quantization).
298 ///
299 /// Syntax synopsis:
300 ///   Per-axis, all parameters expressed:
301 ///     !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
302 ///   Per-axis, optional parameters omitted:
303 ///     !quant<uniform[StorageType]{Scale}>
304 ///
305 ///   StorageType: 'i'|'u' NumBits
306 ///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
307 ///   QuantizedDim: An integer value
308 ///   QuantParams: (Scale ':' ZeroPoint)+
309 ///   Scale: A legal double value
310 ///   ZeroPoint: An integer value
311 class UniformQuantizedPerAxisType
312     : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
313                             detail::UniformQuantizedPerAxisTypeStorage> {
314 public:
315   using Base::Base;
316 
317   /// Gets an instance of the type with all parameters specified but not
318   /// checked.
319   static UniformQuantizedPerAxisType
320   get(unsigned flags, Type storageType, Type expressedType,
321       ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
322       int32_t quantizedDimension, int64_t storageTypeMin,
323       int64_t storageTypeMax);
324 
325   /// Gets an instance of the type with all specified parameters checked.
326   /// Returns a nullptr convertible type on failure.
327   static UniformQuantizedPerAxisType
328   getChecked(unsigned flags, Type storageType, Type expressedType,
329              ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
330              int32_t quantizedDimension, int64_t storageTypeMin,
331              int64_t storageTypeMax, Location location);
332 
333   /// Verifies construction invariants and issues errors/warnings.
334   static LogicalResult
335   verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
336                                Type expressedType, ArrayRef<double> scales,
337                                ArrayRef<int64_t> zeroPoints,
338                                int32_t quantizedDimension,
339                                int64_t storageTypeMin, int64_t storageTypeMax);
340 
341   /// Gets the quantization scales. The scales designate the difference between
342   /// the real values corresponding to consecutive quantized values differing
343   /// by 1. The ith scale corresponds to the ith slice in the
344   /// quantized_dimension.
345   ArrayRef<double> getScales() const;
346 
347   /// Gets the storage values corresponding to the real value 0 in the affine
348   /// equation. The ith zero point corresponds to the ith slice in the
349   /// quantized_dimension.
350   ArrayRef<int64_t> getZeroPoints() const;
351 
352   /// Specifies the dimension of the Tensor's shape that the scales and
353   /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
354   /// with quantization params:
355   ///   scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
356   /// will be quantized across the second dimension of t.
357   ///   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
358   ///   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
359   ///   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
360   int32_t getQuantizedDimension() const;
361 
362   /// Fixed point values are real numbers divided by a scale.
363   /// Currently, only signed storage types are treated as fixed point.
364   /// A fixed point value can be obtained from an affine value by subtracting
365   /// the zeroPoint.
366   /// In the future, this may be explicit versus implied by type and zeroPoint.
isFixedPoint()367   bool isFixedPoint() const {
368     if (!isSigned())
369       return false;
370     return llvm::all_of(getZeroPoints(),
371                         [](int64_t zeroPoint) { return zeroPoint != 0; });
372   }
373 };
374 
375 /// A quantized type that infers its range from given min/max values.
376 ///
377 /// Typical syntax:
378 ///   quant.calibrated<f32<-0.922,0.981>>
379 class CalibratedQuantizedType
380     : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
381                             detail::CalibratedQuantizedTypeStorage> {
382 public:
383   using Base::Base;
384 
385   /// Gets an instance of the type with all parameters specified but not
386   /// checked.
387   static CalibratedQuantizedType get(Type expressedType, double min,
388                                      double max);
389 
390   /// Gets an instance of the type with all specified parameters checked.
391   /// Returns a nullptr convertible type on failure.
392   static CalibratedQuantizedType getChecked(Type expressedType, double min,
393                                             double max, Location location);
394 
395   /// Verifies construction invariants and issues errors/warnings.
396   static LogicalResult verifyConstructionInvariants(Location loc,
397                                                     Type expressedType,
398                                                     double min, double max);
399   double getMin() const;
400   double getMax() const;
401 };
402 
403 } // namespace quant
404 } // namespace mlir
405 
406 #endif // MLIR_DIALECT_QUANT_QUANT_TYPES_H_
407