1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace quant;
24 
parseStorageType(DialectAsmParser & parser,bool & isSigned)25 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26   auto typeLoc = parser.getCurrentLocation();
27   IntegerType type;
28 
29   // Parse storage type (alpha_ident, integer_literal).
30   StringRef identifier;
31   unsigned storageTypeWidth = 0;
32   if (failed(parser.parseOptionalKeyword(&identifier))) {
33     // If we didn't parse a keyword, this must be a signed type.
34     if (parser.parseType(type))
35       return nullptr;
36     isSigned = true;
37     storageTypeWidth = type.getWidth();
38 
39     // Otherwise, this must be an unsigned integer (`u` integer-literal).
40   } else {
41     if (!identifier.consume_front("u")) {
42       parser.emitError(typeLoc, "illegal storage type prefix");
43       return nullptr;
44     }
45     if (identifier.getAsInteger(10, storageTypeWidth)) {
46       parser.emitError(typeLoc, "expected storage type width");
47       return nullptr;
48     }
49     isSigned = false;
50     type = parser.getBuilder().getIntegerType(storageTypeWidth);
51   }
52 
53   if (storageTypeWidth == 0 ||
54       storageTypeWidth > QuantizedType::MaxStorageBits) {
55     parser.emitError(typeLoc, "illegal storage type size: ")
56         << storageTypeWidth;
57     return nullptr;
58   }
59 
60   return type;
61 }
62 
parseStorageRange(DialectAsmParser & parser,IntegerType storageType,bool isSigned,int64_t & storageTypeMin,int64_t & storageTypeMax)63 static ParseResult parseStorageRange(DialectAsmParser &parser,
64                                      IntegerType storageType, bool isSigned,
65                                      int64_t &storageTypeMin,
66                                      int64_t &storageTypeMax) {
67   int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68       isSigned, storageType.getWidth());
69   int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70       isSigned, storageType.getWidth());
71   if (failed(parser.parseOptionalLess())) {
72     storageTypeMin = defaultIntegerMin;
73     storageTypeMax = defaultIntegerMax;
74     return success();
75   }
76 
77   // Explicit storage min and storage max.
78   llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79   if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
80       parser.getCurrentLocation(&maxLoc) ||
81       parser.parseInteger(storageTypeMax) || parser.parseGreater())
82     return failure();
83   if (storageTypeMin < defaultIntegerMin) {
84     return parser.emitError(minLoc, "illegal storage type minimum: ")
85            << storageTypeMin;
86   }
87   if (storageTypeMax > defaultIntegerMax) {
88     return parser.emitError(maxLoc, "illegal storage type maximum: ")
89            << storageTypeMax;
90   }
91   return success();
92 }
93 
parseExpressedTypeAndRange(DialectAsmParser & parser,double & min,double & max)94 static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
95                                             double &min, double &max) {
96   auto typeLoc = parser.getCurrentLocation();
97   FloatType type;
98 
99   if (failed(parser.parseType(type))) {
100     parser.emitError(typeLoc, "expecting float expressed type");
101     return nullptr;
102   }
103 
104   // Calibrated min and max values.
105   if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
106       parser.parseFloat(max) || parser.parseGreater()) {
107     parser.emitError(typeLoc, "calibrated values must be present");
108     return nullptr;
109   }
110   return type;
111 }
112 
113 /// Parses an AnyQuantizedType.
114 ///
115 ///   any ::= `any<` storage-spec (expressed-type-spec)?`>`
116 ///   storage-spec ::= storage-type (`<` storage-range `>`)?
117 ///   storage-range ::= integer-literal `:` integer-literal
118 ///   storage-type ::= (`i` | `u`) integer-literal
119 ///   expressed-type-spec ::= `:` `f` integer-literal
parseAnyType(DialectAsmParser & parser,Location loc)120 static Type parseAnyType(DialectAsmParser &parser, Location loc) {
121   IntegerType storageType;
122   FloatType expressedType;
123   unsigned typeFlags = 0;
124   int64_t storageTypeMin;
125   int64_t storageTypeMax;
126 
127   // Type specification.
128   if (parser.parseLess())
129     return nullptr;
130 
131   // Storage type.
132   bool isSigned = false;
133   storageType = parseStorageType(parser, isSigned);
134   if (!storageType) {
135     return nullptr;
136   }
137   if (isSigned) {
138     typeFlags |= QuantizationFlags::Signed;
139   }
140 
141   // Storage type range.
142   if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
143                         storageTypeMax)) {
144     return nullptr;
145   }
146 
147   // Optional expressed type.
148   if (succeeded(parser.parseOptionalColon())) {
149     if (parser.parseType(expressedType)) {
150       return nullptr;
151     }
152   }
153 
154   if (parser.parseGreater()) {
155     return nullptr;
156   }
157 
158   return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType,
159                                       storageTypeMin, storageTypeMax, loc);
160 }
161 
parseQuantParams(DialectAsmParser & parser,double & scale,int64_t & zeroPoint)162 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
163                                     int64_t &zeroPoint) {
164   // scale[:zeroPoint]?
165   // scale.
166   if (parser.parseFloat(scale))
167     return failure();
168 
169   // zero point.
170   zeroPoint = 0;
171   if (failed(parser.parseOptionalColon())) {
172     // Default zero point.
173     return success();
174   }
175 
176   return parser.parseInteger(zeroPoint);
177 }
178 
179 /// Parses a UniformQuantizedType.
180 ///
181 ///   uniform_type ::= uniform_per_layer
182 ///                  | uniform_per_axis
183 ///   uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
184 ///                          `,` scale-zero `>`
185 ///   uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
186 ///                        axis-spec `,` scale-zero-list `>`
187 ///   storage-spec ::= storage-type (`<` storage-range `>`)?
188 ///   storage-range ::= integer-literal `:` integer-literal
189 ///   storage-type ::= (`i` | `u`) integer-literal
190 ///   expressed-type-spec ::= `:` `f` integer-literal
191 ///   axis-spec ::= `:` integer-literal
192 ///   scale-zero ::= float-literal `:` integer-literal
193 ///   scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
parseUniformType(DialectAsmParser & parser,Location loc)194 static Type parseUniformType(DialectAsmParser &parser, Location loc) {
195   IntegerType storageType;
196   FloatType expressedType;
197   unsigned typeFlags = 0;
198   int64_t storageTypeMin;
199   int64_t storageTypeMax;
200   bool isPerAxis = false;
201   int32_t quantizedDimension;
202   SmallVector<double, 1> scales;
203   SmallVector<int64_t, 1> zeroPoints;
204 
205   // Type specification.
206   if (parser.parseLess()) {
207     return nullptr;
208   }
209 
210   // Storage type.
211   bool isSigned = false;
212   storageType = parseStorageType(parser, isSigned);
213   if (!storageType) {
214     return nullptr;
215   }
216   if (isSigned) {
217     typeFlags |= QuantizationFlags::Signed;
218   }
219 
220   // Storage type range.
221   if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
222                         storageTypeMax)) {
223     return nullptr;
224   }
225 
226   // Expressed type.
227   if (parser.parseColon() || parser.parseType(expressedType)) {
228     return nullptr;
229   }
230 
231   // Optionally parse quantized dimension for per-axis quantization.
232   if (succeeded(parser.parseOptionalColon())) {
233     if (parser.parseInteger(quantizedDimension))
234       return nullptr;
235     isPerAxis = true;
236   }
237 
238   // Comma leading into range_spec.
239   if (parser.parseComma()) {
240     return nullptr;
241   }
242 
243   // Parameter specification.
244   // For per-axis, ranges are in a {} delimitted list.
245   if (isPerAxis) {
246     if (parser.parseLBrace()) {
247       return nullptr;
248     }
249   }
250 
251   // Parse scales/zeroPoints.
252   llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
253   do {
254     scales.resize(scales.size() + 1);
255     zeroPoints.resize(zeroPoints.size() + 1);
256     if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
257       return nullptr;
258     }
259   } while (isPerAxis && succeeded(parser.parseOptionalComma()));
260 
261   if (isPerAxis) {
262     if (parser.parseRBrace()) {
263       return nullptr;
264     }
265   }
266 
267   if (parser.parseGreater()) {
268     return nullptr;
269   }
270 
271   if (!isPerAxis && scales.size() > 1) {
272     return (parser.emitError(scaleZPLoc,
273                              "multiple scales/zeroPoints provided, but "
274                              "quantizedDimension wasn't specified"),
275             nullptr);
276   }
277 
278   if (isPerAxis) {
279     ArrayRef<double> scalesRef(scales.begin(), scales.end());
280     ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
281     return UniformQuantizedPerAxisType::getChecked(
282         typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
283         quantizedDimension, storageTypeMin, storageTypeMax, loc);
284   }
285 
286   return UniformQuantizedType::getChecked(typeFlags, storageType, expressedType,
287                                           scales.front(), zeroPoints.front(),
288                                           storageTypeMin, storageTypeMax, loc);
289 }
290 
291 /// Parses an CalibratedQuantizedType.
292 ///
293 ///   calibrated ::= `calibrated<` expressed-spec `>`
294 ///   expressed-spec ::= expressed-type `<` calibrated-range `>`
295 ///   expressed-type ::= `f` integer-literal
296 ///   calibrated-range ::= float-literal `:` float-literal
parseCalibratedType(DialectAsmParser & parser,Location loc)297 static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
298   FloatType expressedType;
299   double min;
300   double max;
301 
302   // Type specification.
303   if (parser.parseLess())
304     return nullptr;
305 
306   // Expressed type.
307   expressedType = parseExpressedTypeAndRange(parser, min, max);
308   if (!expressedType) {
309     return nullptr;
310   }
311 
312   if (parser.parseGreater()) {
313     return nullptr;
314   }
315 
316   return CalibratedQuantizedType::getChecked(expressedType, min, max, loc);
317 }
318 
319 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const320 Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
321   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
322 
323   // All types start with an identifier that we switch on.
324   StringRef typeNameSpelling;
325   if (failed(parser.parseKeyword(&typeNameSpelling)))
326     return nullptr;
327 
328   if (typeNameSpelling == "uniform")
329     return parseUniformType(parser, loc);
330   if (typeNameSpelling == "any")
331     return parseAnyType(parser, loc);
332   if (typeNameSpelling == "calibrated")
333     return parseCalibratedType(parser, loc);
334 
335   parser.emitError(parser.getNameLoc(),
336                    "unknown quantized type " + typeNameSpelling);
337   return nullptr;
338 }
339 
printStorageType(QuantizedType type,DialectAsmPrinter & out)340 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
341   // storage type
342   unsigned storageWidth = type.getStorageTypeIntegralWidth();
343   bool isSigned = type.isSigned();
344   if (isSigned) {
345     out << "i" << storageWidth;
346   } else {
347     out << "u" << storageWidth;
348   }
349 
350   // storageTypeMin and storageTypeMax if not default.
351   int64_t defaultIntegerMin =
352       QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
353   int64_t defaultIntegerMax =
354       QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
355   if (defaultIntegerMin != type.getStorageTypeMin() ||
356       defaultIntegerMax != type.getStorageTypeMax()) {
357     out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
358         << ">";
359   }
360 }
361 
printQuantParams(double scale,int64_t zeroPoint,DialectAsmPrinter & out)362 static void printQuantParams(double scale, int64_t zeroPoint,
363                              DialectAsmPrinter &out) {
364   out << scale;
365   if (zeroPoint != 0) {
366     out << ":" << zeroPoint;
367   }
368 }
369 
370 /// Helper that prints a AnyQuantizedType.
printAnyQuantizedType(AnyQuantizedType type,DialectAsmPrinter & out)371 static void printAnyQuantizedType(AnyQuantizedType type,
372                                   DialectAsmPrinter &out) {
373   out << "any<";
374   printStorageType(type, out);
375   if (Type expressedType = type.getExpressedType()) {
376     out << ":" << expressedType;
377   }
378   out << ">";
379 }
380 
381 /// Helper that prints a UniformQuantizedType.
printUniformQuantizedType(UniformQuantizedType type,DialectAsmPrinter & out)382 static void printUniformQuantizedType(UniformQuantizedType type,
383                                       DialectAsmPrinter &out) {
384   out << "uniform<";
385   printStorageType(type, out);
386   out << ":" << type.getExpressedType() << ", ";
387 
388   // scheme specific parameters
389   printQuantParams(type.getScale(), type.getZeroPoint(), out);
390   out << ">";
391 }
392 
393 /// Helper that prints a UniformQuantizedPerAxisType.
printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,DialectAsmPrinter & out)394 static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
395                                              DialectAsmPrinter &out) {
396   out << "uniform<";
397   printStorageType(type, out);
398   out << ":" << type.getExpressedType() << ":";
399   out << type.getQuantizedDimension();
400   out << ", ";
401 
402   // scheme specific parameters
403   ArrayRef<double> scales = type.getScales();
404   ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
405   out << "{";
406   llvm::interleave(
407       llvm::seq<size_t>(0, scales.size()), out,
408       [&](size_t index) {
409         printQuantParams(scales[index], zeroPoints[index], out);
410       },
411       ",");
412   out << "}>";
413 }
414 
415 /// Helper that prints a CalibratedQuantizedType.
printCalibratedQuantizedType(CalibratedQuantizedType type,DialectAsmPrinter & out)416 static void printCalibratedQuantizedType(CalibratedQuantizedType type,
417                                          DialectAsmPrinter &out) {
418   out << "calibrated<" << type.getExpressedType();
419   out << "<" << type.getMin() << ":" << type.getMax() << ">";
420   out << ">";
421 }
422 
423 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const424 void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
425   if (auto anyType = type.dyn_cast<AnyQuantizedType>())
426     printAnyQuantizedType(anyType, os);
427   else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
428     printUniformQuantizedType(uniformType, os);
429   else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
430     printUniformQuantizedPerAxisType(perAxisType, os);
431   else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
432     printCalibratedQuantizedType(calibratedType, os);
433   else
434     llvm_unreachable("Unhandled quantized type");
435 }
436