1//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the operation definition file for Quantization.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef DIALECT_QUANT_QUANT_OPS_
14#define DIALECT_QUANT_QUANT_OPS_
15
16include "mlir/Dialect/Quant/QuantOpsBase.td"
17include "mlir/Interfaces/SideEffectInterfaces.td"
18
19//===----------------------------------------------------------------------===//
20// Base classes
21//===----------------------------------------------------------------------===//
22
23class quant_Op<string mnemonic, list<OpTrait> traits> :
24    Op<Quantization_Dialect, mnemonic, traits>;
25
26//===----------------------------------------------------------------------===//
27// Quantization casts
28//===----------------------------------------------------------------------===//
29// A QuantizeCast (qcast) represents a potential type shift from a quantizable
30// type to a quantized type.
31//
32// At runtime, a qcast will apply the transformation expressed by its
33// operand and result type. For flexibility during transformation, it is also
34// possible to have a qcast that performs no transformation (both its
35// operand and result type are quantizable).
36//
37// A qcast will typically originate from either:
38//   a) An expressed or implied constraint in the source dialect which signals
39//      that a certain level of quantization is possible or required.
40//   b) An inference made by a quantization algorithm indicating that a
41//      quantized representation may be acceptable.
42//
43// Especially early in transformation, it is common to have pairs of
44// qcast/dcast at points where a transition to a quantized type is
45// required. In addition, it is also common to have an identity qcast
46// (where the operand and result type are not quantized) at all points where
47// it is legal to use a quantized representation (but is not known to be
48// acceptable).
49def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> {
50  let arguments = (ins quant_RealValueType:$arg);
51  let results = (outs quant_RealValueType);
52}
53
54// A DequantizeCast op (dcast) represents the inverse of a qcast,
55// converting back from a quantized to quantizable (expressed) type.
56//
57// Like qcasts, a dcast is allowed to have both its operand and result
58// as non quantized types. This facilitates transformations and marks edges
59// where the computation must be carried out in the expressed type.
60//
61// Especially early in transformation, it is common to have dcasts on
62// all operands to ops that must operate with the expressed type (typically
63// math ops prior to lowering to target-specific, quantized kernels).
64def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
65  let arguments = (ins quant_RealValueType:$arg);
66  let results = (outs quant_RealValueType);
67}
68
69// A StorageCast (scast) represents a cast from or to a type based on the
70// storage type and a type based on a corresponding quantized type.
71//
72// This op exists to ensure type coherency for between parts of the computation
73// which are operating directly on an underlying storage type and those which
74// operate on quantized values.
75//
76// Examples from storage to quantized type:
77//   i8 -> !quant<"uniform[i8:f32]{1.0}">
78//   tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
79//   vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
80def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
81  let arguments = (ins quant_RealOrStorageValueType:$arg);
82  let results = (outs quant_RealOrStorageValueType);
83  let hasFolder = 1;
84}
85
86// A QuantizeRegion (region) represents a quantization unit which wraps
87// high-precision ops with quantization specifications for all the inputs
88// and outputs. Some quantization specifications can be undetermined and
89// derived from other ports by the target specification of the kernel.
90def quant_QuantizeRegionOp : quant_Op<"region", [
91    NoSideEffect,
92    IsolatedFromAbove,
93    SingleBlockImplicitTerminator<"ReturnOp">]> {
94  let summary = [{
95    The `region` operation wraps high-precision ops as a logical low-precision
96    quantized kernel.
97  }];
98
99  let arguments = (ins Variadic<AnyType>:$inputs,
100                    TypeArrayAttr:$input_specs,
101                    TypeArrayAttr:$output_specs,
102                    StrAttr:$logical_kernel);
103  let results = (outs Variadic<AnyType>:$outputs);
104  let regions = (region SizedRegion<1>:$body);
105  let verifier = [{ return verifyRegionOp(*this); }];
106}
107
108def quant_ReturnOp : quant_Op<"return", [Terminator]> {
109  let summary = [{
110    The `return` operation terminates a quantize region and returns values.
111  }];
112
113  let arguments = (ins Variadic<AnyTensor>:$results);
114}
115
116//===----------------------------------------------------------------------===//
117// Training integration and instrumentation ops
118//===----------------------------------------------------------------------===//
119
120def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
121                                    [SameOperandsAndResultType, NoSideEffect]> {
122  let summary = [{
123    Simulates the effect of uniform quantization with const range.
124  }];
125
126  let description = [{
127    Given a const min, max, num_bits and narrow_range attribute, applies the
128    same uniform quantization simulation as is done by the TensorFlow
129    fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
130    method and the quant-convert-simulated-quantization pass for further details.
131  }];
132
133  let arguments = (ins
134    F32Tensor:$inputs,
135    F32Attr:$min,
136    F32Attr:$max,
137    // The bitwidth of the quantization; between 2 and 16, inclusive.
138    I64Attr:$num_bits,
139    // Quantization range starts from 0 or 1; starts from 1 if true.
140    DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
141    // The sign of the quantization.
142    DefaultValuedAttr<BoolAttr, "false">:$is_signed
143  );
144
145  let results = (outs
146    F32Tensor:$outputs
147  );
148}
149
150def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis",
151                                    [SameOperandsAndResultType, NoSideEffect]> {
152  let summary = [{
153    Simulates the effect of per axis uniform quantization with const range.
154  }];
155
156  let description = [{
157    Given a const min, max, num_bits and narrow_range attribute, applies the
158    same per axis uniform quantization simulation as is done by the TensorFlow
159    fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType()
160    utility method and the quant-convert-simulated-quantization pass for further
161    details.
162  }];
163
164  let arguments = (ins
165    F32Tensor:$inputs,
166    F32ArrayAttr:$min,
167    F32ArrayAttr:$max,
168    // The quantized dimension of the inputs tensor.
169    I64Attr:$axis,
170    // The bitwidth of the quantization; between 2 and 16, inclusive.
171    I64Attr:$num_bits,
172    // Quantization range starts from 0 or 1; starts from 1 if true.
173    DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
174    // The sign of the quantization.
175    DefaultValuedAttr<BoolAttr, "false">:$is_signed
176  );
177
178  let results = (outs
179    F32Tensor:$outputs
180  );
181}
182
183def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
184  let summary = "Indicates that statistics are resolved by reference.";
185
186  let description = [{
187    This op acts as an identity that, when encountered at runtime, should result
188    in statistics being collected about about the value of its operand/result.
189    Such statistics will be stored with the provided key, allowing this node
190    to later be converted to a 'stats' op if statistics with that key have been
191    encountered.
192  }];
193
194  let arguments = (ins
195    quant_RealValueType:$arg,
196    StrAttr:$statsKey
197  );
198  let results = (outs quant_RealValueType);
199}
200
201def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
202  let summary = "Identity op which associates statistics with the value.";
203
204  let description = [{
205    Associates statistics about the runtime ranges of values observed for
206    evaluations of this node.
207
208    Statistics about the entire type are reported in the 'layerStats' attribute
209    and those for each axis, in the (optional) `axisStats` attribute. The
210    interpretation of each is determined by the last dimension of its shape.
211    Currently, only dim=2 is supported, which is interpreted as [min, max].
212
213    `layerStats` must be a rank 1 tensor: [2]
214    `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size
215      splitted by the `axis` dimension. For example:
216
217    ```
218    <?x?x3x2>, axis=3 => N=2
219    <?x?x3x2>, axis=2 => N=6
220    ```
221  }];
222
223  let arguments = (ins
224    quant_RealValueType:$arg,
225    ElementsAttr:$layerStats,
226    OptionalAttr<ElementsAttr>:$axisStats,
227    OptionalAttr<I64Attr>:$axis);
228  let results = (outs quant_RealValueType);
229
230  let verifier = [{
231    auto tensorArg = arg().getType().dyn_cast<TensorType>();
232    if (!tensorArg) return emitOpError("arg needs to be tensor type.");
233
234    // Verify layerStats attribute.
235    {
236      auto layerStatsType = layerStats().getType();
237      if (!layerStatsType.getElementType().isa<FloatType>()) {
238        return emitOpError(
239            "layerStats must have a floating point element type");
240      }
241      if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
242        return emitOpError("layerStats must have shape [2]");
243      }
244    }
245    // Verify axisStats (optional) attribute.
246    if (axisStats()) {
247      if (!axis()) return emitOpError("axis must be specified for axisStats");
248
249      auto shape = tensorArg.getShape();
250      auto argSliceSize = std::accumulate(std::next(shape.begin(),
251        *axis()), shape.end(), 1, std::multiplies<int64_t>());
252
253      auto axisStatsType = axisStats()->getType();
254      if (!axisStatsType.getElementType().isa<FloatType>()) {
255        return emitOpError("axisStats must have a floating point element type");
256      }
257      if (axisStatsType.getRank() != 2 ||
258          axisStatsType.getDimSize(1) != 2 ||
259          axisStatsType.getDimSize(0) != argSliceSize) {
260        return emitOpError("axisStats must have shape [N,2] "
261                           "where N = the slice size defined by the axis dim");
262      }
263    }
264    return success();
265  }];
266}
267
268def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> {
269  let summary = [{
270    Indicates that one point of the computation is coupled to another.
271  }];
272
273  let description = [{
274    Ordinarily, relationships between ops for the purposes of determining
275    compatible quantized types is explicit based on the use-def chain. However,
276    in some situations, a use may be separated from its def by arbitrary
277    external connections. In such a case, during analysis, all coupled_ref
278    nodes in a module which share a coupledKey will be considered to be
279    directly connected as via an identity op for the purpose of type inference.
280  }];
281
282  let arguments = (ins
283    quant_RealValueType:$arg,
284    StrAttr:$coupledKey);
285  let results = (outs quant_RealValueType);
286}
287
288#endif // DIALECT_QUANT_QUANT_OPS_
289