1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "flatbuffers/flexbuffers.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/lite/context.h"
26 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
30
31 namespace tflite {
32 namespace optimize {
33
34 namespace {
35
36 typedef struct {
37 OperatorT* op;
38 // The index of the op in the operators vector.
39 int32_t op_idx;
40 // The index of the tensor to quantize in subgraph->tensors.
41 int32_t op_input_idx;
42 } ConsumerOpInfo;
43
44 // The default minimum number of elements a weights array must have to be
45 // quantized by this transformation.
46 const int kWeightsMinNumElementsDefault = 1024;
47
48 // Gets the operators that consume tensor_idx.
GetTensorConsumers(const ModelT * model,const SubGraphT * subgraph,int32_t tensor_idx)49 std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
50 const SubGraphT* subgraph,
51 int32_t tensor_idx) {
52 // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor,
53 // instead doing one sweep for the entire model.
54 std::vector<ConsumerOpInfo> consumer_ops;
55 for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
56 OperatorT* op = subgraph->operators[op_idx].get();
57 if (op == nullptr) {
58 continue;
59 }
60 for (size_t i = 0; i < op->inputs.size(); ++i) {
61 if (op->inputs[i] == tensor_idx) {
62 consumer_ops.push_back(
63 {op, static_cast<int>(op_idx), static_cast<int>(i)});
64 }
65 }
66 }
67 return consumer_ops;
68 }
69
70 // Gets the list of op->inputs indices of the weights inputs to be quantized for
71 // the provided op.
GetWeightInputIndices(const BuiltinOperator & op_code)72 std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
73 if (op_code == BuiltinOperator_CONV_2D ||
74 op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
75 op_code == BuiltinOperator_FULLY_CONNECTED ||
76 op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
77 return {1};
78 } else if (op_code == BuiltinOperator_SVDF) {
79 // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc
80 return {1, 2};
81 } else if (op_code == BuiltinOperator_LSTM ||
82 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
83 // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc
84 // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc
85 return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
86 } else if (op_code == BuiltinOperator_RNN ||
87 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
88 // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc
89 // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc
90 return {1, 2};
91 } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
92 // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc
93 return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16,
94 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 33};
95 } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
96 // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
97 return {1, 2, 4, 5};
98 }
99 return {};
100 }
101
102 // Returns true if the operator supports hybrid evaluation.
IsHybridEvaluationOp(const OperatorT * op,const BuiltinOperator & op_code)103 bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
104 // Operations that support hybrid evaluation.
105 bool eval_hybrid = false;
106 if (op_code == BuiltinOperator_FULLY_CONNECTED ||
107 op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
108 op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
109 op_code == BuiltinOperator_RNN ||
110 op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
111 op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
112 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
113 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
114 eval_hybrid = true;
115 } else if (op_code == BuiltinOperator_LSTM) {
116 const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
117 // Only lstm kernel_type full supports hybrid evaluation.
118 if (options->kernel_type == LSTMKernelType_FULL) {
119 eval_hybrid = true;
120 }
121 }
122 return eval_hybrid;
123 }
124
125 // Returns true if all of the op's inputs are quantized.
CheckAllOpInputsQuantized(const SubGraphT * subgraph,const OperatorT * op,const BuiltinOperator & op_code)126 bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op,
127 const BuiltinOperator& op_code) {
128 std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
129 for (const int32_t op_input_idx : op_input_indices) {
130 int32_t tensor_idx = op->inputs[op_input_idx];
131
132 if (tensor_idx == -1) {
133 // Optional tensor.
134 continue;
135 }
136
137 TensorT* tensor = subgraph->tensors[tensor_idx].get();
138
139 if (tensor->type != TensorType_INT8) {
140 return false;
141 }
142 }
143 return true;
144 }
145
146 // Inserts Tensors for each input tensor of op that should be
147 // quantized into tensor_map.
InsertQuantizableInputTensorsFromOperator(const ModelT * model,const OperatorT * op,uint64_t weights_min_num_elements,std::unordered_map<int32_t,TensorT * > * tensor_map)148 TfLiteStatus InsertQuantizableInputTensorsFromOperator(
149 const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
150 std::unordered_map<int32_t, TensorT*>* tensor_map) {
151 SubGraphT* subgraph = model->subgraphs.at(0).get();
152 const BuiltinOperator op_code =
153 model->operator_codes[op->opcode_index]->builtin_code;
154
155 std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
156 for (const int32_t op_input_idx : op_input_indices) {
157 int32_t tensor_idx = op->inputs[op_input_idx];
158 if (tensor_idx == -1) {
159 LOG(INFO) << "Skipping optional tensor input " << op_input_idx
160 << " of operation " << EnumNameBuiltinOperator(op_code);
161 continue;
162 }
163
164 TensorT* tensor = subgraph->tensors[tensor_idx].get();
165 if (tensor->type != TensorType_FLOAT32) {
166 LOG(INFO) << "Skipping quantization of tensor " << tensor->name
167 << " that is not type float.";
168 continue;
169 }
170
171 uint64_t num_elements;
172 TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
173 if (num_elements < weights_min_num_elements) {
174 LOG(INFO) << "Skipping quantization of tensor " << tensor->name
175 << " because it has fewer than " << weights_min_num_elements
176 << " elements (" << num_elements << ").";
177 continue;
178 }
179
180 // Some tensors may have a null buffer vector, indicating an intermediate
181 // array.
182 if (model->buffers[tensor->buffer]->data.data() == nullptr) {
183 LOG(INFO) << "Skipping quantization of tensor " << tensor->name
184 << " because it has no allocated buffer.";
185 continue;
186 }
187
188 tensor_map->insert({tensor_idx, tensor});
189 }
190
191 return kTfLiteOk;
192 }
193
194 // Returns the index of the Dequantize op_code.
195 // If a Dequantize op_code doesn't exist, adds it and returns its index.
GetOrInsertDequantizeOpCodeIndex(ModelT * model)196 int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
197 for (size_t i = 0; i < model->operator_codes.size(); ++i) {
198 if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
199 return i;
200 }
201 }
202 model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
203 int op_code_idx = model->operator_codes.size() - 1;
204 model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
205 // Version 2 and onwards supports INT8 inputs.
206 model->operator_codes[op_code_idx]->version = 2;
207
208 // Return the index of the newly placed OperatorCodeT.
209 return op_code_idx;
210 }
211
212 // Creates a Dequantize OperatorT object.
MakeDequantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)213 void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
214 int32_t input, int32_t output) {
215 OperatorT* op_raw = new OperatorT;
216 op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model);
217 op_raw->inputs = {input};
218 op_raw->outputs = {output};
219
220 op->reset(op_raw);
221 }
222
223 // Create a new TensorT object.
MakeTensor(const string & name,const std::vector<int32_t> & shape,std::unique_ptr<TensorT> * tensor)224 void MakeTensor(const string& name, const std::vector<int32_t>& shape,
225 std::unique_ptr<TensorT>* tensor) {
226 TensorT* tensor_raw = new TensorT;
227 tensor_raw->name = name;
228 tensor_raw->shape = shape;
229
230 tensor->reset(tensor_raw);
231 }
232
233 // Updates operator code versions for the operators with INT8 inputs.
UpdateInt8OperatorVersions(ModelT * model)234 void UpdateInt8OperatorVersions(ModelT* model) {
235 for (size_t i = 0; i < model->operator_codes.size(); ++i) {
236 const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
237 if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
238 op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
239 op_code == BuiltinOperator_RNN ||
240 op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
241 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
242 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
243 model->operator_codes[i]->version = 2;
244
245 } else if (op_code == BuiltinOperator_FULLY_CONNECTED ||
246 op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
247 op_code == BuiltinOperator_LSTM) {
248 model->operator_codes[i]->version = 3;
249 }
250 }
251 }
252
QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,bool use_hybrid_evaluation,uint64_t weights_min_num_elements)253 TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
254 const Model* input_model,
255 bool use_hybrid_evaluation,
256 uint64_t weights_min_num_elements) {
257 std::unique_ptr<ModelT> model;
258 model.reset(input_model->UnPack());
259
260 // TODO(suharshs): When models support multiple subgraphs, add support.
261 if (model->subgraphs.size() != 1) {
262 LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
263 "subgraph.";
264 return kTfLiteError;
265 }
266
267 SubGraphT* subgraph = model->subgraphs.at(0).get();
268
269 std::vector<std::unique_ptr<OperatorT>> new_operators;
270 std::unordered_map<int32_t, TensorT*> tensor_map;
271 for (size_t i = 0; i < subgraph->operators.size(); ++i) {
272 OperatorT* op = subgraph->operators[i].get();
273 TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
274 model.get(), op, weights_min_num_elements, &tensor_map));
275 }
276
277 // The unordered_map ensures that we quantize each tensor exactly once.
278 // TODO(suharshs): This map key isn't sufficient when we support multiple
279 // subgraphs.
280 for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
281 // Quantize the tensor.
282 TF_LITE_ENSURE_STATUS(
283 utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second));
284 }
285
286 // Examine the tensor consumers to determine which require dequantize ops.
287 for (const auto& tensor_pair : tensor_map) {
288 const int32_t tensor_idx = tensor_pair.first;
289 TensorT* tensor = tensor_pair.second;
290 std::vector<ConsumerOpInfo> consumer_op_infos =
291 GetTensorConsumers(model.get(), subgraph, tensor_idx);
292
293 std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants.
294 for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
295 OperatorT* consumer_op = consumer_op_info.op;
296 const BuiltinOperator consumer_op_code =
297 model->operator_codes[consumer_op->opcode_index]->builtin_code;
298 // If the op is a hybrid op and all the required tensors are quantized,
299 // we have no further work to do, but for all ops that require
300 // dequantization we need to add a Dequantize op.
301 bool eval_hybrid =
302 use_hybrid_evaluation &&
303 IsHybridEvaluationOp(consumer_op, consumer_op_code) &&
304 CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code);
305 if (!eval_hybrid) {
306 dequant_op_infos.push_back(consumer_op_info);
307 }
308 }
309
310 // If no ops require dequant, we are done for this tensor.
311 if (dequant_op_infos.empty()) {
312 continue;
313 }
314
315 // Create a new tensor to be the output of the dequantize op.
316 std::unique_ptr<TensorT> dequantize_output;
317 const string dequant_name = tensor->name + "_dequantize";
318 MakeTensor(dequant_name, tensor->shape, &dequantize_output);
319 const int32_t dequantize_output_idx = subgraph->tensors.size();
320 subgraph->tensors.push_back(std::move(dequantize_output));
321
322 // Create the Dequantize operation.
323 std::unique_ptr<OperatorT> dequantize_op;
324 MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
325 dequantize_output_idx);
326
327 LOG(INFO) << "Creating Dequantize op with name " << dequant_name << ".";
328
329 // Update the op_input of all the ops that need the created dequantize
330 // operation.
331 int32_t min_op_idx = 0;
332 for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
333 dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
334 dequantize_output_idx;
335 min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
336 }
337
338 // Insert the newly created Dequantize operation before the earliest
339 // consumer, since TFLite requires operators to be topo-sorted.
340 subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
341 std::move(dequantize_op));
342 }
343
344 // Update the modified operator code versions.
345 UpdateInt8OperatorVersions(model.get());
346
347 flatbuffers::Offset<Model> output_model_location =
348 Model::Pack(*builder, model.get());
349 FinishModelBuffer(*builder, output_model_location);
350
351 return kTfLiteOk;
352 }
353
354 } // namespace
355
356 namespace internal {
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,bool use_hybrid_evaluation)357 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
358 const Model* input_model,
359 uint64_t weights_min_num_elements,
360 bool use_hybrid_evaluation) {
361 // By default we require that only weights with more than
362 // kWeightsMinSizeDefault elements are quantized.
363 return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
364 weights_min_num_elements);
365 }
366 } // namespace internal
367
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements)368 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
369 const Model* input_model,
370 uint64_t weights_min_num_elements) {
371 return QuantizeWeightsInternal(builder, input_model, true,
372 weights_min_num_elements);
373 }
374
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model)375 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
376 const Model* input_model) {
377 // By default we require that only weights with more than
378 // kWeightsMinSizeDefault elements are quantized.
379 return QuantizeWeightsInternal(builder, input_model, true,
380 kWeightsMinNumElementsDefault);
381 }
382
383 } // namespace optimize
384 } // namespace tflite
385