1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Transform pass for LSTMs.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_set.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/MathExtras.h"
30 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
36 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
37 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
38 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
39 #include "mlir/IR/Value.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/lite/schema/schema_generated.h"
47 #include "tensorflow/lite/tools/optimize/operator_property.h"
48 
49 //===----------------------------------------------------------------------===//
50 // The prepare-quantize Pass for LSTM.
51 //
52 namespace mlir {
53 namespace TFL {
54 
55 constexpr double power_of_two_scale = 32768.0;
56 
57 // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
58 constexpr const char* intermediate_attributes[] = {
59     "input_to_input_intermediate", "input_to_forget_intermediate",
60     "input_to_cell_intermediate", "input_to_output_intermediate",
61     "effective_hidden_scale_intermediate"};
62 
63 // Calculates the minimum power of two that is not less than the value.
PowerOfTwoBound(double value)64 inline double PowerOfTwoBound(double value) {
65   return std::pow(2, std::ceil(std::log2(value)));
66 }
67 
68 // Returns the element type of LSTM's intermediate tensor designated by the
69 // index.
70 template <typename LstmOp>
GetIntermediateElementType(LstmOp op,int tensor_index)71 inline QuantizedType GetIntermediateElementType(LstmOp op, int tensor_index) {
72   if (tensor_index < 0 || tensor_index > 4) return nullptr;
73   TypeAttr attr = op->template getAttrOfType<TypeAttr>(
74       intermediate_attributes[tensor_index]);
75   if (!attr) {
76     return nullptr;
77   }
78   return QuantizedType::getQuantizedElementType(attr.getValue());
79 }
80 
81 namespace operator_property = ::tflite::optimize::operator_property;
82 using Q = quant::QuantizeCastOp;
83 using DQ = quant::DequantizeCastOp;
84 
85 template <typename LstmOp>
GetLstmProperty(LstmOp op,operator_property::OpVariant * lstm_variant,operator_property::OperatorProperty * op_property)86 LogicalResult GetLstmProperty(
87     LstmOp op, operator_property::OpVariant* lstm_variant,
88     operator_property::OperatorProperty* op_property) {
89   if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
90     lstm_variant->op_code = tflite::BuiltinOperator_LSTM;
91   } else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(op.getOperation())) {
92     lstm_variant->op_code =
93         tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
94   } else {
95     op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
96     return failure();
97   }
98   lstm_variant->use_projection =
99       !op.projection_weights().getType().template isa<NoneType>();
100   lstm_variant->use_peephole =
101       !op.cell_to_output_weights().getType().template isa<NoneType>();
102   lstm_variant->use_layer_norm =
103       !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
104 
105   *op_property = operator_property::GetOperatorProperty(*lstm_variant);
106 
107   // TODO(b/176258587) move this to operator_property.cc if this is needed in
108   // other components, too.
109   bool use_cifg =
110       op.input_to_input_weights().getType().template isa<NoneType>();
111   if (use_cifg) {
112     const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
113     const int cifg_non_intermediate = 0;
114     op_property->inputs.erase(
115         std::remove_if(
116             op_property->inputs.begin(), op_property->inputs.end(),
117             [&](std::pair<int, operator_property::TensorProperty> input) {
118               return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
119             }),
120         op_property->inputs.end());
121     op_property->intermediates.erase(
122         std::remove_if(op_property->intermediates.begin(),
123                        op_property->intermediates.end(),
124                        [&](std::pair<int, operator_property::TensorProperty>
125                                intermediate) {
126                          return intermediate.first == cifg_non_intermediate;
127                        }),
128         op_property->intermediates.end());
129   }
130   return success();
131 }
132 
133 template <typename SourceOp>
134 struct PrepareLstmOutputScale : public OpRewritePattern<SourceOp> {
135  public:
PrepareLstmOutputScalePrepareLstmOutputScale136   explicit PrepareLstmOutputScale(MLIRContext* context)
137       : OpRewritePattern<SourceOp>(context) {}
matchAndRewritePrepareLstmOutputScale138   LogicalResult matchAndRewrite(SourceOp op,
139                                 PatternRewriter& rewriter) const override {
140     operator_property::OpVariant lstm_variant;
141     operator_property::OperatorProperty lstm_property;
142 
143     if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
144       return failure();
145     }
146     if (lstm_property.restrict_scale.size() != 1) {
147       op.emitError() << "The LSTM's operator property expects exactly one "
148                      << "restrict scale requirement. Got "
149                      << lstm_property.restrict_scale.size()
150                      << " restrict scale requirements.";
151       return failure();
152     }
153 
154     // Use same scale for input and output specified in restrict_scale.
155     const std::vector<int>& tensors = lstm_property.restrict_scale[0];
156     if (tensors.size() != 2) {
157       op.emitError(
158           "Unexpected restricted_scale from operator property."
159           " Should only have a pair of indices.");
160       return failure();
161     }
162     return processRestrictScale(op, tensors[0], tensors[1], rewriter);
163   }
164 
165  private:
166   // For LSTM's recurrent input activation and output, they are quantized with
167   // the collective range of both tensors, because theoretically the input
168   // activation value for the very first inference is not reflected in the
169   // output and the input activation is not captured.
processRestrictScalePrepareLstmOutputScale170   LogicalResult processRestrictScale(SourceOp op, int input_index,
171                                      int output_index,
172                                      PatternRewriter& rewriter) const {
173     assert(output_index == 0);
174     if (!op.getResult().hasOneUse()) {
175       op.emitError()
176           << "output " << output_index
177           << " should have only one use, which should be quant.stats.";
178       return failure();
179     }
180 
181     llvm::SmallVector<quant::StatisticsOp, 2> stats_ops = {
182         llvm::dyn_cast_or_null<quant::StatisticsOp>(
183             op.getOperand(input_index).getDefiningOp()),
184         llvm::dyn_cast_or_null<quant::StatisticsOp>(
185             *op.getResult().getUsers().begin()),
186     };
187 
188     if (!stats_ops[0] || !stats_ops[1]) {
189       return failure();  // Already converted to Q-DQ pair.
190     }
191 
192     llvm::SmallVector<llvm::APFloat, 4> min_max_values;
193 
194     for (auto& stats_op : stats_ops) {
195       auto values = stats_op.layerStats()
196                         .dyn_cast<DenseFPElementsAttr>()
197                         .getValues<llvm::APFloat>();
198       min_max_values.insert(min_max_values.end(), values.begin(), values.end());
199     }
200 
201     // min and max values of two stats are already the same.
202     if (min_max_values[0] == min_max_values[2] &&
203         min_max_values[1] == min_max_values[3]) {
204       return failure();
205     }
206 
207     mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
208         mlir::RankedTensorType::get({2}, rewriter.getF32Type()),
209         {llvm::minimum(min_max_values[0], min_max_values[2]),
210          llvm::maximum(min_max_values[1], min_max_values[3])});
211     mlir::ElementsAttr axis_stats;
212     mlir::IntegerAttr axis;
213     for (auto& stats_op : stats_ops) {
214       rewriter.setInsertionPointAfter(stats_op);
215       rewriter.replaceOpWithNewOp<quant::StatisticsOp>(
216           stats_op, stats_op.arg(), layer_stats, axis_stats, axis);
217     }
218     return success();
219   }
220 };
221 
222 template <typename SourceOp>
223 struct ConvertOpStatsToQDQs : public OpRewritePattern<SourceOp> {
224  public:
225   explicit ConvertOpStatsToQDQs(MLIRContext* context,
226                                 const QuantizationSpecs& quant_specs,
227                                 PatternBenefit benefit = 1)
228       : OpRewritePattern<SourceOp>(context, benefit),
229         quant_specs(quant_specs) {}
230 
231  protected:
232   QuantizationSpecs quant_specs;
233 
processInputsConvertOpStatsToQDQs234   LogicalResult processInputs(
235       SourceOp op, const operator_property::OpVariant& op_variant,
236       const operator_property::OperatorProperty& op_property,
237       PatternRewriter& rewriter) const {
238     for (auto& enumerated_inputs : op_property.inputs) {
239       int index = enumerated_inputs.first;
240       auto& tensor_property = enumerated_inputs.second;
241 
242       Value input = op.getOperand(index);
243 
244       if (input.getDefiningOp() == nullptr) continue;
245 
246       // TODO(b/172517537): make this work with non-PTQ case.
247       if (llvm::isa<ConstantOp, TFL::ConstOp>(input.getDefiningOp())) {
248         // Tensors with derived scale are biases, and handled in propagation.
249         if (tensor_property.use_derived_scale) continue;
250         // For weights, use quantization scale inferred from the values.
251         if (failed(processConstantOp(op, input.getDefiningOp(), index,
252                                      tensor_property, rewriter))) {
253           return failure();
254         }
255       } else {
256         if (auto stats_op =
257                 llvm::dyn_cast<quant::StatisticsOp>(input.getDefiningOp())) {
258           if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
259                                     rewriter))) {
260             return failure();
261           }
262         } else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
263                    !llvm::isa<SameScalesOpInterface>(input.getDefiningOp())) {
264           // Continue if StatisticsOp is already converted to Q-DQ pair, or
265           // stats op is not immediately available to the input because it's
266           // connected to ops with same scale requirements.
267           // TODO(b/172517537): make this work with non-PTQ case.
268           op.emitError() << "Input " << index
269                          << " should be from DequantizeCast, Statistics, "
270                          << ", or ops with same scale requirement.";
271           input.getDefiningOp()->emitError();
272           return failure();
273         }
274       }
275     }
276     return success();
277   }
278 
processConstantOpConvertOpStatsToQDQs279   LogicalResult processConstantOp(
280       SourceOp op, Operation* const_op, int input_index,
281       const operator_property::TensorProperty& tensor_property,
282       PatternRewriter& rewriter) const {
283     // Non-float tensors are neither weights nor require quantization.
284     auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
285     if (!type || !type.getElementType().isa<FloatType>()) return success();
286 
287     DenseFPElementsAttr attr;
288     if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
289       const_op->emitError("Not a constant op.");
290       return failure();
291     }
292 
293     UniformQuantizedType quant_type = nullptr;
294     // When the number of bits is 10 (instead of 16), quantize the tensor to
295     // [-512, 512], instead of [-32767, 32767].
296     // For now this behavior is specific for SVDF, where 6 bits are reserved for
297     // the reduce operation after element-wise multiplication between state and
298     // time weights.
299     if (tensor_property.number_of_bits == 10) {
300       SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
301       SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
302       // Computes the effective min/max values of the attribute values.
303       quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1,
304                                    /*symmetric=*/true, mins, maxs);
305       double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits);
306       quant_type = UniformQuantizedType::getChecked(
307           quant::QuantizationFlags::Signed, rewriter.getIntegerType(16),
308           attr.getType().getElementType(), scale, /*zeroPoint=*/0,
309           llvm::minIntN(10), -llvm::minIntN(10), const_op->getLoc());
310     } else {
311       quant_type =
312           quant::GetUniformQuantizedTypeForWeight(
313               attr, /*symmetric=*/true,
314               /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
315               /*narrow_range=*/true, quant_specs.legacy_float_scale)
316               .template dyn_cast<quant::UniformQuantizedType>();
317     }
318     if (!quant_type) {
319       const_op->emitError("Failed to get quantized type");
320       return failure();
321     }
322 
323     // TODO(b/172517537): duplicate the constant when the bias is shared.
324     Type expressed_type = const_op->getResult(0).getType();
325     Type cast_type = quant_type.castFromExpressedType(expressed_type);
326     rewriter.setInsertionPointAfter(const_op);
327     auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
328                                 const_op->getResult(0));
329     auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
330     op.setOperand(input_index, dq.getResult());
331     return success();
332   }
333 
replaceStatsOpConvertOpStatsToQDQs334   LogicalResult replaceStatsOp(
335       SourceOp op, quant::StatisticsOp stats_op, int input_index,
336       const operator_property::TensorProperty& tensor_property,
337       PatternRewriter& rewriter) const {
338     if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
339       // TODO(b/172517537): check if other tensors should go through this
340       // check too.
341       op.emitError() << "Input tensor [" << input_index
342                      << "] is a state tensor, but has more than one use.";
343       return failure();
344     }
345     auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
346     if (!stats || stats.getNumElements() != 2) {
347       stats_op.emitError("Stats should have 2 values.");
348       return failure();
349     }
350     quant::QuantizedType quant_type;
351     double min = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
352     double max = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
353     // Make sure the range includes zero.
354     min = std::min(min, 0.0);
355     max = std::max(max, 0.0);
356     Type expressed = getElementTypeOrSelf(stats_op.getType());
357 
358     if (tensor_property.extend_to_power_of_two) {
359       if (tensor_property.number_of_bits != 16) {
360         op.emitError(
361             "extended power of 2 scale is only supported for 16-bit"
362             " quantization.");
363         return failure();
364       }
365 
366       double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
367       // Set flags to 1 for signed type.
368       quant_type = UniformQuantizedType::getChecked(
369           quant::QuantizationFlags::Signed,
370           rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
371           /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
372           /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
373           llvm::maxIntN(tensor_property.number_of_bits), op.getLoc());
374     } else {
375       // int16 uses range [-32767, 32767]
376       if (tensor_property.number_of_bits == 16) {
377         max = std::max(std::abs(min), std::abs(max));
378         min = -max;
379         quant_type = quant::fakeQuantAttrsToType(
380             op.getLoc(), tensor_property.number_of_bits, min, max,
381             /*narrowRange=*/true, expressed,
382             /*isSigned=*/true);
383       } else {
384         quant_type = quant::fakeQuantAttrsToType(
385             op.getLoc(), tensor_property.number_of_bits, min, max,
386             /*narrowRange=*/false, expressed,
387             /*isSigned=*/true);
388       }
389       if (quant_specs.legacy_float_scale) {
390         quant_type = quant::DownCastScale(quant_type, min, max, op.getLoc());
391       }
392     }
393     rewriter.setInsertionPointAfter(stats_op);
394     Type result_type = quant_type.castFromExpressedType(stats_op.getType());
395     auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
396     rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
397     return success();
398   }
399 };
400 
401 // Quantize LSTM according to its quantization recipe.
402 template <typename SourceOp>
403 struct ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs<SourceOp> {
404  public:
ConvertLstmStatsToQDQsConvertLstmStatsToQDQs405   ConvertLstmStatsToQDQs(MLIRContext* context,
406                          const QuantizationSpecs& quant_specs)
407 
408       : ConvertOpStatsToQDQs<SourceOp>(context, quant_specs) {}
matchAndRewriteConvertLstmStatsToQDQs409   LogicalResult matchAndRewrite(SourceOp op,
410                                 PatternRewriter& rewriter) const override {
411     operator_property::OpVariant lstm_variant;
412     operator_property::OperatorProperty lstm_property;
413     if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
414       return failure();
415     }
416 
417     if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
418         failed(ConvertOpStatsToQDQs<SourceOp>::processInputs(
419             op, lstm_variant, lstm_property, rewriter))) {
420       return failure();
421     }
422 
423     return success();
424   }
425 
426  private:
processIntermediatesConvertLstmStatsToQDQs427   LogicalResult processIntermediates(
428       SourceOp op, const operator_property::OpVariant& lstm_variant,
429       const operator_property::OperatorProperty& lstm_property) const {
430     for (auto& enumerated_intermediates : lstm_property.intermediates) {
431       int index = enumerated_intermediates.first;
432       auto& tensor_property = enumerated_intermediates.second;
433       // intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
434       if (!lstm_variant.use_layer_norm && index != 4) {
435         continue;
436       }
437 
438       TypeAttr attr =
439           op->template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
440       auto quant_type = GetIntermediateElementType<SourceOp>(op, index);
441       if (!quant_type) {
442         // intermediate tensor 4 is optional, unless the LSTM uses projection.
443         if (index == 4 && !lstm_variant.use_projection) {
444           return success();
445         }
446         op.emitError() << intermediate_attributes[index]
447                        << " is not quantized.";
448         return failure();
449       }
450       auto calibrated_type =
451           quant_type.template dyn_cast<quant::CalibratedQuantizedType>();
452       if (!calibrated_type) {
453         int num_storage_bits = quant_type.getStorageTypeIntegralWidth();
454         if (tensor_property.number_of_bits != num_storage_bits) {
455           op.emitError() << intermediate_attributes[index]
456                          << " is expected to be quantized with "
457                          << tensor_property.number_of_bits << " bits, but got "
458                          << num_storage_bits << " bits instead.";
459           return failure();
460         }
461         continue;  // skip if it is already quantized.
462       }
463       quant::UniformQuantizedType qtype;
464       if (tensor_property.number_of_bits == 8) {
465         qtype = quant::fakeQuantAttrsToType(
466             op.getLoc(), tensor_property.number_of_bits,
467             calibrated_type.getMin(), calibrated_type.getMax(),
468             /*narrowRange=*/false, calibrated_type.getExpressedType(),
469             /*isSigned=*/this->quant_specs.IsSignedInferenceType());
470         if (this->quant_specs.legacy_float_scale) {
471           qtype = quant::DownCastScale(qtype, calibrated_type.getMin(),
472                                        calibrated_type.getMax(), op.getLoc())
473                       .template cast<UniformQuantizedType>();
474         }
475       } else if (tensor_property.number_of_bits == 16) {
476         double max = std::max(std::abs(calibrated_type.getMin()),
477                               std::abs(calibrated_type.getMax()));
478         qtype = quant::fakeQuantAttrsToType(
479             op.getLoc(), tensor_property.number_of_bits, -max, max,
480             /*narrowRange=*/true, calibrated_type.getExpressedType(),
481             /*isSigned=*/true);
482       } else {
483         op.emitError() << "Unsupported quantization bits: "
484                        << tensor_property.number_of_bits;
485         return failure();
486       }
487       op->setAttr(intermediate_attributes[index],
488                   TypeAttr::get(qtype.castFromExpressedType(
489                       qtype.castToExpressedType(attr.getValue()))));
490     }
491     return success();
492   }
493 };
494 
495 // Returns a function that returns the quantized type of a bias input.
496 // The scale of bias is a multiplication of given scale and scales from the
497 // quantization type of other operands.
GetUniformQuantizedTypeForBiasWithScale(double scale)498 inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale(
499     double scale) {
500   return [=](const std::vector<quant::QuantParams>& quant_params,
501              bool legacy_float_scale) -> quant::QuantParams {
502     if (auto qtype =
503             GetUniformQuantizedTypeForBias(quant_params, legacy_float_scale)
504                 .dyn_cast_or_null<UniformQuantizedType>()) {
505       return quant::UniformQuantizedType::get(
506           qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
507           qtype.getScale() * scale, qtype.getZeroPoint(),
508           qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
509     }
510     return {};
511   };
512 }
513 
514 // Returns quantization spec for LSTMs based on their operator properties.
515 template <typename LstmOp>
GetLstmOpQuantSpec(LstmOp op)516 std::unique_ptr<quant::OpQuantSpec> GetLstmOpQuantSpec(LstmOp op) {
517   operator_property::OpVariant lstm_variant;
518   operator_property::OperatorProperty lstm_property;
519   if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
520     return nullptr;
521   }
522 
523   auto spec = absl::make_unique<quant::OpQuantSpec>();
524 
525   for (const auto& enumerated_inputs : lstm_property.inputs) {
526     int index = enumerated_inputs.first;
527     auto& tensor_property = enumerated_inputs.second;
528     if (tensor_property.use_derived_scale) {
529       double scale = 1.0;
530       for (int tensor_index :
531            tensor_property.derived_scale.intermediate_tensors) {
532         auto quant_type = GetIntermediateElementType<LstmOp>(op, tensor_index);
533         if (!quant_type ||
534             !quant_type.template isa<quant::UniformQuantizedType>()) {
535           op->emitError() << "While processing derived scale, intermediate "
536                           << intermediate_attributes[tensor_index]
537                           << " is not quantized.";
538           return nullptr;
539         }
540         scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
541                      .getScale();
542       }
543       for (float factor : tensor_property.derived_scale.factors) {
544         scale *= factor;
545       }
546       spec->biases_params.emplace(
547           index,
548           std::make_pair(tensor_property.derived_scale.input_tensors,
549                          GetUniformQuantizedTypeForBiasWithScale(scale)));
550     }
551   }
552   return spec;
553 }
554 
555 struct ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs<TFL::SVDFOp> {
556  public:
ConvertSvdfStatsToQDQsConvertSvdfStatsToQDQs557   explicit ConvertSvdfStatsToQDQs(MLIRContext* context,
558                                   const QuantizationSpecs& quant_specs_param)
559       : ConvertOpStatsToQDQs<TFL::SVDFOp>(context, quant_specs_param) {}
matchAndRewriteConvertSvdfStatsToQDQs560   LogicalResult matchAndRewrite(TFL::SVDFOp op,
561                                 PatternRewriter& rewriter) const override {
562     operator_property::OpVariant op_variant;
563     op_variant.op_code = tflite::BuiltinOperator_SVDF;
564     auto op_property = operator_property::GetOperatorProperty(op_variant);
565     return ConvertOpStatsToQDQs<TFL::SVDFOp>::processInputs(
566         op, op_variant, op_property, rewriter);
567   }
568 };
569 
570 }  // namespace TFL
571 }  // namespace mlir
572 
573 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
574