1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/import.h"
16
17 #include "flatbuffers/flexbuffers.h"
18 #include "tensorflow/lite/model.h"
19 #include "tensorflow/lite/schema/schema_generated.h"
20 #include "tensorflow/lite/toco/tflite/operator.h"
21 #include "tensorflow/lite/toco/tflite/types.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/lite/tools/verifier.h"
24
25 namespace toco {
26
27 namespace tflite {
28
29 namespace details {
LoadTensorsTable(const::tflite::Model & input_model,TensorsTable * tensors_table)30 void LoadTensorsTable(const ::tflite::Model& input_model,
31 TensorsTable* tensors_table) {
32 // TODO(aselle): add support to toco for multiple subgraphs.
33 auto tensors = (*input_model.subgraphs())[0]->tensors();
34 if (!tensors) return;
35 for (const auto* tensor : *tensors) {
36 tensors_table->push_back(tensor->name()->c_str());
37 }
38 }
39
LoadOperatorsTable(const::tflite::Model & input_model,OperatorsTable * operators_table)40 void LoadOperatorsTable(const ::tflite::Model& input_model,
41 OperatorsTable* operators_table) {
42 auto opcodes = input_model.operator_codes();
43 if (!opcodes) return;
44 for (const auto* opcode : *opcodes) {
45 if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
46 operators_table->push_back(
47 EnumNameBuiltinOperator(opcode->builtin_code()));
48 } else {
49 operators_table->push_back(opcode->custom_code()->c_str());
50 }
51 }
52 }
53 } // namespace details
54
ImportTensors(const::tflite::Model & input_model,Model * model)55 void ImportTensors(const ::tflite::Model& input_model, Model* model) {
56 auto tensors = (*input_model.subgraphs())[0]->tensors();
57 auto* buffers = input_model.buffers();
58 // auto tensors = input_model.tensors();
59 if (!tensors) return;
60 for (const auto* input_tensor : *tensors) {
61 Array& array = model->GetOrCreateArray(input_tensor->name()->c_str());
62 array.data_type = DataType::Deserialize(input_tensor->type());
63 int buffer_index = input_tensor->buffer();
64 auto* buffer = buffers->Get(buffer_index);
65 DataBuffer::Deserialize(*input_tensor, *buffer, &array);
66
67 auto shape = input_tensor->shape();
68 if (shape) {
69 // If the shape is 0-dimensional, make sure to record it as such,
70 // as oppose to leaving the array without a shape.
71 array.mutable_shape()->mutable_dims()->clear();
72 for (int i = 0; i < shape->Length(); ++i) {
73 auto d = shape->Get(i);
74 array.mutable_shape()->mutable_dims()->push_back(d);
75 }
76 }
77
78 auto quantization = input_tensor->quantization();
79 if (quantization) {
80 // Note that tf.mini only supports a single quantization parameters for
81 // the whole array.
82 if (quantization->min() && quantization->max()) {
83 CHECK_EQ(1, quantization->min()->Length());
84 CHECK_EQ(1, quantization->max()->Length());
85 MinMax& minmax = array.GetOrCreateMinMax();
86 minmax.min = quantization->min()->Get(0);
87 minmax.max = quantization->max()->Get(0);
88 }
89 if (quantization->scale() && quantization->zero_point()) {
90 CHECK_EQ(1, quantization->scale()->Length());
91 CHECK_EQ(1, quantization->zero_point()->Length());
92 QuantizationParams& q = array.GetOrCreateQuantizationParams();
93 q.scale = quantization->scale()->Get(0);
94 q.zero_point = quantization->zero_point()->Get(0);
95 }
96 }
97 }
98 }
99
ImportOperators(const::tflite::Model & input_model,const std::map<string,std::unique_ptr<BaseOperator>> & ops_by_name,const details::TensorsTable & tensors_table,const details::OperatorsTable & operators_table,Model * model)100 void ImportOperators(
101 const ::tflite::Model& input_model,
102 const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name,
103 const details::TensorsTable& tensors_table,
104 const details::OperatorsTable& operators_table, Model* model) {
105 // TODO(aselle): add support for multiple subgraphs.
106 auto ops = (*input_model.subgraphs())[0]->operators();
107
108 if (!ops) return;
109 for (const auto* input_op : *ops) {
110 int index = input_op->opcode_index();
111 if (index < 0 || index > operators_table.size()) {
112 LOG(FATAL) << "Index " << index << " must be between zero and "
113 << operators_table.size();
114 }
115 string opname = operators_table.at(index);
116
117 // Find and use the appropriate operator deserialization factory.
118 std::unique_ptr<Operator> new_op = nullptr;
119 if (ops_by_name.count(opname) == 0) {
120 string effective_opname = "TENSORFLOW_UNSUPPORTED";
121 if (ops_by_name.count(effective_opname) == 0) {
122 LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found.";
123 }
124 new_op = ops_by_name.at(effective_opname)
125 ->Deserialize(input_op->builtin_options(),
126 input_op->custom_options());
127 if (new_op->type == OperatorType::kUnsupported) {
128 auto* unsupported_op =
129 static_cast<TensorFlowUnsupportedOperator*>(new_op.get());
130 unsupported_op->tensorflow_op = opname;
131 // TODO(b/109932940): Remove this when quantized is removed.
132 // For now, we assume all ops are quantized.
133 unsupported_op->quantized = true;
134 } else {
135 LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator";
136 }
137 } else {
138 new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(),
139 input_op->custom_options());
140 }
141 model->operators.emplace_back(new_op.release());
142 auto* op = model->operators.back().get();
143
144 // Make sure all the inputs and outputs are hooked up.
145 auto inputs = input_op->inputs();
146 for (int i = 0; i < inputs->Length(); i++) {
147 auto input_index = inputs->Get(i);
148 // input_index == -1 indicates optional tensor.
149 if (input_index != -1) {
150 const string& input_name = tensors_table.at(input_index);
151 op->inputs.push_back(input_name);
152 } else {
153 const string& tensor_name =
154 toco::AvailableArrayName(*model, "OptionalTensor");
155 model->CreateOptionalArray(tensor_name);
156 op->inputs.push_back(tensor_name);
157 }
158 }
159 auto outputs = input_op->outputs();
160 for (int i = 0; i < outputs->Length(); i++) {
161 auto output_index = outputs->Get(i);
162 const string& output_name = tensors_table.at(output_index);
163 op->outputs.push_back(output_name);
164 }
165 }
166 }
167
ImportIOTensors(const ModelFlags & model_flags,const::tflite::Model & input_model,const details::TensorsTable & tensors_table,Model * model)168 void ImportIOTensors(const ModelFlags& model_flags,
169 const ::tflite::Model& input_model,
170 const details::TensorsTable& tensors_table, Model* model) {
171 // Import from the first subgraph if input arrays have not been specified.
172 if (model_flags.input_arrays().empty()) {
173 auto inputs = (*input_model.subgraphs())[0]->inputs();
174 if (inputs) {
175 for (int input : *inputs) {
176 const string& input_name = tensors_table.at(input);
177 model->flags.add_input_arrays()->set_name(input_name);
178 }
179 }
180 }
181
182 // Import from the first subgraph if output arrays have not been specified.
183 if (model_flags.output_arrays().empty()) {
184 auto outputs = (*input_model.subgraphs())[0]->outputs();
185 if (outputs) {
186 for (int output : *outputs) {
187 const string& output_name = tensors_table.at(output);
188 model->flags.add_output_arrays(output_name);
189 }
190 }
191 }
192 }
193
194 namespace {
Verify(const void * buf,size_t len)195 bool Verify(const void* buf, size_t len) {
196 ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
197 return ::tflite::VerifyModelBuffer(verifier);
198 }
199 } // namespace
200
Import(const ModelFlags & model_flags,const string & input_file_contents)201 std::unique_ptr<Model> Import(const ModelFlags& model_flags,
202 const string& input_file_contents) {
203 ::tflite::AlwaysTrueResolver r;
204 if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(),
205 r, ::tflite::DefaultErrorReporter())) {
206 LOG(FATAL) << "Invalid flatbuffer.";
207 }
208 const ::tflite::Model* input_model =
209 ::tflite::GetModel(input_file_contents.data());
210
211 // Full list of all known operators.
212 const auto ops_by_name = BuildOperatorByNameMap();
213
214 if (!input_model->subgraphs() || input_model->subgraphs()->size() != 1) {
215 LOG(FATAL) << "Number of subgraphs in tflite should be exactly 1.";
216 }
217 std::unique_ptr<Model> model;
218 model.reset(new Model);
219
220 details::TensorsTable tensors_table;
221 details::LoadTensorsTable(*input_model, &tensors_table);
222
223 details::OperatorsTable operators_table;
224 details::LoadOperatorsTable(*input_model, &operators_table);
225
226 ImportTensors(*input_model, model.get());
227 ImportOperators(*input_model, ops_by_name, tensors_table, operators_table,
228 model.get());
229
230 ImportIOTensors(model_flags, *input_model, tensors_table, model.get());
231
232 UndoWeightsShuffling(model.get());
233
234 return model;
235 }
236
237 } // namespace tflite
238
239 } // namespace toco
240