1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "flatbuffers/flexbuffers.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/lite/context.h"
26 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
30 
31 namespace tflite {
32 namespace optimize {
33 
34 namespace {
35 
36 typedef struct {
37   OperatorT* op;
38   // The index of the op in the operators vector.
39   int32_t op_idx;
40   // The index of the tensor to quantize in subgraph->tensors.
41   int32_t op_input_idx;
42 } ConsumerOpInfo;
43 
44 // The default minimum number of elements a weights array must have to be
45 // quantized by this transformation.
46 const int kWeightsMinNumElementsDefault = 1024;
47 
48 // Gets the operators that consume tensor_idx.
GetTensorConsumers(const ModelT * model,const SubGraphT * subgraph,int32_t tensor_idx)49 std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
50                                                const SubGraphT* subgraph,
51                                                int32_t tensor_idx) {
52   // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor,
53   // instead doing one sweep for the entire model.
54   std::vector<ConsumerOpInfo> consumer_ops;
55   for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
56     OperatorT* op = subgraph->operators[op_idx].get();
57     if (op == nullptr) {
58       continue;
59     }
60     for (size_t i = 0; i < op->inputs.size(); ++i) {
61       if (op->inputs[i] == tensor_idx) {
62         consumer_ops.push_back(
63             {op, static_cast<int>(op_idx), static_cast<int>(i)});
64       }
65     }
66   }
67   return consumer_ops;
68 }
69 
70 // Gets the list of op->inputs indices of the weights inputs to be quantized for
71 // the provided op.
GetWeightInputIndices(const BuiltinOperator & op_code)72 std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
73   if (op_code == BuiltinOperator_CONV_2D ||
74       op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
75       op_code == BuiltinOperator_FULLY_CONNECTED ||
76       op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
77     return {1};
78   } else if (op_code == BuiltinOperator_SVDF) {
79     // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc
80     return {1, 2};
81   } else if (op_code == BuiltinOperator_LSTM ||
82              op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
83     // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc
84     // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc
85     return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
86   } else if (op_code == BuiltinOperator_RNN ||
87              op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
88     // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc
89     // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc
90     return {1, 2};
91   } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
92     // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc
93     return {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 16,
94             18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 33};
95   } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
96     // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
97     return {1, 2, 4, 5};
98   }
99   return {};
100 }
101 
102 // Returns true if the operator supports hybrid evaluation.
IsHybridEvaluationOp(const OperatorT * op,const BuiltinOperator & op_code)103 bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
104   // Operations that support hybrid evaluation.
105   bool eval_hybrid = false;
106   if (op_code == BuiltinOperator_FULLY_CONNECTED ||
107       op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
108       op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
109       op_code == BuiltinOperator_RNN ||
110       op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
111       op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
112       op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
113       op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
114     eval_hybrid = true;
115   } else if (op_code == BuiltinOperator_LSTM) {
116     const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
117     // Only lstm kernel_type full supports hybrid evaluation.
118     if (options->kernel_type == LSTMKernelType_FULL) {
119       eval_hybrid = true;
120     }
121   }
122   return eval_hybrid;
123 }
124 
125 // Returns true if all of the op's inputs are quantized.
CheckAllOpInputsQuantized(const SubGraphT * subgraph,const OperatorT * op,const BuiltinOperator & op_code)126 bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op,
127                                const BuiltinOperator& op_code) {
128   std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
129   for (const int32_t op_input_idx : op_input_indices) {
130     int32_t tensor_idx = op->inputs[op_input_idx];
131 
132     if (tensor_idx == -1) {
133       // Optional tensor.
134       continue;
135     }
136 
137     TensorT* tensor = subgraph->tensors[tensor_idx].get();
138 
139     if (tensor->type != TensorType_INT8) {
140       return false;
141     }
142   }
143   return true;
144 }
145 
146 // Inserts Tensors for each input tensor of op that should be
147 // quantized into tensor_map.
InsertQuantizableInputTensorsFromOperator(const ModelT * model,const OperatorT * op,uint64_t weights_min_num_elements,std::unordered_map<int32_t,TensorT * > * tensor_map)148 TfLiteStatus InsertQuantizableInputTensorsFromOperator(
149     const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
150     std::unordered_map<int32_t, TensorT*>* tensor_map) {
151   SubGraphT* subgraph = model->subgraphs.at(0).get();
152   const BuiltinOperator op_code =
153       model->operator_codes[op->opcode_index]->builtin_code;
154 
155   std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
156   for (const int32_t op_input_idx : op_input_indices) {
157     int32_t tensor_idx = op->inputs[op_input_idx];
158     if (tensor_idx == -1) {
159       LOG(INFO) << "Skipping optional tensor input " << op_input_idx
160                 << " of operation " << EnumNameBuiltinOperator(op_code);
161       continue;
162     }
163 
164     TensorT* tensor = subgraph->tensors[tensor_idx].get();
165     if (tensor->type != TensorType_FLOAT32) {
166       LOG(INFO) << "Skipping quantization of tensor " << tensor->name
167                 << " that is not type float.";
168       continue;
169     }
170 
171     uint64_t num_elements;
172     TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
173     if (num_elements < weights_min_num_elements) {
174       LOG(INFO) << "Skipping quantization of tensor " << tensor->name
175                 << " because it has fewer than " << weights_min_num_elements
176                 << " elements (" << num_elements << ").";
177       continue;
178     }
179 
180     // Some tensors may have a null buffer vector, indicating an intermediate
181     // array.
182     if (model->buffers[tensor->buffer]->data.data() == nullptr) {
183       LOG(INFO) << "Skipping quantization of tensor " << tensor->name
184                 << " because it has no allocated buffer.";
185       continue;
186     }
187 
188     tensor_map->insert({tensor_idx, tensor});
189   }
190 
191   return kTfLiteOk;
192 }
193 
194 // Returns the index of the Dequantize op_code.
195 // If a Dequantize op_code doesn't exist, adds it and returns its index.
GetOrInsertDequantizeOpCodeIndex(ModelT * model)196 int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
197   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
198     if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
199       return i;
200     }
201   }
202   model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
203   int op_code_idx = model->operator_codes.size() - 1;
204   model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
205   // Version 2 and onwards supports INT8 inputs.
206   model->operator_codes[op_code_idx]->version = 2;
207 
208   // Return the index of the newly placed OperatorCodeT.
209   return op_code_idx;
210 }
211 
212 // Creates a Dequantize OperatorT object.
MakeDequantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)213 void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
214                             int32_t input, int32_t output) {
215   OperatorT* op_raw = new OperatorT;
216   op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model);
217   op_raw->inputs = {input};
218   op_raw->outputs = {output};
219 
220   op->reset(op_raw);
221 }
222 
223 // Create a new TensorT object.
MakeTensor(const string & name,const std::vector<int32_t> & shape,std::unique_ptr<TensorT> * tensor)224 void MakeTensor(const string& name, const std::vector<int32_t>& shape,
225                 std::unique_ptr<TensorT>* tensor) {
226   TensorT* tensor_raw = new TensorT;
227   tensor_raw->name = name;
228   tensor_raw->shape = shape;
229 
230   tensor->reset(tensor_raw);
231 }
232 
233 // Updates operator code versions for the operators with INT8 inputs.
UpdateInt8OperatorVersions(ModelT * model)234 void UpdateInt8OperatorVersions(ModelT* model) {
235   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
236     const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
237     if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
238         op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
239         op_code == BuiltinOperator_RNN ||
240         op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
241         op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
242         op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
243       model->operator_codes[i]->version = 2;
244 
245     } else if (op_code == BuiltinOperator_FULLY_CONNECTED ||
246                op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
247                op_code == BuiltinOperator_LSTM) {
248       model->operator_codes[i]->version = 3;
249     }
250   }
251 }
252 
QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,bool use_hybrid_evaluation,uint64_t weights_min_num_elements)253 TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
254                                      const Model* input_model,
255                                      bool use_hybrid_evaluation,
256                                      uint64_t weights_min_num_elements) {
257   std::unique_ptr<ModelT> model;
258   model.reset(input_model->UnPack());
259 
260   // TODO(suharshs): When models support multiple subgraphs, add support.
261   if (model->subgraphs.size() != 1) {
262     LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
263                   "subgraph.";
264     return kTfLiteError;
265   }
266 
267   SubGraphT* subgraph = model->subgraphs.at(0).get();
268 
269   std::vector<std::unique_ptr<OperatorT>> new_operators;
270   std::unordered_map<int32_t, TensorT*> tensor_map;
271   for (size_t i = 0; i < subgraph->operators.size(); ++i) {
272     OperatorT* op = subgraph->operators[i].get();
273     TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
274         model.get(), op, weights_min_num_elements, &tensor_map));
275   }
276 
277   // The unordered_map ensures that we quantize each tensor exactly once.
278   // TODO(suharshs): This map key isn't sufficient when we support multiple
279   // subgraphs.
280   for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
281     // Quantize the tensor.
282     TF_LITE_ENSURE_STATUS(
283         utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second));
284   }
285 
286   // Examine the tensor consumers to determine which require dequantize ops.
287   for (const auto& tensor_pair : tensor_map) {
288     const int32_t tensor_idx = tensor_pair.first;
289     TensorT* tensor = tensor_pair.second;
290     std::vector<ConsumerOpInfo> consumer_op_infos =
291         GetTensorConsumers(model.get(), subgraph, tensor_idx);
292 
293     std::vector<ConsumerOpInfo> dequant_op_infos;  // Ops that need dequants.
294     for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
295       OperatorT* consumer_op = consumer_op_info.op;
296       const BuiltinOperator consumer_op_code =
297           model->operator_codes[consumer_op->opcode_index]->builtin_code;
298       // If the op is a hybrid op and all the required tensors are quantized,
299       // we have no further work to do, but for all ops that require
300       // dequantization we need to add a Dequantize op.
301       bool eval_hybrid =
302           use_hybrid_evaluation &&
303           IsHybridEvaluationOp(consumer_op, consumer_op_code) &&
304           CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code);
305       if (!eval_hybrid) {
306         dequant_op_infos.push_back(consumer_op_info);
307       }
308     }
309 
310     // If no ops require dequant, we are done for this tensor.
311     if (dequant_op_infos.empty()) {
312       continue;
313     }
314 
315     // Create a new tensor to be the output of the dequantize op.
316     std::unique_ptr<TensorT> dequantize_output;
317     const string dequant_name = tensor->name + "_dequantize";
318     MakeTensor(dequant_name, tensor->shape, &dequantize_output);
319     const int32_t dequantize_output_idx = subgraph->tensors.size();
320     subgraph->tensors.push_back(std::move(dequantize_output));
321 
322     // Create the Dequantize operation.
323     std::unique_ptr<OperatorT> dequantize_op;
324     MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
325                            dequantize_output_idx);
326 
327     LOG(INFO) << "Creating Dequantize op with name " << dequant_name << ".";
328 
329     // Update the op_input of all the ops that need the created dequantize
330     // operation.
331     int32_t min_op_idx = 0;
332     for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
333       dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
334           dequantize_output_idx;
335       min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
336     }
337 
338     // Insert the newly created Dequantize operation before the earliest
339     // consumer, since TFLite requires operators to be topo-sorted.
340     subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
341                                std::move(dequantize_op));
342   }
343 
344   // Update the modified operator code versions.
345   UpdateInt8OperatorVersions(model.get());
346 
347   flatbuffers::Offset<Model> output_model_location =
348       Model::Pack(*builder, model.get());
349   FinishModelBuffer(*builder, output_model_location);
350 
351   return kTfLiteOk;
352 }
353 
354 }  // namespace
355 
356 namespace internal {
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,bool use_hybrid_evaluation)357 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
358                              const Model* input_model,
359                              uint64_t weights_min_num_elements,
360                              bool use_hybrid_evaluation) {
361   // By default we require that only weights with more than
362   // kWeightsMinSizeDefault elements are quantized.
363   return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
364                                  weights_min_num_elements);
365 }
366 }  // namespace internal
367 
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements)368 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
369                              const Model* input_model,
370                              uint64_t weights_min_num_elements) {
371   return QuantizeWeightsInternal(builder, input_model, true,
372                                  weights_min_num_elements);
373 }
374 
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model)375 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
376                              const Model* input_model) {
377   // By default we require that only weights with more than
378   // kWeightsMinSizeDefault elements are quantized.
379   return QuantizeWeightsInternal(builder, input_model, true,
380                                  kWeightsMinNumElementsDefault);
381 }
382 
383 }  // namespace optimize
384 }  // namespace tflite
385