1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/import_tensorflow.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "google/protobuf/map.h"
23 #include "google/protobuf/text_format.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_split.h"
29 #include "absl/strings/strip.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/public/session_options.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/lite/toco/model.h"
47 #include "tensorflow/lite/toco/model_flags.pb.h"
48 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
49 #include "tensorflow/lite/toco/tensorflow_util.h"
50 #include "tensorflow/lite/toco/tooling_util.h"
51
52 using tensorflow::AttrValue;
53 using tensorflow::DT_BOOL;
54 using tensorflow::DT_COMPLEX64;
55 using tensorflow::DT_FLOAT;
56 using tensorflow::DT_INT32;
57 using tensorflow::DT_INT64;
58 using tensorflow::DT_QUINT8;
59 using tensorflow::DT_STRING;
60 using tensorflow::DT_UINT32;
61 using tensorflow::DT_UINT8;
62 using tensorflow::GraphDef;
63 using tensorflow::NodeDef;
64 using tensorflow::OpRegistry;
65 using tensorflow::TensorProto;
66 using tensorflow::TensorShapeProto;
67
68 namespace toco {
69
70 namespace {
HasAttr(const NodeDef & node,const std::string & attr_name)71 bool HasAttr(const NodeDef& node, const std::string& attr_name) {
72 return node.attr().count(attr_name) > 0;
73 }
74
HasWildcardDimension(const TensorShapeProto & shape)75 bool HasWildcardDimension(const TensorShapeProto& shape) {
76 for (const auto& dim : shape.dim()) {
77 if (dim.size() == -1) return true;
78 }
79 return false;
80 }
81
GetStringAttr(const NodeDef & node,const std::string & attr_name)82 const std::string& GetStringAttr(const NodeDef& node,
83 const std::string& attr_name) {
84 CHECK(HasAttr(node, attr_name));
85 const auto& attr = node.attr().at(attr_name);
86 CHECK_EQ(attr.value_case(), AttrValue::kS);
87 return attr.s();
88 }
89
GetIntAttr(const NodeDef & node,const std::string & attr_name)90 int64 GetIntAttr(const NodeDef& node, const std::string& attr_name) {
91 CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
92 << node.DebugString();
93 const auto& attr = node.attr().at(attr_name);
94 CHECK_EQ(attr.value_case(), AttrValue::kI);
95 return attr.i();
96 }
97
GetFloatAttr(const NodeDef & node,const std::string & attr_name)98 float GetFloatAttr(const NodeDef& node, const std::string& attr_name) {
99 CHECK(HasAttr(node, attr_name));
100 const auto& attr = node.attr().at(attr_name);
101 CHECK_EQ(attr.value_case(), AttrValue::kF);
102 return attr.f();
103 }
104
GetBoolAttr(const NodeDef & node,const std::string & attr_name)105 bool GetBoolAttr(const NodeDef& node, const std::string& attr_name) {
106 CHECK(HasAttr(node, attr_name));
107 const auto& attr = node.attr().at(attr_name);
108 CHECK_EQ(attr.value_case(), AttrValue::kB);
109 return attr.b();
110 }
111
GetDataTypeAttr(const NodeDef & node,const std::string & attr_name)112 tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
113 const std::string& attr_name) {
114 CHECK(HasAttr(node, attr_name));
115 const auto& attr = node.attr().at(attr_name);
116 CHECK_EQ(attr.value_case(), AttrValue::kType);
117 return attr.type();
118 }
119
GetShapeAttr(const NodeDef & node,const std::string & attr_name)120 const TensorShapeProto& GetShapeAttr(const NodeDef& node,
121 const std::string& attr_name) {
122 CHECK(HasAttr(node, attr_name));
123 const auto& attr = node.attr().at(attr_name);
124 CHECK_EQ(attr.value_case(), AttrValue::kShape);
125 return attr.shape();
126 }
127
GetTensorAttr(const NodeDef & node,const std::string & attr_name)128 const TensorProto& GetTensorAttr(const NodeDef& node,
129 const std::string& attr_name) {
130 CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
131 const auto& attr = node.attr().at(attr_name);
132 CHECK_EQ(attr.value_case(), AttrValue::kTensor);
133 return attr.tensor();
134 }
135
GetListAttr(const NodeDef & node,const std::string & attr_name)136 const AttrValue::ListValue& GetListAttr(const NodeDef& node,
137 const std::string& attr_name) {
138 CHECK(HasAttr(node, attr_name));
139 const auto& attr = node.attr().at(attr_name);
140 CHECK_EQ(attr.value_case(), AttrValue::kList);
141 return attr.list();
142 }
143
CheckOptionalAttr(const NodeDef & node,const std::string & attr_name,const std::string & expected_value)144 tensorflow::Status CheckOptionalAttr(const NodeDef& node,
145 const std::string& attr_name,
146 const std::string& expected_value) {
147 if (HasAttr(node, attr_name)) {
148 const std::string& value = GetStringAttr(node, attr_name);
149 if (value != expected_value) {
150 return tensorflow::errors::InvalidArgument(
151 "Unexpected value for attribute '" + attr_name + "'. Expected '" +
152 expected_value + "'");
153 }
154 }
155 return tensorflow::Status::OK();
156 }
157
CheckOptionalAttr(const NodeDef & node,const std::string & attr_name,const tensorflow::DataType & expected_value)158 tensorflow::Status CheckOptionalAttr(
159 const NodeDef& node, const std::string& attr_name,
160 const tensorflow::DataType& expected_value) {
161 if (HasAttr(node, attr_name)) {
162 const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
163 if (value != expected_value) {
164 return tensorflow::errors::InvalidArgument(
165 "Unexpected value for attribute '" + attr_name + "'. Expected '" +
166 tensorflow::DataType_Name(expected_value) + "'");
167 }
168 }
169 return tensorflow::Status::OK();
170 }
171
172 template <typename T1, typename T2>
ExpectValue(const T1 & v1,const T2 & v2,const std::string & description)173 tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
174 const std::string& description) {
175 if (v1 == v2) return tensorflow::Status::OK();
176 return tensorflow::errors::InvalidArgument(absl::StrCat(
177 "Unexpected ", description, ": got ", v1, ", expected ", v2));
178 }
179
ConvertDataType(tensorflow::DataType dtype)180 ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
181 if (dtype == DT_UINT8)
182 return ArrayDataType::kUint8;
183 else if (dtype == DT_FLOAT)
184 return ArrayDataType::kFloat;
185 else if (dtype == DT_BOOL)
186 return ArrayDataType::kBool;
187 else if (dtype == DT_INT32)
188 return ArrayDataType::kInt32;
189 else if (dtype == DT_UINT32)
190 return ArrayDataType::kUint32;
191 else if (dtype == DT_INT64)
192 return ArrayDataType::kInt64;
193 else if (dtype == DT_STRING)
194 return ArrayDataType::kString;
195 else if (dtype == DT_COMPLEX64)
196 return ArrayDataType::kComplex64;
197 else
198 LOG(INFO) << "Unsupported data type in placeholder op: " << dtype;
199 return ArrayDataType::kNone;
200 }
201
ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim> & input_dims,int * input_flat_size,Shape * shape)202 tensorflow::Status ImportShape(
203 const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
204 input_dims,
205 int* input_flat_size, Shape* shape) {
206 std::vector<int> input_dims_only_sizes;
207 bool zero_sized_shape = false;
208 for (auto& d : input_dims) {
209 // TensorFlow's shapes use int64s, while TOCO uses ints.
210 if (d.size() > std::numeric_limits<int>::max()) {
211 return tensorflow::errors::InvalidArgument("Shape element overflows");
212 }
213 if (d.size() == 0) {
214 zero_sized_shape = true;
215 }
216 input_dims_only_sizes.push_back(d.size());
217 }
218
219 // Note that up to this point we were OK with the input shape containing
220 // elements valued -1 or 0, which are perfectly legal in tensorflow. However
221 // our CheckValidShapeDimensions() insists on them being >= 1, with the
222 // exception of the "scalar" shape [0]. The main issue with zero-values shape
223 // elements is that the corresponding arrays don't contain any data and the
224 // allocation code gets a bit confused. It seems that the code expects an
225 // empty shape for zero-sized shapes, so we will do just that, except for the
226 // [0] case.
227 // TODO(b/119325030): In order to correctly import the "scalar" shapes the
228 // following test must include "&& input_dims_only_sizes.size() > 1", but
229 // that seems to slow everything down a lot.
230 if (zero_sized_shape) {
231 shape->mutable_dims()->clear();
232 if (input_flat_size != nullptr) *input_flat_size = 0;
233 return tensorflow::Status::OK();
234 }
235
236 *shape->mutable_dims() = input_dims_only_sizes;
237
238 if (input_flat_size == nullptr) return tensorflow::Status::OK();
239
240 return NumElements(input_dims_only_sizes, input_flat_size);
241 }
242
243 // Define ways to retrieve data from tensors of different types.
244 // TODO(b/80208043): simply use tensorflow::Tensor::FromProto() instead.
245 template <typename T>
246 struct TensorTraits;
247
248 template <>
249 struct TensorTraits<float> {
sizetoco::__anon2e11eff70111::TensorTraits250 static int size(const TensorProto& p) { return p.float_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits251 static float get(const TensorProto& p, int i) { return p.float_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits252 static std::string accessor_name() { return "float_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits253 static std::string type_name() { return "float"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits254 static void CopyFromContent(const TensorProto& p, std::vector<float>* data) {
255 toco::port::CopyToBuffer(p.tensor_content(),
256 reinterpret_cast<char*>(data->data()));
257 }
258 };
259
260 template <>
261 struct TensorTraits<uint8_t> {
sizetoco::__anon2e11eff70111::TensorTraits262 static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits263 static uint8_t get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits264 static std::string accessor_name() { return "int_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits265 static std::string type_name() { return "uint8"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits266 static void CopyFromContent(const TensorProto& p,
267 std::vector<uint8_t>* data) {
268 toco::port::CopyToBuffer(p.tensor_content(),
269 reinterpret_cast<char*>(data->data()));
270 }
271 };
272
273 template <>
274 struct TensorTraits<std::complex<float>> {
sizetoco::__anon2e11eff70111::TensorTraits275 static int size(const TensorProto& p) { return p.scomplex_val_size() / 2; }
gettoco::__anon2e11eff70111::TensorTraits276 static std::complex<float> get(const TensorProto& p, int i) {
277 return std::complex<float>(p.scomplex_val(2 * i),
278 p.scomplex_val(2 * i + 1));
279 }
accessor_nametoco::__anon2e11eff70111::TensorTraits280 static std::string accessor_name() { return "scomplex_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits281 static std::string type_name() { return "complex64"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits282 static void CopyFromContent(const TensorProto& p,
283 std::vector<std::complex<float>>* data) {
284 toco::port::CopyToBuffer(p.tensor_content(),
285 reinterpret_cast<char*>(data->data()));
286 }
287 };
288
289 template <>
290 struct TensorTraits<int32> {
sizetoco::__anon2e11eff70111::TensorTraits291 static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits292 static int32 get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits293 static std::string accessor_name() { return "int_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits294 static std::string type_name() { return "int32"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits295 static void CopyFromContent(const TensorProto& p, std::vector<int32>* data) {
296 toco::port::CopyToBuffer(p.tensor_content(),
297 reinterpret_cast<char*>(data->data()));
298 }
299 };
300
301 template <>
302 struct TensorTraits<uint32> {
sizetoco::__anon2e11eff70111::TensorTraits303 static int size(const TensorProto& p) { return p.uint32_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits304 static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits305 static std::string accessor_name() { return "uint32_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits306 static std::string type_name() { return "uint32"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits307 static void CopyFromContent(const TensorProto& p, std::vector<uint32>* data) {
308 toco::port::CopyToBuffer(p.tensor_content(),
309 reinterpret_cast<char*>(data->data()));
310 }
311 };
312
313 template <>
314 struct TensorTraits<int64> {
sizetoco::__anon2e11eff70111::TensorTraits315 static int size(const TensorProto& p) { return p.int64_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits316 static int64 get(const TensorProto& p, int i) { return p.int64_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits317 static std::string accessor_name() { return "int64_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits318 static std::string type_name() { return "int64"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits319 static void CopyFromContent(const TensorProto& p, std::vector<int64>* data) {
320 toco::port::CopyToBuffer(p.tensor_content(),
321 reinterpret_cast<char*>(data->data()));
322 }
323 };
324
325 template <>
326 struct TensorTraits<bool> {
sizetoco::__anon2e11eff70111::TensorTraits327 static int size(const TensorProto& p) { return p.bool_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits328 static bool get(const TensorProto& p, int i) { return p.bool_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits329 static std::string accessor_name() { return "bool_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits330 static std::string type_name() { return "bool"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits331 static void CopyFromContent(const TensorProto& p, std::vector<bool>* data) {
332 std::vector<char> buf(p.tensor_content().size());
333 toco::port::CopyToBuffer(p.tensor_content(), buf.data());
334 for (int i = 0; i < p.tensor_content().size(); i++) {
335 (*data)[i] = static_cast<bool>(buf[i]);
336 }
337 }
338 };
339
340 template <typename T>
ImportTensorData(const TensorProto & input_tensor,int input_flat_size,std::vector<T> * output_data)341 tensorflow::Status ImportTensorData(const TensorProto& input_tensor,
342 int input_flat_size,
343 std::vector<T>* output_data) {
344 CHECK_GE(output_data->size(), input_flat_size);
345 int num_elements_in_tensor = TensorTraits<T>::size(input_tensor);
346 if (num_elements_in_tensor == input_flat_size) {
347 for (int i = 0; i < num_elements_in_tensor; i++) {
348 (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
349 }
350 } else if (input_tensor.tensor_content().size() ==
351 input_flat_size * sizeof(T)) {
352 TensorTraits<T>::CopyFromContent(input_tensor, output_data);
353 } else if (num_elements_in_tensor > 0 &&
354 num_elements_in_tensor < input_flat_size) {
355 // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
356 // official way to import tensor data. This particular else-if handles a
357 // grappler optimization where the last few elements in a tensor are
358 // omitted if they are repeated.
359 int i = 0;
360 for (; i < num_elements_in_tensor; ++i) {
361 (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
362 }
363 auto last = (*output_data)[i - 1];
364 for (; i < input_flat_size; ++i) {
365 (*output_data)[i] = last;
366 }
367 } else {
368 std::string accessor_name = TensorTraits<T>::accessor_name();
369 std::string type_name = TensorTraits<T>::type_name();
370 return tensorflow::errors::InvalidArgument(
371 absl::StrCat("Neither input_content (",
372 input_tensor.tensor_content().size() / sizeof(T), ") nor ",
373 accessor_name, " (", num_elements_in_tensor,
374 ") have the right dimensions (", input_flat_size,
375 ") for this ", type_name, " tensor"));
376 }
377 return tensorflow::Status::OK();
378 }
379
ImportFloatArray(const TensorProto & input_tensor,Array * output_array)380 tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
381 Array* output_array) {
382 CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
383 const auto& input_shape = input_tensor.tensor_shape();
384 CHECK_LE(input_shape.dim_size(), 6);
385 int input_flat_size;
386 auto status = ImportShape(input_shape.dim(), &input_flat_size,
387 output_array->mutable_shape());
388 if (!status.ok()) return status;
389
390 auto& output_float_data =
391 output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
392 output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
393 0.f);
394 return ImportTensorData<float>(input_tensor, input_flat_size,
395 &output_float_data);
396 }
397
ImportComplex64Array(const TensorProto & input_tensor,Array * output_array)398 tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor,
399 Array* output_array) {
400 CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64);
401 const auto& input_shape = input_tensor.tensor_shape();
402 CHECK_LE(input_shape.dim_size(), 4);
403 int input_flat_size;
404 auto status = ImportShape(input_shape.dim(), &input_flat_size,
405 output_array->mutable_shape());
406 if (!status.ok()) return status;
407
408 auto& output_complex_data =
409 output_array->GetMutableBuffer<ArrayDataType::kComplex64>().data;
410 output_complex_data.resize(RequiredBufferSizeForShape(output_array->shape()),
411 std::complex<float>(0.f, 0.f));
412 return ImportTensorData<std::complex<float>>(input_tensor, input_flat_size,
413 &output_complex_data);
414 }
415
ImportQuint8Array(const TensorProto & input_tensor,Array * output_array)416 tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
417 Array* output_array) {
418 CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
419 const auto& input_shape = input_tensor.tensor_shape();
420 CHECK_LE(input_shape.dim_size(), 6);
421 int input_flat_size;
422 auto status = ImportShape(input_shape.dim(), &input_flat_size,
423 output_array->mutable_shape());
424 if (!status.ok()) return status;
425
426 auto& output_int_data =
427 output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
428 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
429 return ImportTensorData<uint8_t>(input_tensor, input_flat_size,
430 &output_int_data);
431 }
432
ImportInt32Array(const TensorProto & input_tensor,Array * output_array)433 tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
434 Array* output_array) {
435 CHECK_EQ(input_tensor.dtype(), DT_INT32);
436 const auto& input_shape = input_tensor.tensor_shape();
437 CHECK_LE(input_shape.dim_size(), 6);
438 int input_flat_size;
439 auto status = ImportShape(input_shape.dim(), &input_flat_size,
440 output_array->mutable_shape());
441 if (!status.ok()) return status;
442
443 auto& output_int_data =
444 output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
445 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
446 return ImportTensorData<int32>(input_tensor, input_flat_size,
447 &output_int_data);
448 }
449
ImportUint32Array(const TensorProto & input_tensor,Array * output_array)450 tensorflow::Status ImportUint32Array(const TensorProto& input_tensor,
451 Array* output_array) {
452 CHECK_EQ(input_tensor.dtype(), DT_UINT32);
453 const auto& input_shape = input_tensor.tensor_shape();
454 CHECK_LE(input_shape.dim_size(), 6);
455 int input_flat_size;
456 auto status = ImportShape(input_shape.dim(), &input_flat_size,
457 output_array->mutable_shape());
458 if (!status.ok()) return status;
459
460 auto& output_int_data =
461 output_array->GetMutableBuffer<ArrayDataType::kUint32>().data;
462 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
463 return ImportTensorData<uint32>(input_tensor, input_flat_size,
464 &output_int_data);
465 }
466
ImportInt64Array(const TensorProto & input_tensor,Array * output_array)467 tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
468 Array* output_array) {
469 CHECK_EQ(input_tensor.dtype(), DT_INT64);
470 const auto& input_shape = input_tensor.tensor_shape();
471 CHECK_LE(input_shape.dim_size(), 6);
472 int input_flat_size;
473 auto status = ImportShape(input_shape.dim(), &input_flat_size,
474 output_array->mutable_shape());
475 if (!status.ok()) return status;
476
477 auto& output_int_data =
478 output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
479 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
480 return ImportTensorData<int64>(input_tensor, input_flat_size,
481 &output_int_data);
482 }
483
ImportBoolArray(const TensorProto & input_tensor,Array * output_array)484 tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
485 Array* output_array) {
486 CHECK_EQ(input_tensor.dtype(), DT_BOOL);
487 const auto& input_shape = input_tensor.tensor_shape();
488 CHECK_LE(input_shape.dim_size(), 6);
489 int input_flat_size;
490 auto status = ImportShape(input_shape.dim(), &input_flat_size,
491 output_array->mutable_shape());
492 if (!status.ok()) return status;
493
494 auto& output_bool_data =
495 output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
496 output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
497 false);
498 status =
499 ImportTensorData<bool>(input_tensor, input_flat_size, &output_bool_data);
500 if (!status.ok() && output_bool_data.size() == 1) {
501 // Some graphs have bool const nodes without actual value...
502 // assuming that 'false' is implied.
503 // So far only encountered that in an array with 1 entry, let's
504 // require that until we encounter a graph where that's not the case.
505 output_bool_data[0] = false;
506 return tensorflow::Status::OK();
507 }
508 return status;
509 }
510
ImportStringArray(const TensorProto & input_tensor,Array * output_array)511 tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
512 Array* output_array) {
513 CHECK_EQ(input_tensor.dtype(), DT_STRING);
514 const auto& input_shape = input_tensor.tensor_shape();
515 CHECK_LE(input_shape.dim_size(), 6);
516 int input_flat_size;
517 auto status = ImportShape(input_shape.dim(), &input_flat_size,
518 output_array->mutable_shape());
519 if (!status.ok()) return status;
520
521 if (input_flat_size != input_tensor.string_val_size()) {
522 return tensorflow::errors::InvalidArgument(
523 "Input_content string_val doesn't have the right dimensions "
524 "for this string tensor");
525 }
526
527 auto& output_string_data =
528 output_array->GetMutableBuffer<ArrayDataType::kString>().data;
529 output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
530 CHECK_GE(output_string_data.size(), input_flat_size);
531 for (int i = 0; i < input_flat_size; ++i) {
532 output_string_data[i] = input_tensor.string_val(i);
533 }
534 return tensorflow::Status::OK();
535 }
536
537 // Count the number of inputs of a given node. If
538 // `tf_import_flags.drop_control_dependency` is true, count the number of
539 // non-control-dependency inputs.
GetInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags)540 int GetInputsCount(const NodeDef& node,
541 const TensorFlowImportFlags& tf_import_flags) {
542 if (tf_import_flags.drop_control_dependency) {
543 for (size_t i = 0; i < node.input_size(); ++i) {
544 if (node.input(i)[0] == '^') {
545 return i;
546 }
547 }
548 }
549 return node.input_size();
550 }
551
CheckInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,int expected_input_count)552 tensorflow::Status CheckInputsCount(
553 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
554 int expected_input_count) {
555 if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
556 return tensorflow::errors::FailedPrecondition(
557 node.op(), " node expects ", expected_input_count,
558 " input(s) other than control dependencies: ", node.DebugString());
559 }
560 return tensorflow::Status::OK();
561 }
562
563 template <ArrayDataType T>
CreateConstArray(Model * model,std::string const & name,std::vector<typename toco::DataType<T>> const & data)564 std::string CreateConstArray(
565 Model* model, std::string const& name,
566 std::vector<typename toco::DataType<T>> const& data) {
567 // Utility function to create a const 1D array, useful for input parameters.
568 std::string array_name = toco::AvailableArrayName(*model, name);
569 auto& array = model->GetOrCreateArray(array_name);
570 array.data_type = T;
571 array.mutable_shape()->mutable_dims()->emplace_back(
572 static_cast<int>(data.size()));
573 array.GetMutableBuffer<T>().data = data;
574 return array_name;
575 }
576
577 // Retain TensorFlow NodeDef in Toco Operator.
578 //
579 // If an op is supported by Toco but not supported by TFLite, TFLite exporter
580 // will use the retained NodeDef to populate a Flex op when Flex mode is
581 // enabled.
582 //
583 // This can't be easily applied to all operations, because a TensorFlow node
584 // may become multiple Toco operators. Thus we need to call this function in
585 // operator conversion functions one by one whenever feasible.
586 //
587 // This may cause problems if a graph transformation rule changes parameters
588 // of the node. When calling this function, please check if any existing
589 // graph transformation rule will change an existing operator with the same
590 // type.
591 //
592 // This provides a route to handle Toco-supported & TFLite-unsupported ops
593 // in Flex mode. However it's not a solid solution. Eventually we should
594 // get rid of this.
595 // TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
596 // this function.
RetainTensorFlowNodeDef(const NodeDef & node,Operator * op)597 void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
598 node.SerializeToString(&op->tensorflow_node_def);
599 }
600
GetOutputNamesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)601 void GetOutputNamesFromNodeDef(const NodeDef& node,
602 const tensorflow::OpDef& op_def,
603 TensorFlowUnsupportedOperator* op) {
604 int next_output = 0;
605 auto add_output = [&node, &next_output, op]() {
606 if (next_output == 0) {
607 op->outputs.push_back(node.name()); // Implicit :0.
608 } else {
609 op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
610 }
611 ++next_output;
612 };
613 for (int i = 0; i < op_def.output_arg_size(); ++i) {
614 std::string multiples = op_def.output_arg(i).number_attr();
615 if (!multiples.empty()) {
616 CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
617 int num_outputs = GetIntAttr(node, multiples);
618 for (int j = 0; j < num_outputs; ++j) {
619 add_output();
620 }
621 } else {
622 std::string list = op_def.output_arg(i).type_list_attr();
623 if (!list.empty()) {
624 CHECK(HasAttr(node, list)) << "No attr named " << list;
625 const AttrValue::ListValue& list_value = GetListAttr(node, list);
626 for (int j = 0; j < list_value.type_size(); ++j) {
627 add_output();
628 }
629 } else {
630 add_output();
631 }
632 }
633 }
634 }
635
GetOutputTypesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)636 void GetOutputTypesFromNodeDef(const NodeDef& node,
637 const tensorflow::OpDef& op_def,
638 TensorFlowUnsupportedOperator* op) {
639 // The given type to the op, or clear the types if invalid.
640 auto add_type = [&node, op](tensorflow::DataType type) {
641 if (type == tensorflow::DT_INVALID) {
642 LOG(WARNING) << "Op node missing output type attribute: " << node.name();
643 op->output_data_types.clear();
644 } else {
645 op->output_data_types.push_back(ConvertDataType(type));
646 }
647 };
648
649 // Retrieve the data type according to the OpDef definition: either the
650 // "type" or "type_attr" field will be set.
651 auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
652 if (a.type() != tensorflow::DT_INVALID) {
653 return a.type();
654 } else if (HasAttr(node, a.type_attr())) {
655 return GetDataTypeAttr(node, a.type_attr());
656 } else {
657 return tensorflow::DT_INVALID;
658 }
659 };
660
661 for (int i = 0; i < op_def.output_arg_size(); ++i) {
662 std::string multiples = op_def.output_arg(i).number_attr();
663 if (!multiples.empty()) {
664 CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
665 int num_outputs = GetIntAttr(node, multiples);
666 auto type = get_type(op_def.output_arg(i));
667 for (int j = 0; j < num_outputs; ++j) {
668 add_type(type);
669 }
670 } else {
671 std::string list = op_def.output_arg(i).type_list_attr();
672 if (!list.empty()) {
673 CHECK(HasAttr(node, list)) << "No attr named " << list;
674 const AttrValue::ListValue& list_value = GetListAttr(node, list);
675 for (int j = 0; j < list_value.type_size(); ++j) {
676 add_type(list_value.type(j));
677 }
678 } else {
679 add_type(get_type(op_def.output_arg(i)));
680 }
681 }
682 }
683 }
684
ConvertUnsupportedOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)685 tensorflow::Status ConvertUnsupportedOperator(
686 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
687 const ModelFlags& model_flags, Model* model) {
688 // Names of special attributes in TF graph that are used by Toco.
689 static constexpr char kAttrOutputQuantized[] = "_output_quantized";
690 static constexpr char kAttrOutputTypes[] = "_output_types";
691 static constexpr char kAttrOutputShapes[] = "_output_shapes";
692 static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
693 "_support_output_type_float_in_quantized_op";
694
695 LOG(INFO) << "Converting unsupported operation: " << node.op();
696
697 auto* op = new TensorFlowUnsupportedOperator;
698 op->tensorflow_op = node.op();
699
700 // For Flex mode. Please read the comments of the function.
701 RetainTensorFlowNodeDef(node, op);
702
703 model->operators.emplace_back(op);
704
705 // Parse inputs.
706 const int num_inputs = GetInputsCount(node, tf_import_flags);
707 for (int i = 0; i < num_inputs; ++i) {
708 op->inputs.push_back(node.input(i));
709 }
710
711 // Parse outputs. Name them after the node's name, plus an ordinal suffix.
712 // Note that some outputs are to be multiplied by a named attribute.
713 const tensorflow::OpDef* op_def = nullptr;
714 if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
715 GetOutputNamesFromNodeDef(node, *op_def, op);
716 } else {
717 op->outputs.push_back(node.name()); // Implicit :0.
718 }
719
720 // Parse if the op supports quantization
721 if (HasAttr(node, kAttrOutputQuantized)) {
722 op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
723 }
724 // Parse if the quantized op allows output arrays of type float
725 if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
726 op->support_output_type_float_in_quantized_op =
727 GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
728 }
729
730 // Parse output type(s).
731 if (HasAttr(node, kAttrOutputTypes)) {
732 const auto& output_types = GetListAttr(node, kAttrOutputTypes);
733 for (int i = 0; i < output_types.type_size(); ++i) {
734 op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
735 }
736 } else if (HasAttr(node, "Tout")) {
737 const auto& output_type = GetDataTypeAttr(node, "Tout");
738 op->output_data_types.push_back(ConvertDataType(output_type));
739 } else if (op_def != nullptr) {
740 GetOutputTypesFromNodeDef(node, *op_def, op);
741 } else {
742 // TODO(b/113613439): Figure out how to propagate types for custom ops
743 // that have no OpDef.
744 LOG(INFO) << "Unable to determine output type for op: " << node.op();
745 }
746
747 // Parse output shape(s).
748 if (HasAttr(node, kAttrOutputShapes)) {
749 const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
750 Shape output_shape;
751 for (int i = 0; i < output_shapes.shape_size(); ++i) {
752 const auto& shape = output_shapes.shape(i);
753 // TOCO doesn't yet properly handle shapes with wildcard dimensions.
754 // TODO(b/113613439): Handle shape inference for unsupported ops that have
755 // shapes with wildcard dimensions.
756 if (HasWildcardDimension(shape)) {
757 LOG(INFO) << "Skipping wildcard output shape(s) for node: "
758 << node.name();
759 op->output_shapes.clear();
760 break;
761 }
762 const auto status =
763 ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
764 if (!status.ok()) {
765 return status;
766 }
767 op->output_shapes.push_back(output_shape);
768 }
769 }
770 return tensorflow::Status::OK();
771 }
772
ConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)773 tensorflow::Status ConvertConstOperator(
774 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
775 const ModelFlags& model_flags, Model* model) {
776 CHECK_EQ(node.op(), "Const");
777 const auto& tensor = GetTensorAttr(node, "value");
778 const auto dtype = GetDataTypeAttr(node, "dtype");
779
780 tensorflow::Status status = tensorflow::Status::OK();
781
782 auto& array = model->GetOrCreateArray(node.name());
783 switch (dtype) {
784 case DT_FLOAT:
785 array.data_type = ArrayDataType::kFloat;
786 status = ImportFloatArray(tensor, &array);
787 break;
788 case DT_INT32:
789 array.data_type = ArrayDataType::kInt32;
790 status = ImportInt32Array(tensor, &array);
791 break;
792 case DT_UINT32:
793 array.data_type = ArrayDataType::kUint32;
794 status = ImportUint32Array(tensor, &array);
795 break;
796 case DT_QUINT8:
797 array.data_type = ArrayDataType::kUint8;
798 status = ImportQuint8Array(tensor, &array);
799 break;
800 case DT_INT64:
801 array.data_type = ArrayDataType::kInt64;
802 status = ImportInt64Array(tensor, &array);
803 break;
804 case DT_STRING:
805 array.data_type = ArrayDataType::kString;
806 status = ImportStringArray(tensor, &array);
807 break;
808 case DT_BOOL:
809 array.data_type = ArrayDataType::kBool;
810 status = ImportBoolArray(tensor, &array);
811 break;
812 case DT_COMPLEX64:
813 array.data_type = ArrayDataType::kComplex64;
814 status = ImportComplex64Array(tensor, &array);
815 break;
816 default:
817 array.data_type = ArrayDataType::kNone;
818 // do nothing, silently ignore the Const data.
819 // We just make a dummy buffer to indicate that
820 // this array does not rely on external input.
821 array.GetMutableBuffer<ArrayDataType::kNone>();
822 break;
823 }
824 TF_RETURN_WITH_CONTEXT_IF_ERROR(
825 status, " (while processing node '" + node.name() + "')");
826 return tensorflow::Status::OK();
827 }
828
ConvertConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)829 tensorflow::Status ConvertConvOperator(
830 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
831 const ModelFlags& model_flags, Model* model) {
832 CHECK_EQ(node.op(), "Conv2D");
833 TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
834
835 // We only support NHWC, which is the default data_format.
836 // So if data_format is not defined, we're all good.
837 TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
838 TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
839
840 const auto& input_name = node.input(0);
841 const auto& weights_name = node.input(1);
842 const auto& reordered_weights_name =
843 AvailableArrayName(*model, weights_name + "_reordered");
844 // Check if a ReorderAxesOperator was already created for these weights
845 // (that happens when multiple layers share the same weights).
846 const Operator* existing_reorder =
847 GetOpWithOutput(*model, reordered_weights_name);
848 if (existing_reorder) {
849 // Check that it is safe to rely on the _reordered naming of the output
850 // array!
851 CHECK(existing_reorder->type == OperatorType::kReorderAxes);
852 } else {
853 // Create a new ReorderAxesOperator
854 auto* reorder = new ReorderAxesOperator;
855 reorder->inputs = {weights_name};
856 reorder->outputs = {reordered_weights_name};
857 reorder->input_axes_order = AxesOrder::kHWIO;
858 reorder->output_axes_order = AxesOrder::kOHWI;
859 model->operators.emplace_back(reorder);
860 }
861 if (!HasAttr(node, "strides")) {
862 return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
863 }
864 const auto& strides = GetListAttr(node, "strides");
865 TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
866 TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
867 TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
868 int dilation_height_factor;
869 int dilation_width_factor;
870 if (HasAttr(node, "dilations")) {
871 const auto& dilations = GetListAttr(node, "dilations");
872 TF_RETURN_IF_ERROR(
873 ExpectValue(dilations.i_size(), 4, "number of dilations"));
874 if (dilations.i(0) != 1 || dilations.i(3) != 1) {
875 return tensorflow::errors::InvalidArgument(absl::StrCat(
876 "Can only import Conv ops with dilation along the height "
877 "(1st) or width (2nd) axis. TensorFlow op \"",
878 node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
879 dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
880 }
881 dilation_height_factor = dilations.i(1);
882 dilation_width_factor = dilations.i(2);
883 } else {
884 dilation_height_factor = 1;
885 dilation_width_factor = 1;
886 }
887 const auto& padding = GetStringAttr(node, "padding");
888 PaddingType padding_type;
889 if (padding == "SAME") {
890 padding_type = PaddingType::kSame;
891 } else if (padding == "VALID") {
892 padding_type = PaddingType::kValid;
893 } else {
894 return tensorflow::errors::InvalidArgument(
895 "Bad padding (only SAME and VALID are supported)");
896 }
897 auto* conv = new ConvOperator;
898 conv->inputs = {input_name, reordered_weights_name};
899 conv->outputs = {node.name()};
900 conv->stride_height = strides.i(1);
901 conv->stride_width = strides.i(2);
902 conv->dilation_height_factor = dilation_height_factor;
903 conv->dilation_width_factor = dilation_width_factor;
904 conv->padding.type = padding_type;
905 model->operators.emplace_back(conv);
906
907 return tensorflow::Status::OK();
908 }
909
ConvertDepthwiseConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)910 tensorflow::Status ConvertDepthwiseConvOperator(
911 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
912 const ModelFlags& model_flags, Model* model) {
913 CHECK_EQ(node.op(), "DepthwiseConv2dNative");
914 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
915
916 // We only support NHWC, which is the default data_format.
917 // So if data_format is not defined, we're all good.
918 if (HasAttr(node, "data_format")) {
919 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
920 }
921 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
922
923 const auto& input_name = node.input(0);
924 const auto& weights_name = node.input(1);
925 const auto& reordered_weights_name = weights_name + "_reordered";
926 // Check if a ReorderAxesOperator was already created for these weights
927 // (that happens when multiple layers share the same weights).
928 const Operator* existing_reorder =
929 GetOpWithOutput(*model, reordered_weights_name);
930 if (existing_reorder) {
931 // Check that it is safe to rely on the _reordered naming of the output
932 // array!
933 CHECK(existing_reorder->type == OperatorType::kReorderAxes);
934 } else {
935 // Create a new ReorderAxesOperator
936 auto* reorder = new ReorderAxesOperator;
937 reorder->inputs = {weights_name};
938 reorder->outputs = {reordered_weights_name};
939 reorder->input_axes_order = AxesOrder::kHWIM;
940 reorder->output_axes_order = AxesOrder::k1HWO;
941 model->operators.emplace_back(reorder);
942 }
943 const auto& strides = GetListAttr(node, "strides");
944 TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
945 TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
946 TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
947 int dilation_height_factor;
948 int dilation_width_factor;
949 if (HasAttr(node, "dilations")) {
950 const auto& dilations = GetListAttr(node, "dilations");
951 TF_RETURN_IF_ERROR(
952 ExpectValue(dilations.i_size(), 4, "number of dilations"));
953 if (dilations.i(0) != 1 || dilations.i(3) != 1) {
954 return tensorflow::errors::InvalidArgument(absl::StrCat(
955 "Can only import Conv ops with dilation along the height "
956 "(1st) or width (2nd) axis. TensorFlow op \"",
957 node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
958 dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
959 }
960 dilation_height_factor = dilations.i(1);
961 dilation_width_factor = dilations.i(2);
962 } else {
963 dilation_height_factor = 1;
964 dilation_width_factor = 1;
965 }
966 const auto& padding = GetStringAttr(node, "padding");
967 PaddingType padding_type;
968 if (padding == "SAME") {
969 padding_type = PaddingType::kSame;
970 } else if (padding == "VALID") {
971 padding_type = PaddingType::kValid;
972 } else {
973 return tensorflow::errors::InvalidArgument(
974 "Bad padding (only SAME and VALID are supported)");
975 }
976 auto* conv = new DepthwiseConvOperator;
977 conv->inputs = {input_name, reordered_weights_name};
978 conv->outputs = {node.name()};
979 conv->stride_height = strides.i(1);
980 conv->stride_width = strides.i(2);
981 conv->dilation_height_factor = dilation_height_factor;
982 conv->dilation_width_factor = dilation_width_factor;
983 conv->padding.type = padding_type;
984 model->operators.emplace_back(conv);
985 return tensorflow::Status::OK();
986 }
987
ConvertDepthToSpaceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)988 tensorflow::Status ConvertDepthToSpaceOperator(
989 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
990 const ModelFlags& model_flags, Model* model) {
991 CHECK_EQ(node.op(), "DepthToSpace");
992 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
993
994 tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
995 if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
996 dtype != DT_INT64) {
997 const auto* enum_descriptor = tensorflow::DataType_descriptor();
998 LOG(FATAL) << "TFLite does not support DepthToSpace with type T:"
999 << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1000 << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1001 }
1002 auto* op = new DepthToSpaceOperator;
1003 op->inputs.push_back(node.input(0));
1004 op->outputs.push_back(node.name());
1005 op->block_size = GetIntAttr(node, "block_size");
1006 QCHECK_GE(op->block_size, 2);
1007 model->operators.emplace_back(op);
1008 return tensorflow::Status::OK();
1009 }
1010
ConvertSpaceToDepthOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1011 tensorflow::Status ConvertSpaceToDepthOperator(
1012 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1013 const ModelFlags& model_flags, Model* model) {
1014 CHECK_EQ(node.op(), "SpaceToDepth");
1015 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1016
1017 tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
1018 if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
1019 dtype != DT_INT64) {
1020 const auto* enum_descriptor = tensorflow::DataType_descriptor();
1021 LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
1022 << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1023 << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1024 }
1025 auto* op = new SpaceToDepthOperator;
1026 op->inputs.push_back(node.input(0));
1027 op->outputs.push_back(node.name());
1028 op->block_size = GetIntAttr(node, "block_size");
1029 QCHECK_GE(op->block_size, 2);
1030 model->operators.emplace_back(op);
1031 return tensorflow::Status::OK();
1032 }
1033
ConvertBiasAddOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1034 tensorflow::Status ConvertBiasAddOperator(
1035 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1036 const ModelFlags& model_flags, Model* model) {
1037 CHECK_EQ(node.op(), "BiasAdd");
1038 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1039
1040 const auto& input_name = node.input(0);
1041 const auto& bias_name = node.input(1);
1042 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1043 auto* biasadd = new AddOperator;
1044 biasadd->inputs.push_back(input_name);
1045 biasadd->inputs.push_back(bias_name);
1046 biasadd->outputs.push_back(node.name());
1047 model->operators.emplace_back(biasadd);
1048 return tensorflow::Status::OK();
1049 }
1050
ConvertRandomUniform(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1051 tensorflow::Status ConvertRandomUniform(
1052 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1053 const ModelFlags& model_flags, Model* model) {
1054 CHECK_EQ(node.op(), "RandomUniform");
1055 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1056
1057 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
1058 auto op = absl::make_unique<RandomUniformOperator>();
1059 op->inputs.push_back(node.input(0));
1060 op->outputs.push_back(node.name());
1061 op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1062 op->seed = GetIntAttr(node, "seed");
1063 op->seed2 = GetIntAttr(node, "seed2");
1064 CHECK(model != nullptr);
1065 model->operators.emplace_back(std::move(op));
1066 return tensorflow::Status::OK();
1067 }
1068
ConvertIdentityOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1069 tensorflow::Status ConvertIdentityOperator(
1070 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1071 const ModelFlags& model_flags, Model* model) {
1072 CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
1073 node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
1074 node.op() == "Snapshot");
1075 auto* op = new TensorFlowIdentityOperator;
1076 // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
1077 // identity nodes with multiple inputs, but the other inputs seem
1078 // to be gratuitous (in the case of rajeev_lstm.pb, these are
1079 // enumerating the LSTM state arrays). We will just ignore extra
1080 // inputs beyond the first input.
1081 QCHECK_GE(node.input_size(), 1)
1082 << node.op()
1083 << " node expects at least 1 input other than control dependencies: "
1084 << node.DebugString();
1085 const auto& input_name = node.input(0);
1086 op->inputs.push_back(input_name);
1087 op->outputs.push_back(node.name());
1088 model->operators.emplace_back(op);
1089 return tensorflow::Status::OK();
1090 }
1091
ConvertIdentityNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1092 tensorflow::Status ConvertIdentityNOperator(
1093 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1094 const ModelFlags& model_flags, Model* model) {
1095 CHECK_EQ(node.op(), "IdentityN");
1096 for (int i = 0; i < node.input_size(); ++i) {
1097 auto* op = new TensorFlowIdentityOperator;
1098 const auto& input_name = node.input(i);
1099 std::string output_name = node.name();
1100 if (i > 0) {
1101 output_name = output_name + ":" + std::to_string(i);
1102 }
1103 op->inputs.push_back(input_name);
1104 op->outputs.push_back(output_name);
1105 model->operators.emplace_back(op);
1106 }
1107 return tensorflow::Status::OK();
1108 }
1109
ConvertFakeQuantWithMinMaxArgs(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1110 tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
1111 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1112 const ModelFlags& model_flags, Model* model) {
1113 CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
1114 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1115 auto* op = new FakeQuantOperator;
1116 op->inputs.push_back(node.input(0));
1117 op->minmax.reset(new MinMax);
1118 auto& minmax = *op->minmax;
1119 minmax.min = GetFloatAttr(node, "min");
1120 minmax.max = GetFloatAttr(node, "max");
1121 op->outputs.push_back(node.name());
1122 // tf.fake_quant_with_min_max_args num_bits defaults to 8.
1123 op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1124 if (HasAttr(node, "narrow_range")) {
1125 op->narrow_range = GetBoolAttr(node, "narrow_range");
1126 }
1127 model->operators.emplace_back(op);
1128 return tensorflow::Status::OK();
1129 }
1130
ConvertFakeQuantWithMinMaxVars(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1131 tensorflow::Status ConvertFakeQuantWithMinMaxVars(
1132 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1133 const ModelFlags& model_flags, Model* model) {
1134 CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
1135 const int num_inputs = GetInputsCount(node, tf_import_flags);
1136 QCHECK(num_inputs == 3 || num_inputs == 4)
1137 << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
1138 "control dependencies: "
1139 << node.DebugString();
1140 auto* op = new FakeQuantOperator;
1141 for (int i = 0; i < 3; i++) {
1142 op->inputs.push_back(node.input(i));
1143 }
1144 op->outputs.push_back(node.name());
1145 op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1146 if (HasAttr(node, "narrow_range")) {
1147 op->narrow_range = GetBoolAttr(node, "narrow_range");
1148 }
1149 model->operators.emplace_back(op);
1150 return tensorflow::Status::OK();
1151 }
1152
ConvertSqueezeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1153 tensorflow::Status ConvertSqueezeOperator(
1154 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1155 const ModelFlags& model_flags, Model* model) {
1156 CHECK_EQ(node.op(), "Squeeze");
1157 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1158 auto* op = new SqueezeOperator;
1159 op->inputs.push_back(node.input(0));
1160 op->outputs.push_back(node.name());
1161
1162 // When omitted we are to squeeze all dimensions == 1.
1163 if (HasAttr(node, "squeeze_dims")) {
1164 const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
1165 for (int i = 0; i < squeeze_dims.i_size(); ++i) {
1166 op->squeeze_dims.push_back(squeeze_dims.i(i));
1167 }
1168 }
1169
1170 model->operators.emplace_back(op);
1171 return tensorflow::Status::OK();
1172 }
1173
ConvertSplitOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1174 tensorflow::Status ConvertSplitOperator(
1175 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1176 const ModelFlags& model_flags, Model* model) {
1177 CHECK_EQ(node.op(), "Split");
1178 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1179 auto* op = new TensorFlowSplitOperator;
1180 op->inputs.push_back(node.input(0));
1181 op->inputs.push_back(node.input(1));
1182 const int num_split = GetIntAttr(node, "num_split");
1183 op->outputs.push_back(node.name());
1184 for (int i = 1; i < num_split; i++) {
1185 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1186 }
1187 op->num_split = num_split;
1188 model->operators.emplace_back(op);
1189 return tensorflow::Status::OK();
1190 }
1191
ConvertSplitVOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1192 tensorflow::Status ConvertSplitVOperator(
1193 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1194 const ModelFlags& model_flags, Model* model) {
1195 CHECK_EQ(node.op(), "SplitV");
1196 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1197 auto* op = new TensorFlowSplitVOperator;
1198 op->inputs.push_back(node.input(0));
1199 op->inputs.push_back(node.input(1));
1200 op->inputs.push_back(node.input(2));
1201 const int num_split = GetIntAttr(node, "num_split");
1202 op->outputs.push_back(node.name());
1203 for (int i = 1; i < num_split; i++) {
1204 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1205 }
1206 op->num_split = num_split;
1207 model->operators.emplace_back(op);
1208 return tensorflow::Status::OK();
1209 }
1210
ConvertSwitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1211 tensorflow::Status ConvertSwitchOperator(
1212 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1213 const ModelFlags& model_flags, Model* model) {
1214 CHECK_EQ(node.op(), "Switch");
1215 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1216 auto* op = new TensorFlowSwitchOperator;
1217 op->inputs.push_back(node.input(0));
1218 op->inputs.push_back(node.input(1));
1219 op->outputs.push_back(node.name());
1220 // Switch operators have two outputs: "name" and "name:1".
1221 op->outputs.push_back(node.name() + ":1");
1222 model->operators.emplace_back(op);
1223 return tensorflow::Status::OK();
1224 }
1225
ConvertSoftmaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1226 tensorflow::Status ConvertSoftmaxOperator(
1227 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1228 const ModelFlags& model_flags, Model* model) {
1229 CHECK_EQ(node.op(), "Softmax");
1230 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1231 const auto& input_name = node.input(0);
1232 auto* softmax = new SoftmaxOperator;
1233 softmax->inputs.push_back(input_name);
1234 softmax->outputs.push_back(node.name());
1235 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
1236 CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
1237 if (node.attr().count("_softmax_beta")) {
1238 softmax->beta = GetFloatAttr(node, "_softmax_beta");
1239 } else {
1240 softmax->beta = 1.f;
1241 }
1242 model->operators.emplace_back(softmax);
1243 return tensorflow::Status::OK();
1244 }
1245
ConvertLRNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1246 tensorflow::Status ConvertLRNOperator(
1247 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1248 const ModelFlags& model_flags, Model* model) {
1249 CHECK_EQ(node.op(), "LRN");
1250 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1251 const auto& input_name = node.input(0);
1252 auto* lrn = new LocalResponseNormalizationOperator;
1253 lrn->inputs.push_back(input_name);
1254 lrn->outputs.push_back(node.name());
1255 lrn->range = GetIntAttr(node, "depth_radius");
1256 lrn->bias = GetFloatAttr(node, "bias");
1257 lrn->alpha = GetFloatAttr(node, "alpha");
1258 lrn->beta = GetFloatAttr(node, "beta");
1259 model->operators.emplace_back(lrn);
1260 return tensorflow::Status::OK();
1261 }
1262
ConvertMaxPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1263 tensorflow::Status ConvertMaxPoolOperator(
1264 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1265 const ModelFlags& model_flags, Model* model) {
1266 CHECK_EQ(node.op(), "MaxPool");
1267 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1268 const auto& input_name = node.input(0);
1269 // We only support NHWC, which is the default data_format.
1270 // So if data_format is not defined, we're all good.
1271 if (node.attr().count("data_format")) {
1272 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1273 }
1274 if (HasAttr(node, "T")) {
1275 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1276 } else {
1277 LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
1278 }
1279 auto* maxpool = new MaxPoolOperator;
1280 maxpool->inputs.push_back(input_name);
1281 maxpool->outputs.push_back(node.name());
1282 const auto& strides = GetListAttr(node, "strides");
1283 CHECK_EQ(strides.i_size(), 4);
1284 CHECK_EQ(strides.i(0), 1);
1285 CHECK_EQ(strides.i(3), 1);
1286 maxpool->stride_height = strides.i(1);
1287 maxpool->stride_width = strides.i(2);
1288 const auto& ksize = GetListAttr(node, "ksize");
1289 CHECK_EQ(ksize.i_size(), 4);
1290 CHECK_EQ(ksize.i(0), 1);
1291 CHECK_EQ(ksize.i(3), 1);
1292 maxpool->kheight = ksize.i(1);
1293 maxpool->kwidth = ksize.i(2);
1294 const auto& padding = GetStringAttr(node, "padding");
1295 if (padding == "SAME") {
1296 maxpool->padding.type = PaddingType::kSame;
1297 } else if (padding == "VALID") {
1298 maxpool->padding.type = PaddingType::kValid;
1299 } else {
1300 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1301 }
1302 model->operators.emplace_back(maxpool);
1303 return tensorflow::Status::OK();
1304 }
1305
ConvertAvgPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1306 tensorflow::Status ConvertAvgPoolOperator(
1307 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1308 const ModelFlags& model_flags, Model* model) {
1309 CHECK_EQ(node.op(), "AvgPool");
1310 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1311 const auto& input_name = node.input(0);
1312 // We only support NHWC, which is the default data_format.
1313 // So if data_format is not defined, we're all good.
1314 if (node.attr().count("data_format")) {
1315 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1316 }
1317 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1318 auto* avgpool = new AveragePoolOperator;
1319 avgpool->inputs.push_back(input_name);
1320 avgpool->outputs.push_back(node.name());
1321 const auto& strides = GetListAttr(node, "strides");
1322 CHECK_EQ(strides.i_size(), 4);
1323 CHECK_EQ(strides.i(0), 1);
1324 CHECK_EQ(strides.i(3), 1);
1325 avgpool->stride_height = strides.i(1);
1326 avgpool->stride_width = strides.i(2);
1327 const auto& ksize = GetListAttr(node, "ksize");
1328 CHECK_EQ(ksize.i_size(), 4);
1329 CHECK_EQ(ksize.i(0), 1);
1330 CHECK_EQ(ksize.i(3), 1);
1331 avgpool->kheight = ksize.i(1);
1332 avgpool->kwidth = ksize.i(2);
1333 const auto& padding = GetStringAttr(node, "padding");
1334 if (padding == "SAME") {
1335 avgpool->padding.type = PaddingType::kSame;
1336 } else if (padding == "VALID") {
1337 avgpool->padding.type = PaddingType::kValid;
1338 } else {
1339 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1340 }
1341 model->operators.emplace_back(avgpool);
1342 return tensorflow::Status::OK();
1343 }
1344
ConvertBatchMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1345 tensorflow::Status ConvertBatchMatMulOperator(
1346 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1347 const ModelFlags& model_flags, Model* model) {
1348 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1349
1350 auto* batch_matmul = new BatchMatMulOperator;
1351 // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
1352 if (HasAttr(node, "adj_x")) {
1353 batch_matmul->adj_x = GetBoolAttr(node, "adj_x");
1354 }
1355 if (HasAttr(node, "adj_y")) {
1356 batch_matmul->adj_y = GetBoolAttr(node, "adj_y");
1357 }
1358 batch_matmul->inputs = {node.input(0), node.input(1)};
1359 batch_matmul->outputs = {node.name()};
1360
1361 // For Flex mode. Please read the comments of the function.
1362 RetainTensorFlowNodeDef(node, batch_matmul);
1363
1364 model->operators.emplace_back(batch_matmul);
1365 return tensorflow::Status::OK();
1366 }
1367
ConvertMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1368 tensorflow::Status ConvertMatMulOperator(
1369 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1370 const ModelFlags& model_flags, Model* model) {
1371 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1372
1373 CHECK(!HasAttr(node, "adjoint_a") ||
1374 (GetBoolAttr(node, "adjoint_a") == false));
1375 CHECK(!HasAttr(node, "adjoint_b") ||
1376 (GetBoolAttr(node, "adjoint_b") == false));
1377
1378 auto* matmul = new TensorFlowMatMulOperator;
1379 if (HasAttr(node, "transpose_a")) {
1380 matmul->transpose_a = GetBoolAttr(node, "transpose_a");
1381 }
1382 if (HasAttr(node, "transpose_b")) {
1383 matmul->transpose_b = GetBoolAttr(node, "transpose_b");
1384 }
1385
1386 matmul->inputs = {node.input(0), node.input(1)};
1387 matmul->outputs = {node.name()};
1388 model->operators.emplace_back(matmul);
1389 return tensorflow::Status::OK();
1390 }
1391
ConvertConcatOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1392 tensorflow::Status ConvertConcatOperator(
1393 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1394 const ModelFlags& model_flags, Model* model) {
1395 Operator* op = nullptr;
1396 if (node.op() == "Concat") {
1397 op = new TensorFlowConcatOperator;
1398 } else if (node.op() == "ConcatV2") {
1399 op = new TensorFlowConcatV2Operator;
1400 } else {
1401 LOG(FATAL) << "Expected Concat or ConcatV2";
1402 }
1403 const int num_inputs = GetInputsCount(node, tf_import_flags);
1404 QCHECK_GE(num_inputs, 2)
1405 << node.op()
1406 << " node expects at least 2 inputs other than control dependencies: "
1407 << node.DebugString();
1408 CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1409 for (int i = 0; i < num_inputs; ++i) {
1410 op->inputs.push_back(node.input(i));
1411 }
1412 op->outputs.push_back(node.name());
1413 model->operators.emplace_back(op);
1414 return tensorflow::Status::OK();
1415 }
1416
ConvertMirrorPadOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1417 tensorflow::Status ConvertMirrorPadOperator(
1418 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1419 const ModelFlags& model_flags, Model* model) {
1420 if (node.op() != "MirrorPad") {
1421 LOG(FATAL) << "Expected MirrorPad.";
1422 }
1423 const int num_inputs = GetInputsCount(node, tf_import_flags);
1424 CHECK_EQ(num_inputs, 2);
1425 auto* op = new MirrorPadOperator;
1426 for (int i = 0; i < num_inputs; ++i) {
1427 op->inputs.push_back(node.input(i));
1428 }
1429 op->outputs.push_back(node.name());
1430 const auto mode = GetStringAttr(node, "mode");
1431 if (mode == "REFLECT") {
1432 op->mode = toco::MirrorPadMode::kReflect;
1433 } else if (mode == "SYMMETRIC") {
1434 op->mode = toco::MirrorPadMode::kSymmetric;
1435 }
1436
1437 model->operators.emplace_back(op);
1438
1439 return tensorflow::Status::OK();
1440 }
1441
1442 static constexpr int kAnyNumInputs = -1;
1443
1444 enum FlexSupport { kFlexOk, kFlexNotOk };
1445
1446 // This method supports simple operators without additional attributes.
1447 // Converts a simple operator that takes no attributes. The list of inputs is
1448 // taken from the given NodeDef, and its number must match NumInputs, unless
1449 // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
1450 // will be eligible for being exported as a flex op.
1451 template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
ConvertSimpleOperatorGeneric(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1452 tensorflow::Status ConvertSimpleOperatorGeneric(
1453 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1454 const ModelFlags& model_flags, Model* model) {
1455 if (NumInputs != kAnyNumInputs) {
1456 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
1457 }
1458 auto* op = new Op;
1459 const int num_inputs = GetInputsCount(node, tf_import_flags);
1460 for (int i = 0; i < num_inputs; ++i) {
1461 op->inputs.push_back(node.input(i));
1462 }
1463 op->outputs.push_back(node.name());
1464 if (NumOutputs > 1) {
1465 for (int i = 1; i < NumOutputs; ++i) {
1466 op->outputs.push_back(node.name() + ":" + std::to_string(i));
1467 }
1468 }
1469
1470 if (flex == kFlexOk) {
1471 RetainTensorFlowNodeDef(node, op);
1472 }
1473
1474 model->operators.emplace_back(op);
1475 return tensorflow::Status::OK();
1476 }
1477
1478 // Convert a simple operator which is not valid as a flex op.
1479 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1480 tensorflow::Status ConvertSimpleOperator(
1481 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1482 const ModelFlags& model_flags, Model* model) {
1483 return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
1484 node, tf_import_flags, model_flags, model);
1485 }
1486
1487 // Convert a simple operator which is valid as a flex op.
1488 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperatorFlexOk(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1489 tensorflow::Status ConvertSimpleOperatorFlexOk(
1490 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1491 const ModelFlags& model_flags, Model* model) {
1492 return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
1493 node, tf_import_flags, model_flags, model);
1494 }
1495
1496 // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
1497 // the types are not supported. Converting Const operators here avoids
1498 // expensive copies of the protocol buffers downstream in the flex delegate.
ConditionallyConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1499 tensorflow::Status ConditionallyConvertConstOperator(
1500 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1501 const ModelFlags& model_flags, Model* model) {
1502 // We avoid incomplete and zero shapes because the resulting arrays
1503 // are not completely compatible with Eager/TensorFlow.
1504 const auto& tensor = GetTensorAttr(node, "value");
1505 const auto& shape = tensor.tensor_shape();
1506 for (const auto& dim : shape.dim()) {
1507 if (dim.size() <= 0) {
1508 return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1509 model);
1510 }
1511 }
1512 switch (GetDataTypeAttr(node, "dtype")) {
1513 case DT_FLOAT:
1514 case DT_INT32:
1515 case DT_QUINT8:
1516 case DT_INT64:
1517 case DT_STRING:
1518 case DT_BOOL:
1519 case DT_COMPLEX64:
1520 return ConvertConstOperator(node, tf_import_flags, model_flags, model);
1521 default:
1522 return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1523 model);
1524 }
1525 }
1526
ConvertStridedSliceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1527 tensorflow::Status ConvertStridedSliceOperator(
1528 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1529 const ModelFlags& model_flags, Model* model) {
1530 CHECK_EQ(node.op(), "StridedSlice");
1531 // TODO(soroosh): The 4th input (strides) should be e optional, to be
1532 // consistent with TF.
1533 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
1534
1535 auto* op = new StridedSliceOperator;
1536 for (const auto& input : node.input()) {
1537 op->inputs.push_back(input);
1538 }
1539 op->outputs.push_back(node.name());
1540
1541 op->begin_mask =
1542 HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
1543 op->ellipsis_mask =
1544 HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
1545 op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
1546 op->new_axis_mask =
1547 HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
1548 op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
1549 ? GetIntAttr(node, "shrink_axis_mask")
1550 : 0;
1551
1552 model->operators.emplace_back(op);
1553 return tensorflow::Status::OK();
1554 }
1555
ConvertPlaceholderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1556 tensorflow::Status ConvertPlaceholderOperator(
1557 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1558 const ModelFlags& model_flags, Model* model) {
1559 CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1560 if (node.op() == "Placeholder") {
1561 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
1562 }
1563
1564 bool inside_input_arrays = false;
1565 for (const auto& input_array : model_flags.input_arrays()) {
1566 if (node.name() == input_array.name()) {
1567 inside_input_arrays = true;
1568 break;
1569 }
1570 }
1571
1572 if (!inside_input_arrays) {
1573 model->AddInvalidInputArray(node.name());
1574 }
1575
1576 auto& array = model->GetOrCreateArray(node.name());
1577 if (node.attr().count("dtype")) {
1578 array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1579 }
1580 if (node.attr().count("shape")) {
1581 const auto& shape = GetShapeAttr(node, "shape");
1582 auto num_dims = shape.dim_size();
1583 // TODO(b/62716978): This logic needs to be revisited. During dims
1584 // refactoring it is an interim fix.
1585 if (num_dims > 0 && !HasWildcardDimension(shape)) {
1586 auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
1587 dst_array_dims.resize(num_dims);
1588 for (std::size_t i = 0; i < num_dims; i++) {
1589 dst_array_dims[i] = shape.dim(i).size();
1590 }
1591 }
1592 }
1593 return tensorflow::Status::OK();
1594 }
1595
ConvertNoOpOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1596 tensorflow::Status ConvertNoOpOperator(
1597 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1598 const ModelFlags& model_flags, Model* model) {
1599 return tensorflow::Status::OK();
1600 }
1601
ConvertCastOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1602 tensorflow::Status ConvertCastOperator(
1603 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1604 const ModelFlags& model_flags, Model* model) {
1605 CHECK_EQ(node.op(), "Cast");
1606 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1607 const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1608 const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1609 auto* op = new CastOperator;
1610 op->src_data_type = ConvertDataType(tf_src_dtype);
1611 op->dst_data_type = ConvertDataType(tf_dst_dtype);
1612 op->inputs.push_back(node.input(0));
1613 op->outputs.push_back(node.name());
1614 model->operators.emplace_back(op);
1615 return tensorflow::Status::OK();
1616 }
1617
ConvertFloorOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1618 tensorflow::Status ConvertFloorOperator(
1619 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1620 const ModelFlags& model_flags, Model* model) {
1621 CHECK_EQ(node.op(), "Floor");
1622 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1623 const auto data_type = GetDataTypeAttr(node, "T");
1624 CHECK(data_type == DT_FLOAT);
1625 auto* op = new FloorOperator;
1626 op->inputs.push_back(node.input(0));
1627 op->outputs.push_back(node.name());
1628 model->operators.emplace_back(op);
1629 return tensorflow::Status::OK();
1630 }
1631
ConvertCeilOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1632 tensorflow::Status ConvertCeilOperator(
1633 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1634 const ModelFlags& model_flags, Model* model) {
1635 CHECK_EQ(node.op(), "Ceil");
1636 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1637 const auto data_type = GetDataTypeAttr(node, "T");
1638 CHECK(data_type == DT_FLOAT);
1639 auto* op = new CeilOperator;
1640 op->inputs.push_back(node.input(0));
1641 op->outputs.push_back(node.name());
1642 model->operators.emplace_back(op);
1643 return tensorflow::Status::OK();
1644 }
1645
ConvertRoundOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1646 tensorflow::Status ConvertRoundOperator(
1647 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1648 const ModelFlags& model_flags, Model* model) {
1649 CHECK_EQ(node.op(), "Round");
1650 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1651 const auto data_type = GetDataTypeAttr(node, "T");
1652 CHECK(data_type == DT_FLOAT);
1653 auto* op = new RoundOperator;
1654 op->inputs.push_back(node.input(0));
1655 op->outputs.push_back(node.name());
1656 model->operators.emplace_back(op);
1657 return tensorflow::Status::OK();
1658 }
1659
ConvertGatherOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1660 tensorflow::Status ConvertGatherOperator(
1661 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1662 const ModelFlags& model_flags, Model* model) {
1663 CHECK(node.op() == "Gather" || node.op() == "GatherV2");
1664 if (node.op() == "Gather")
1665 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1666 if (node.op() == "GatherV2")
1667 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1668 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1669 CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1670 auto* op = new GatherOperator;
1671 op->inputs.push_back(node.input(0));
1672 op->inputs.push_back(node.input(1));
1673 if (node.input_size() >= 3) {
1674 // GatherV2 form where we are provided an axis. It may be either a constant
1675 // or runtime defined value, so we just wire up the array and let
1676 // ResolveGatherAttributes take care of it later on.
1677 const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
1678 CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
1679 op->inputs.push_back(node.input(2));
1680 } else {
1681 // Gather form that assumes axis=0.
1682 op->axis = {0};
1683 }
1684 op->outputs.push_back(node.name());
1685 model->operators.emplace_back(op);
1686 return tensorflow::Status::OK();
1687 }
1688
ConvertGatherNdOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1689 tensorflow::Status ConvertGatherNdOperator(
1690 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1691 const ModelFlags& model_flags, Model* model) {
1692 CHECK_EQ(node.op(), "GatherNd");
1693 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1694 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1695 CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1696 auto* op = new GatherNdOperator;
1697 op->inputs.push_back(node.input(0));
1698 op->inputs.push_back(node.input(1));
1699 op->outputs.push_back(node.name());
1700 model->operators.emplace_back(op);
1701 return tensorflow::Status::OK();
1702 }
1703
1704 template <typename Op>
ConvertArgMinMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1705 tensorflow::Status ConvertArgMinMaxOperator(
1706 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1707 const ModelFlags& model_flags, Model* model) {
1708 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1709 const auto axis_data_type =
1710 HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
1711 const auto output_type = HasAttr(node, "output_type")
1712 ? GetDataTypeAttr(node, "output_type")
1713 : DT_INT64;
1714 CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
1715 CHECK(output_type == DT_INT64 || output_type == DT_INT32);
1716 auto* op = new Op;
1717 op->output_data_type = ConvertDataType(output_type);
1718 op->inputs.push_back(node.input(0));
1719 op->inputs.push_back(node.input(1));
1720 op->outputs.push_back(node.name());
1721 model->operators.emplace_back(op);
1722 return tensorflow::Status::OK();
1723 }
1724
ConvertArgMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1725 tensorflow::Status ConvertArgMaxOperator(
1726 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1727 const ModelFlags& model_flags, Model* model) {
1728 CHECK_EQ(node.op(), "ArgMax");
1729 return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
1730 model_flags, model);
1731 }
1732
ConvertArgMinOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1733 tensorflow::Status ConvertArgMinOperator(
1734 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1735 const ModelFlags& model_flags, Model* model) {
1736 CHECK_EQ(node.op(), "ArgMin");
1737 return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
1738 model_flags, model);
1739 }
1740
ConvertResizeBilinearOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1741 tensorflow::Status ConvertResizeBilinearOperator(
1742 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1743 const ModelFlags& model_flags, Model* model) {
1744 CHECK_EQ(node.op(), "ResizeBilinear");
1745 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1746 auto* op = new ResizeBilinearOperator;
1747
1748 op->align_corners = false;
1749 op->half_pixel_centers = false;
1750 if (HasAttr(node, "align_corners")) {
1751 op->align_corners = GetBoolAttr(node, "align_corners");
1752 }
1753 if (HasAttr(node, "half_pixel_centers")) {
1754 op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1755 }
1756
1757 op->inputs.push_back(node.input(0));
1758 op->inputs.push_back(node.input(1));
1759 op->outputs.push_back(node.name());
1760 model->operators.emplace_back(op);
1761 return tensorflow::Status::OK();
1762 }
1763
ConvertResizeNearestNeighborOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1764 tensorflow::Status ConvertResizeNearestNeighborOperator(
1765 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1766 const ModelFlags& model_flags, Model* model) {
1767 CHECK_EQ(node.op(), "ResizeNearestNeighbor");
1768 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1769 auto* op = new ResizeNearestNeighborOperator;
1770
1771 op->align_corners = false;
1772 op->half_pixel_centers = false;
1773 if (HasAttr(node, "align_corners")) {
1774 op->align_corners = GetBoolAttr(node, "align_corners");
1775 }
1776 if (HasAttr(node, "half_pixel_centers")) {
1777 op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1778 }
1779
1780 op->inputs.push_back(node.input(0));
1781 op->inputs.push_back(node.input(1));
1782 op->outputs.push_back(node.name());
1783 model->operators.emplace_back(op);
1784 return tensorflow::Status::OK();
1785 }
1786
ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1787 tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
1788 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1789 const ModelFlags& model_flags, Model* model) {
1790 CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1791 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1792
1793 // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
1794 // to the input, before feeding it into TensorFlowRsqrtOperator.
1795 // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1796
1797 std::string multiplier = node.name() + "_mul";
1798 if (GetBoolAttr(node, "scale_after_normalization")) {
1799 // Create graph:
1800 // v -> RSQRT ->
1801 // MUL -> multiplier
1802 // gamma ----->
1803 std::string rsqrt = node.name() + "_rsqrt";
1804
1805 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1806 rsqrt_op->inputs.push_back(node.input(2));
1807 rsqrt_op->outputs.push_back(rsqrt);
1808 model->operators.emplace_back(rsqrt_op);
1809
1810 auto* mul_op = new MulOperator;
1811 mul_op->inputs.push_back(rsqrt);
1812 mul_op->inputs.push_back(node.input(4));
1813 mul_op->outputs.push_back(multiplier);
1814 model->operators.emplace_back(mul_op);
1815 } else {
1816 // Create graph:
1817 // v -> RSQRT -> multiplier
1818 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1819 rsqrt_op->inputs.push_back(node.input(2));
1820 rsqrt_op->outputs.push_back(multiplier);
1821 model->operators.emplace_back(rsqrt_op);
1822 }
1823
1824 auto* op = new BatchNormalizationOperator;
1825 op->global_normalization = true;
1826
1827 op->inputs.push_back(node.input(0));
1828 op->inputs.push_back(node.input(1));
1829 op->inputs.push_back(multiplier);
1830 op->inputs.push_back(node.input(3));
1831 op->outputs.push_back(node.name());
1832
1833 model->operators.emplace_back(op);
1834 return tensorflow::Status::OK();
1835 }
1836
ConvertFusedBatchNormOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1837 tensorflow::Status ConvertFusedBatchNormOperator(
1838 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1839 const ModelFlags& model_flags, Model* model) {
1840 CHECK((node.op() == "FusedBatchNorm") || (node.op() == "FusedBatchNormV3"));
1841 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1842
1843 // Declare shortcuts for the inputs.
1844 const std::string& gamma_input = node.input(1);
1845 const std::string& beta_input = node.input(2);
1846 const std::string& moving_mean_input = node.input(3);
1847 const std::string& moving_variance_input = node.input(4);
1848
1849 // Create an array holding the epsilon value (typically, 0.001).
1850 const std::string epsilon_array_name =
1851 CreateConstArray<ArrayDataType::kFloat>(model,
1852 node.name() + "_epsilon_array",
1853 {GetFloatAttr(node, "epsilon")});
1854
1855 // Add epsilon to the moving variance.
1856 const std::string epsilon_add_op_name = node.name() + "_epsilon";
1857 auto* epsilon_add_op = new AddOperator;
1858 epsilon_add_op->inputs.push_back(moving_variance_input);
1859 epsilon_add_op->inputs.push_back(epsilon_array_name);
1860 epsilon_add_op->outputs.push_back(epsilon_add_op_name);
1861 model->operators.emplace_back(epsilon_add_op);
1862
1863 // Take the inverse square root of the (variance + epsilon).
1864 const std::string rsqrt_op_name = node.name() + "_rsqrt";
1865 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1866 rsqrt_op->inputs.push_back(epsilon_add_op_name);
1867 rsqrt_op->outputs.push_back(rsqrt_op_name);
1868 model->operators.emplace_back(rsqrt_op);
1869
1870 // Multiply the result by gamma.
1871 const std::string multiplier = node.name() + "_mul";
1872 auto* mul_op = new MulOperator;
1873 mul_op->inputs.push_back(rsqrt_op_name);
1874 mul_op->inputs.push_back(gamma_input);
1875 mul_op->outputs.push_back(multiplier);
1876 model->operators.emplace_back(mul_op);
1877
1878 // Now we have all required inputs for the BatchNormalizationOperator.
1879 auto* op = new BatchNormalizationOperator;
1880 op->global_normalization = true;
1881
1882 op->inputs.push_back(node.input(0));
1883 op->inputs.push_back(moving_mean_input);
1884 op->inputs.push_back(multiplier);
1885 op->inputs.push_back(beta_input);
1886 op->outputs.push_back(node.name());
1887
1888 model->operators.emplace_back(op);
1889 return tensorflow::Status::OK();
1890 }
1891
ConvertSpaceToBatchNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1892 tensorflow::Status ConvertSpaceToBatchNDOperator(
1893 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1894 const ModelFlags& model_flags, Model* model) {
1895 CHECK_EQ(node.op(), "SpaceToBatchND");
1896 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1897 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1898 CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1899 auto* op = new SpaceToBatchNDOperator;
1900 op->inputs.push_back(node.input(0));
1901 op->inputs.push_back(node.input(1));
1902 op->inputs.push_back(node.input(2));
1903 op->outputs.push_back(node.name());
1904 model->operators.emplace_back(op);
1905 return tensorflow::Status::OK();
1906 }
1907
ConvertBatchToSpaceNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1908 tensorflow::Status ConvertBatchToSpaceNDOperator(
1909 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1910 const ModelFlags& model_flags, Model* model) {
1911 CHECK_EQ(node.op(), "BatchToSpaceND");
1912 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1913 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1914 CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1915 auto* op = new BatchToSpaceNDOperator;
1916 op->inputs.push_back(node.input(0));
1917 op->inputs.push_back(node.input(1));
1918 op->inputs.push_back(node.input(2));
1919 op->outputs.push_back(node.name());
1920 model->operators.emplace_back(op);
1921 return tensorflow::Status::OK();
1922 }
1923
1924 template <typename T>
ConvertReduceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1925 tensorflow::Status ConvertReduceOperator(
1926 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1927 const ModelFlags& model_flags, Model* model) {
1928 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1929 auto* op = new T;
1930 op->inputs.push_back(node.input(0));
1931 op->inputs.push_back(node.input(1));
1932 op->outputs.push_back(node.name());
1933 model->operators.emplace_back(op);
1934 if (HasAttr(node, "keepdims")) {
1935 op->keep_dims = GetBoolAttr(node, "keepdims");
1936 } else if (HasAttr(node, "keep_dims")) {
1937 op->keep_dims = GetBoolAttr(node, "keep_dims");
1938 }
1939 return tensorflow::Status::OK();
1940 }
1941
1942 // TODO(b/139320642): Add test when fused op is supported.
ConvertSvdfOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1943 tensorflow::Status ConvertSvdfOperator(
1944 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1945 const ModelFlags& model_flags, Model* model) {
1946 CHECK_EQ(node.op(), "Svdf");
1947 const int input_size = GetInputsCount(node, tf_import_flags);
1948 QCHECK(input_size == 4 || input_size == 5)
1949 << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1950 << node.DebugString();
1951 bool has_bias = (input_size == 5);
1952 auto* op = new SvdfOperator;
1953 int index = 0;
1954 op->inputs.push_back(node.input(index++));
1955 op->inputs.push_back(node.input(index++));
1956 op->inputs.push_back(node.input(index++));
1957 if (has_bias) {
1958 op->inputs.push_back(node.input(index++));
1959 }
1960 op->inputs.push_back(node.input(index));
1961 op->outputs.push_back(node.name());
1962 if (node.attr().at("ActivationFunction").s() == "Relu") {
1963 op->fused_activation_function = FusedActivationFunctionType::kRelu;
1964 } else {
1965 op->fused_activation_function = FusedActivationFunctionType::kNone;
1966 }
1967 op->rank = node.attr().at("Rank").i();
1968 model->operators.emplace_back(op);
1969 return tensorflow::Status::OK();
1970 }
1971
1972 // This is just bare bones support to get the shapes to propagate.
ConvertTransposeConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1973 tensorflow::Status ConvertTransposeConvOperator(
1974 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1975 const ModelFlags& model_flags, Model* model) {
1976 CHECK_EQ(node.op(), "Conv2DBackpropInput");
1977 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1978 auto* op = new TransposeConvOperator;
1979 op->inputs.push_back(node.input(0));
1980 op->inputs.push_back(node.input(1));
1981 op->inputs.push_back(node.input(2));
1982 op->outputs.push_back(node.name());
1983 const auto& strides = GetListAttr(node, "strides");
1984 op->stride_height = strides.i(1);
1985 op->stride_width = strides.i(2);
1986 CHECK_EQ(strides.i_size(), 4)
1987 << "Can only import TransposeConv ops with 4D strides. TensorFlow op \""
1988 << node.name() << "\" has " << strides.i_size() << "D strides.";
1989 CHECK((strides.i(0) == 1) && (strides.i(3) == 1))
1990 << "Can only import TransposeConv ops with striding along the height "
1991 "(1st) or width (2nd) axis. TensorFlow op \""
1992 << node.name() << "\" had strides:[ " << strides.i(0) << ", "
1993 << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "].";
1994 op->stride_height = strides.i(1);
1995 op->stride_width = strides.i(2);
1996 if (HasAttr(node, "dilations")) {
1997 const auto& dilations = GetListAttr(node, "dilations");
1998 CHECK_EQ(dilations.i_size(), 4)
1999 << "Dilation unsupported in TransposeConv. TensorFlow op \""
2000 << node.name() << "\" had dilations";
2001 CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
2002 (dilations.i(2) == 1) && (dilations.i(3) == 1))
2003 << "Dilation unsupported in TransposeConv. TensorFlow op \""
2004 << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
2005 << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
2006 << "].";
2007 }
2008
2009 const std::string& weights_name = node.input(TransposeConvOperator::WEIGHTS);
2010 const std::string& transposed_weights_name = weights_name + "_transposed";
2011 // Check if a TransposeOperator was already created for these weights
2012 // (can happen when multiple layers share the same weights).
2013 const Operator* existing_transpose =
2014 GetOpWithOutput(*model, transposed_weights_name);
2015 if (existing_transpose) {
2016 CHECK(existing_transpose->type == OperatorType::kTranspose);
2017 } else {
2018 // Transpose weights from HWOI order to OHWI order, which is more efficient
2019 // for computation. (Note that TensorFlow considers the order as HWIO
2020 // because they consider this a backward conv, inverting the sense of
2021 // input/output.)
2022 TransposeOperator* transpose = new TransposeOperator;
2023 std::string perm_array = CreateConstArray<ArrayDataType::kInt32>(
2024 model, node.name() + "_transpose_perm", {2, 0, 1, 3});
2025 transpose->inputs = {weights_name, perm_array};
2026 transpose->outputs = {transposed_weights_name};
2027 model->operators.emplace_back(transpose);
2028 }
2029 op->inputs[1] = transposed_weights_name;
2030
2031 auto const& padding = GetStringAttr(node, "padding");
2032 if (padding == "SAME") {
2033 op->padding.type = PaddingType::kSame;
2034 } else if (padding == "VALID") {
2035 op->padding.type = PaddingType::kValid;
2036 } else {
2037 LOG(FATAL) << "Only SAME and VALID padding supported on "
2038 "Conv2DBackpropInput nodes.";
2039 }
2040 model->operators.emplace_back(op);
2041 return tensorflow::Status::OK();
2042 }
2043
ConvertRangeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2044 tensorflow::Status ConvertRangeOperator(
2045 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2046 const ModelFlags& model_flags, Model* model) {
2047 CHECK_EQ(node.op(), "Range");
2048 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
2049 auto* op = new RangeOperator;
2050 if (HasAttr(node, "Tidx")) {
2051 const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
2052 CHECK(dtype == DT_UINT8 || dtype == DT_INT32 || dtype == DT_INT64 ||
2053 dtype == DT_FLOAT);
2054 op->dtype = ConvertDataType(dtype);
2055 }
2056 op->inputs.push_back(node.input(0));
2057 op->inputs.push_back(node.input(1));
2058 op->inputs.push_back(node.input(2));
2059 op->outputs.push_back(node.name());
2060
2061 model->operators.emplace_back(op);
2062 return tensorflow::Status::OK();
2063 }
2064
2065 // Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
2066 // they aren't the same thing. tf.stack results in a "Pack" operator. "Stack"
2067 // operators also exist, but involve manipulating the TF runtime stack, and are
2068 // not directly related to tf.stack() usage.
ConvertPackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2069 tensorflow::Status ConvertPackOperator(
2070 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2071 const ModelFlags& model_flags, Model* model) {
2072 CHECK_EQ(node.op(), "Pack");
2073 auto op = absl::make_unique<PackOperator>();
2074 const int num_inputs = GetInputsCount(node, tf_import_flags);
2075 QCHECK_GE(num_inputs, 1)
2076 << node.op()
2077 << " node expects at least 1 input other than control dependencies: "
2078 << node.DebugString();
2079 CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
2080 for (int i = 0; i < num_inputs; ++i) {
2081 op->inputs.push_back(node.input(i));
2082 }
2083 op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
2084 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2085 op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2086 op->outputs.push_back(node.name());
2087 model->operators.emplace_back(std::move(op));
2088 return tensorflow::Status::OK();
2089 }
2090
ConvertUnpackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2091 tensorflow::Status ConvertUnpackOperator(
2092 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2093 const ModelFlags& model_flags, Model* model) {
2094 CHECK_EQ(node.op(), "Unpack");
2095 auto op = absl::make_unique<UnpackOperator>();
2096 const int num_inputs = GetInputsCount(node, tf_import_flags);
2097 QCHECK_EQ(num_inputs, 1);
2098 op->inputs.push_back(node.input(0));
2099 op->num = GetIntAttr(node, "num");
2100 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2101 op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2102
2103 op->outputs.push_back(node.name()); // Implicit :0.
2104 for (int i = 1; i < op->num; ++i) {
2105 op->outputs.push_back(node.name() + ":" + std::to_string(i));
2106 }
2107 model->operators.emplace_back(std::move(op));
2108 return tensorflow::Status::OK();
2109 }
2110
2111 // Some TensorFlow ops only occur in graph cycles, representing
2112 // control flow. We do not currently support control flow, so we wouldn't
2113 // be able to fully support such graphs, including performing inference,
2114 // anyway. However, rather than erroring out early on graphs being cyclic,
2115 // it helps to at least support these just enough to allow getting a
2116 // graph visualization. This is not trivial, as we require graphs to be
2117 // acyclic aside from RNN back-edges. The solution is to special-case
2118 // such ops as RNN back-edges, which is technically incorrect (does not
2119 // allow representing the op's semantics) but good enough to get a
2120 // graph visualization.
ConvertOperatorSpecialCasedAsRNNBackEdge(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2121 tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
2122 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2123 const ModelFlags& model_flags, Model* model) {
2124 // At the moment, the only type of operator special-cased in this way is
2125 // NextIteration, occurring only in control-flow cycles.
2126 CHECK_EQ(node.op(), "NextIteration");
2127 CHECK_EQ(node.input_size(), 1);
2128 auto* rnn_state = model->flags.add_rnn_states();
2129 // This RNN state is not explicitly created by the user, so it's
2130 // OK for some later graph transformation to discard it.
2131 rnn_state->set_discardable(true);
2132 rnn_state->set_state_array(node.name());
2133 rnn_state->set_back_edge_source_array(node.input(0));
2134 // TODO(tianjuny): Temporary set the size to 1 to avoid transient array
2135 // allocation crash. The real value should depend on the hidden_size of RNN.
2136 rnn_state->set_size(1);
2137 return tensorflow::Status::OK();
2138 }
2139
ConvertShapeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2140 tensorflow::Status ConvertShapeOperator(
2141 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2142 const ModelFlags& model_flags, Model* model) {
2143 CHECK_EQ(node.op(), "Shape");
2144 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2145 const auto out_type =
2146 HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
2147 CHECK(out_type == DT_INT64 || out_type == DT_INT32);
2148 auto op = absl::make_unique<TensorFlowShapeOperator>();
2149 op->output_data_type = ConvertDataType(out_type);
2150 op->inputs.push_back(node.input(0));
2151 op->outputs.push_back(node.name());
2152 model->operators.push_back(std::move(op));
2153 return tensorflow::Status::OK();
2154 }
2155
ConvertReverseSequenceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2156 tensorflow::Status ConvertReverseSequenceOperator(
2157 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2158 const ModelFlags& model_flags, Model* model) {
2159 CHECK_EQ(node.op(), "ReverseSequence");
2160 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2161 auto op = absl::make_unique<ReverseSequenceOperator>();
2162 if (HasAttr(node, "seq_dim")) {
2163 op->seq_dim = GetIntAttr(node, "seq_dim");
2164 }
2165 // In tf.reverse_sequence, batch_dim defaults to 0.
2166 op->batch_dim =
2167 HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0;
2168 const int num_inputs = GetInputsCount(node, tf_import_flags);
2169 for (int i = 0; i < num_inputs; ++i) {
2170 op->inputs.push_back(node.input(i));
2171 }
2172 op->outputs.push_back(node.name());
2173 model->operators.push_back(std::move(op));
2174 return tensorflow::Status::OK();
2175 }
2176
StripCaretFromArrayNames(Model * model)2177 void StripCaretFromArrayNames(Model* model) {
2178 for (auto& op : model->operators) {
2179 for (auto& input : op->inputs) {
2180 input = std::string(absl::StripPrefix(input, "^"));
2181 }
2182 for (auto& output : op->outputs) {
2183 output = std::string(absl::StripPrefix(output, "^"));
2184 }
2185 }
2186 for (auto& array : model->GetArrayMap()) {
2187 if (absl::StartsWith(array.first, "^")) {
2188 LOG(FATAL) << "What?";
2189 }
2190 }
2191 }
2192
StripZeroOutputIndexFromInputs(NodeDef * node)2193 void StripZeroOutputIndexFromInputs(NodeDef* node) {
2194 for (auto& input : *node->mutable_input()) {
2195 input = std::string(absl::StripSuffix(input, ":0"));
2196 }
2197 }
2198
2199 // In TensorFlow GraphDef, when a node has multiple outputs, they are named
2200 // name:0, name:1, ...
2201 // where 'name' is the node's name(). Just 'name' is an equivalent shorthand
2202 // form for name:0.
2203 // A TensorFlow GraphDef does not explicitly list all the outputs of each node
2204 // (unlike inputs), it being implied by the node's name and operator type
2205 // (the latter implies the number of outputs).
2206 // This makes it non-trivial for us to reconstruct the list of all arrays
2207 // present in the graph and, for each operator, the list of its outputs.
2208 // We do that by taking advantage of the fact that
2209 // at least each node lists explicitly its inputs, so after we've loaded
2210 // all nodes, we can use that information.
AddExtraOutputs(Model * model)2211 void AddExtraOutputs(Model* model) {
2212 // Construct the list of all arrays consumed by anything in the graph.
2213 std::vector<std::string> consumed_arrays;
2214 // Add arrays consumed by an op.
2215 for (const auto& consumer_op : model->operators) {
2216 for (const std::string& input : consumer_op->inputs) {
2217 consumed_arrays.push_back(input);
2218 }
2219 }
2220 // Add global outputs of the model.
2221 for (const std::string& output_array : model->flags.output_arrays()) {
2222 consumed_arrays.push_back(output_array);
2223 }
2224 // Add arrays consumed by a RNN back-edge.
2225 for (const auto& rnn_state : model->flags.rnn_states()) {
2226 consumed_arrays.push_back(rnn_state.back_edge_source_array());
2227 }
2228 // Now add operator outputs so that all arrays that are consumed,
2229 // are produced.
2230 for (const std::string& consumed_array : consumed_arrays) {
2231 // Test if consumed_array is already the output of some op.
2232 // This has occurred in a model where separate nodes had names of the form
2233 // foo:$i with the same base name foo.
2234 if (GetOpWithOutput(*model, consumed_array)) {
2235 continue;
2236 }
2237 // Split the consumed array name into the form name:output_index.
2238 const std::vector<std::string>& split = absl::StrSplit(consumed_array, ':');
2239 // If not of the form name:output_index, then this is not an additional
2240 // output of a node with multiple outputs, so nothing to do here.
2241 if (split.size() != 2) {
2242 continue;
2243 }
2244 int output_index = 0;
2245 if (!absl::SimpleAtoi(split[1], &output_index)) {
2246 continue;
2247 }
2248 // Each op is initially recorded as producing at least the array that
2249 // has its name. We use that to identify the producer node.
2250 auto* producer_op = GetOpWithOutput(*model, split[0]);
2251 if (!producer_op) {
2252 continue;
2253 }
2254 // Add extra outputs to that producer node, all the way to the
2255 // output_index.
2256 while (producer_op->outputs.size() <= output_index) {
2257 using toco::port::StringF;
2258 producer_op->outputs.push_back(
2259 StringF("%s:%d", split[0], producer_op->outputs.size()));
2260 }
2261 }
2262 }
2263
InlineAllFunctions(GraphDef * graphdef)2264 bool InlineAllFunctions(GraphDef* graphdef) {
2265 if (graphdef->library().function().empty()) {
2266 VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
2267 return false;
2268 }
2269
2270 // Override "_noinline" attribute on all functions
2271 GraphDef graphdef_copy(*graphdef);
2272 for (auto& function :
2273 (*graphdef_copy.mutable_library()->mutable_function())) {
2274 auto* attributes = function.mutable_attr();
2275 if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
2276 (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
2277 }
2278 }
2279
2280 // Construct minimum resources needed to use ExpandInlineFunctions().
2281 tensorflow::SessionOptions options;
2282 auto* device_count = options.config.mutable_device_count();
2283 device_count->insert({"CPU", 1});
2284 std::vector<std::unique_ptr<tensorflow::Device>> devices;
2285 TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
2286 options, "/job:localhost/replica:0/task:0", &devices));
2287
2288 tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
2289 graphdef_copy.library());
2290 tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
2291 tensorflow::ProcessFunctionLibraryRuntime pflr(
2292 &device_mgr, tensorflow::Env::Default(), &options.config,
2293 TF_GRAPH_DEF_VERSION, &fld,
2294 options.config.graph_options().optimizer_options(), nullptr);
2295 tensorflow::FunctionLibraryRuntime* flr;
2296 flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
2297
2298 tensorflow::Graph graph(fld);
2299 tensorflow::ImportGraphDefOptions gc_opts;
2300 gc_opts.validate_shape = false;
2301 const auto& tf_convert_status = tensorflow::ImportGraphDef(
2302 gc_opts, graphdef_copy, &graph, nullptr, nullptr);
2303 if (!tf_convert_status.ok()) {
2304 LOG(ERROR) << "tensorflow::ImportGraphDef failed with status: "
2305 << tf_convert_status.ToString();
2306 return false;
2307 }
2308
2309 // Iterate over the graph until there are no more nodes to be inlined.
2310 bool graph_modified = false;
2311 while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
2312 graph_modified = true;
2313 }
2314
2315 // Output inlined graph
2316 if (graph_modified) {
2317 LOG(INFO) << "Found and inlined TensorFlow functions.";
2318 graph.ToGraphDef(graphdef);
2319 }
2320 return graph_modified;
2321 }
2322
ConvertTopKV2Operator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2323 tensorflow::Status ConvertTopKV2Operator(
2324 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2325 const ModelFlags& model_flags, Model* model) {
2326 CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
2327 auto op = absl::make_unique<TopKV2Operator>();
2328 op->inputs.push_back(node.input(0));
2329 // K can be encoded as attr (TopK) convert it to a const.
2330 if (HasAttr(node, "k")) {
2331 std::string k_array = CreateConstArray<ArrayDataType::kInt32>(
2332 model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
2333 op->inputs.push_back(k_array);
2334 } else {
2335 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2336 op->inputs.push_back(node.input(1));
2337 }
2338 // The op has two outputs.
2339 op->outputs.push_back(node.name());
2340 op->outputs.push_back(node.name() + ":1");
2341 model->operators.emplace_back(op.release());
2342 return tensorflow::Status::OK();
2343 }
2344
ConvertDynamicPartitionOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2345 tensorflow::Status ConvertDynamicPartitionOperator(
2346 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2347 const ModelFlags& model_flags, Model* model) {
2348 auto op = absl::make_unique<DynamicPartitionOperator>();
2349 CHECK(HasAttr(node, "num_partitions"));
2350 op->num_partitions = GetIntAttr(node, "num_partitions");
2351 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2352 op->inputs.push_back(node.input(0));
2353 op->inputs.push_back(node.input(1));
2354 CHECK_GT(op->num_partitions, 1);
2355 op->outputs.push_back(node.name()); // Implicit :0.
2356 for (int i = 1; i < op->num_partitions; ++i) {
2357 op->outputs.push_back(node.name() + ":" + std::to_string(i));
2358 }
2359 model->operators.emplace_back(op.release());
2360 return tensorflow::Status::OK();
2361 }
2362
ConvertDynamicStitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2363 tensorflow::Status ConvertDynamicStitchOperator(
2364 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2365 const ModelFlags& model_flags, Model* model) {
2366 // The parallel and non-parallel variants are the same besides whether they
2367 // have a parallel loop; there are no behavioral differences.
2368 CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
2369 auto op = absl::make_unique<DynamicStitchOperator>();
2370 CHECK(HasAttr(node, "N"));
2371 op->num_partitions = GetIntAttr(node, "N");
2372 // Expect all ID partitions + all value partitions.
2373 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
2374 for (int i = 0; i < op->num_partitions * 2; ++i) {
2375 op->inputs.push_back(node.input(i));
2376 }
2377 op->outputs.push_back(node.name());
2378 model->operators.emplace_back(op.release());
2379 return tensorflow::Status::OK();
2380 }
2381
ConvertSparseToDenseOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2382 tensorflow::Status ConvertSparseToDenseOperator(
2383 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2384 const ModelFlags& model_flags, Model* model) {
2385 CHECK_EQ(node.op(), "SparseToDense");
2386 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2387
2388 auto* op = new SparseToDenseOperator;
2389 for (const std::string& input : node.input()) {
2390 op->inputs.push_back(input);
2391 }
2392 op->outputs.push_back(node.name());
2393
2394 op->validate_indices = HasAttr(node, "validate_indices")
2395 ? GetBoolAttr(node, "validate_indices")
2396 : true;
2397 model->operators.emplace_back(op);
2398 return tensorflow::Status::OK();
2399 }
2400
ConvertOneHotOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2401 tensorflow::Status ConvertOneHotOperator(
2402 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2403 const ModelFlags& model_flags, Model* model) {
2404 CHECK_EQ(node.op(), "OneHot");
2405 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2406
2407 const auto dtype = GetDataTypeAttr(node, "T");
2408 // TODO(b/111744875): Support DT_UINT8 and quantization.
2409 CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
2410 dtype == DT_BOOL);
2411
2412 auto op = absl::make_unique<OneHotOperator>();
2413 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
2414 for (const std::string& input : node.input()) {
2415 op->inputs.push_back(input);
2416 }
2417 op->outputs.push_back(node.name());
2418 model->operators.emplace_back(op.release());
2419 return tensorflow::Status::OK();
2420 }
2421
ConvertCTCBeamSearchDecoderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2422 tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
2423 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2424 const ModelFlags& model_flags, Model* model) {
2425 CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
2426 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2427
2428 auto* op = new CTCBeamSearchDecoderOperator;
2429 for (const std::string& input : node.input()) {
2430 op->inputs.push_back(input);
2431 }
2432
2433 op->beam_width =
2434 HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
2435 op->top_paths =
2436 HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
2437 op->merge_repeated = HasAttr(node, "merge_repeated")
2438 ? GetBoolAttr(node, "merge_repeated")
2439 : true;
2440
2441 // There are top_paths + 1 outputs.
2442 op->outputs.push_back(node.name()); // Implicit :0.
2443 for (int i = 0; i < op->top_paths; ++i) {
2444 op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
2445 }
2446 model->operators.emplace_back(op);
2447 return tensorflow::Status::OK();
2448 }
2449
2450 // This isn't a TensorFlow builtin op. Currently this node can only be generated
2451 // with TfLite OpHint API.
ConvertUnidirectionalSequenceLstm(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2452 tensorflow::Status ConvertUnidirectionalSequenceLstm(
2453 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2454 const ModelFlags& model_flags, Model* model) {
2455 DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
2456
2457 const auto& indices = GetListAttr(node, "_tflite_input_indices");
2458
2459 auto* op = new UnidirectionalSequenceLstmOperator();
2460
2461 // The input size needs to be the same as the TfLite UniDirectionalSequence
2462 // Lstm implementation.
2463 const int kInputsSize = 20;
2464
2465 op->inputs.resize(kInputsSize);
2466
2467 if (indices.i_size() != node.input().size()) {
2468 // New version, the optional inputs are filled with constant nodes.
2469 int count = 0;
2470 for (int idx = 0; idx < kInputsSize; ++idx) {
2471 if (count < indices.i_size() && indices.i(count) == idx) {
2472 // Specified input.
2473 op->inputs[idx] = node.input(idx);
2474 count++;
2475 } else {
2476 // Optional input.
2477 std::string optional_name = node.name() + "_" + std::to_string(idx);
2478 model->CreateOptionalArray(optional_name);
2479 op->inputs[idx] = optional_name;
2480 }
2481 }
2482 } else { // Legacy version.
2483 std::vector<bool> done(kInputsSize);
2484 int idx = 0;
2485 for (const std::string& input : node.input()) {
2486 int real_index = indices.i(idx);
2487 op->inputs[real_index] = (input);
2488 done[real_index] = true;
2489 idx++;
2490 }
2491
2492 for (int idx = 0; idx < done.size(); idx++) {
2493 if (!done[idx]) {
2494 std::string optional_name = node.name() + "_" + std::to_string(idx);
2495 model->CreateOptionalArray(optional_name);
2496 op->inputs[idx] = optional_name;
2497 }
2498 }
2499 }
2500
2501 // There're three outputs, only the last one is required.
2502 op->outputs.push_back(node.name() + ":2");
2503 model->operators.emplace_back(op);
2504
2505 return tensorflow::Status::OK();
2506 }
2507
ConvertLeakyReluOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2508 tensorflow::Status ConvertLeakyReluOperator(
2509 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2510 const ModelFlags& model_flags, Model* model) {
2511 CHECK_EQ(node.op(), "LeakyRelu");
2512 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2513 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
2514 const auto& input_name = node.input(0);
2515 auto* op = new LeakyReluOperator;
2516 op->inputs.push_back(input_name);
2517 op->outputs.push_back(node.name());
2518 op->alpha = GetFloatAttr(node, "alpha");
2519 model->operators.emplace_back(op);
2520 return tensorflow::Status::OK();
2521 }
2522
ConvertUnidirectionalSequenceRnn(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2523 tensorflow::Status ConvertUnidirectionalSequenceRnn(
2524 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2525 const ModelFlags& model_flags, Model* model) {
2526 DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
2527
2528 const auto& indices = GetListAttr(node, "_tflite_input_indices");
2529 if (indices.i_size() != node.input().size()) {
2530 return tensorflow::errors::InvalidArgument("Input size does not match.");
2531 }
2532
2533 auto* op = new UnidirectionalSequenceRnnOperator();
2534 for (const std::string& input : node.input()) {
2535 op->inputs.push_back(input);
2536 }
2537 // Only use the last one as input.
2538 op->outputs.push_back(node.name() + ":1");
2539 model->operators.emplace_back(op);
2540
2541 return tensorflow::Status::OK();
2542 }
2543
2544 } // namespace
2545
2546 namespace internal {
2547
2548 using ConverterType = tensorflow::Status (*)(
2549 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2550 const ModelFlags& model_flags, Model* model);
2551 using ConverterMapType = std::unordered_map<std::string, ConverterType>;
2552
GetTensorFlowNodeConverterMapForFlex()2553 ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
2554 return std::unordered_map<std::string, ConverterType>({
2555 // We need to let TOCO convert Placeholder information into
2556 // array data, so that the data types are correct.
2557 {"LegacyFedInput", ConvertPlaceholderOperator},
2558 {"Placeholder", ConvertPlaceholderOperator},
2559 {"Const", ConditionallyConvertConstOperator},
2560 });
2561 }
2562
GetTensorFlowNodeConverterMap()2563 ConverterMapType GetTensorFlowNodeConverterMap() {
2564 return std::unordered_map<std::string, ConverterType>({
2565 {"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
2566 {"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
2567 {"AddV2", ConvertSimpleOperator<AddOperator, 2, 1>},
2568 {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
2569 {"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
2570 {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
2571 {"ArgMax", ConvertArgMaxOperator},
2572 {"ArgMin", ConvertArgMinOperator},
2573 {"Assert",
2574 ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
2575 {"AvgPool", ConvertAvgPoolOperator},
2576 {"BatchMatMul", ConvertBatchMatMulOperator},
2577 {"BatchMatMulV2", ConvertBatchMatMulOperator},
2578 {"BatchNormWithGlobalNormalization",
2579 ConvertBatchNormWithGlobalNormalizationOperator},
2580 {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
2581 {"BiasAdd", ConvertBiasAddOperator},
2582 {"Cast", ConvertCastOperator},
2583 {"Ceil", ConvertCeilOperator},
2584 {"CheckNumerics", ConvertIdentityOperator},
2585 {"Concat", ConvertConcatOperator},
2586 {"ConcatV2", ConvertConcatOperator},
2587 {"Const", ConvertConstOperator},
2588 {"Conv2D", ConvertConvOperator},
2589 {"Conv2DBackpropInput", ConvertTransposeConvOperator},
2590 {"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
2591 {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
2592 {"DepthToSpace", ConvertDepthToSpaceOperator},
2593 {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
2594 {"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
2595 {"DynamicPartition", ConvertDynamicPartitionOperator},
2596 {"DynamicStitch", ConvertDynamicStitchOperator},
2597 {"Elu", ConvertSimpleOperator<EluOperator, 1, 1>},
2598 {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
2599 {"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
2600 {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
2601 {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
2602 {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
2603 {"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
2604 {"Floor", ConvertFloorOperator},
2605 {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
2606 {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
2607 {"FusedBatchNorm", ConvertFusedBatchNormOperator},
2608 {"FusedBatchNormV3", ConvertFusedBatchNormOperator},
2609 {"Gather", ConvertGatherOperator},
2610 {"GatherV2", ConvertGatherOperator},
2611 {"GatherNd", ConvertGatherNdOperator},
2612 {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
2613 {"GreaterEqual",
2614 ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
2615 {"Identity", ConvertIdentityOperator},
2616 {"IdentityN", ConvertIdentityNOperator},
2617 {"LRN", ConvertLRNOperator},
2618 {"LeakyRelu", ConvertLeakyReluOperator},
2619 {"LegacyFedInput", ConvertPlaceholderOperator},
2620 {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
2621 {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
2622 {"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
2623 {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
2624 {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
2625 {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
2626 {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
2627 {"MatMul", ConvertMatMulOperator},
2628 {"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
2629 {"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
2630 // `MatrixDiagV3` has an `align` attribute. However, Toco only converts
2631 // `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols,
2632 // padding_value` inputs. In this case, `align` can be ignored.
2633 {"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>},
2634 {"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
2635 {"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
2636 // `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts
2637 // `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this
2638 // case, `align` can be ignored.
2639 {"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>},
2640 {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
2641 {"MaxPool", ConvertMaxPoolOperator},
2642 {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
2643 {"Mean", ConvertReduceOperator<MeanOperator>},
2644 {"Merge",
2645 ConvertSimpleOperator<TensorFlowMergeOperator, kAnyNumInputs, 1>},
2646 {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
2647 {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
2648 {"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
2649 {"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
2650 {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
2651 {"NoOp", ConvertNoOpOperator},
2652 {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
2653 {"OneHot", ConvertOneHotOperator},
2654 {"Pack", ConvertPackOperator},
2655 {"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
2656 {"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
2657 {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
2658 {"Placeholder", ConvertPlaceholderOperator},
2659 {"PlaceholderWithDefault", ConvertIdentityOperator},
2660 {"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
2661 {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
2662 {"RandomUniform", ConvertRandomUniform},
2663 {"Range", ConvertRangeOperator},
2664 {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
2665 {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
2666 {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
2667 {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
2668 {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
2669 {"ResizeBilinear", ConvertResizeBilinearOperator},
2670 {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
2671 {"ReverseSequence", ConvertReverseSequenceOperator},
2672 {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
2673 {"Round", ConvertRoundOperator},
2674 {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
2675 {"ScatterNd", ConvertSimpleOperator<ScatterNdOperator, 3, 1>},
2676 {"SegmentSum", ConvertSimpleOperator<SegmentSumOperator, 2, 1>},
2677 {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
2678 {"SelectV2", ConvertSimpleOperator<SelectOperator, 3, 1>},
2679 {"Shape", ConvertShapeOperator},
2680 {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
2681 {"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
2682 {"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
2683 {"Softmax", ConvertSoftmaxOperator},
2684 {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
2685 {"SpaceToDepth", ConvertSpaceToDepthOperator},
2686 {"SparseToDense", ConvertSparseToDenseOperator},
2687 {"Split", ConvertSplitOperator},
2688 {"SplitV", ConvertSplitVOperator},
2689 {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
2690 {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
2691 {"SquaredDifference",
2692 ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
2693 {"Snapshot", ConvertIdentityOperator},
2694 {"Squeeze", ConvertSqueezeOperator},
2695 {"StopGradient", ConvertIdentityOperator},
2696 {"StridedSlice", ConvertStridedSliceOperator},
2697 {"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
2698 {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
2699 {"Svdf", ConvertSvdfOperator},
2700 {"Switch", ConvertSwitchOperator},
2701 {"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
2702 {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
2703 {"TopK", ConvertTopKV2Operator},
2704 {"TopKV2", ConvertTopKV2Operator},
2705 {"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
2706 {"Unpack", ConvertUnpackOperator},
2707 {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
2708 {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
2709 {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
2710 {"MirrorPad", ConvertMirrorPadOperator},
2711 {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
2712 {"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
2713 });
2714 }
2715
ImportTensorFlowNode(const tensorflow::NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model,const ConverterMapType & converter_map)2716 tensorflow::Status ImportTensorFlowNode(
2717 const tensorflow::NodeDef& node,
2718 const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
2719 Model* model, const ConverterMapType& converter_map) {
2720 auto converter = converter_map.find(node.op());
2721 if (converter == converter_map.end()) {
2722 return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
2723 model);
2724 } else {
2725 return converter->second(node, tf_import_flags, model_flags, model);
2726 }
2727 }
2728 } // namespace internal
2729
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const GraphDef & tf_graph)2730 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2731 const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2732 const GraphDef& tf_graph) {
2733 LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
2734
2735 GraphDef inlined_graph(tf_graph);
2736 if (InlineAllFunctions(&inlined_graph)) {
2737 LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
2738 }
2739
2740 // Check input and output specification.
2741 for (const auto& specified_input_array : model_flags.input_arrays()) {
2742 CHECK(!absl::EndsWith(specified_input_array.name(), ":0"))
2743 << "Unsupported explicit zero output index: "
2744 << specified_input_array.name();
2745 }
2746 for (const std::string& specified_output_array :
2747 model_flags.output_arrays()) {
2748 CHECK(!absl::EndsWith(specified_output_array, ":0"))
2749 << "Unsupported explicit zero output index: " << specified_output_array;
2750 }
2751
2752 Model* model = new Model;
2753 internal::ConverterMapType converter_map;
2754
2755 // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
2756 // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
2757 // converted to TFLite Flex ops.
2758 if (!tf_import_flags.import_all_ops_as_unsupported) {
2759 converter_map = internal::GetTensorFlowNodeConverterMap();
2760 } else {
2761 converter_map = internal::GetTensorFlowNodeConverterMapForFlex();
2762 }
2763
2764 for (auto node : inlined_graph.node()) {
2765 StripZeroOutputIndexFromInputs(&node);
2766 auto status = internal::ImportTensorFlowNode(
2767 node, tf_import_flags, model_flags, model, converter_map);
2768 CHECK(status.ok()) << status.error_message();
2769 }
2770
2771 ResolveModelFlags(model_flags, model);
2772
2773 StripCaretFromArrayNames(model);
2774 AddExtraOutputs(model);
2775 FixNoMissingArray(model);
2776 FixNoOrphanedArray(model);
2777 FixOperatorOrdering(model);
2778 CheckInvariants(*model);
2779
2780 // if rnn state arrays are constant, make them transient
2781 for (const auto& rnn_state : model->flags.rnn_states()) {
2782 model->GetArray(rnn_state.state_array()).buffer = nullptr;
2783 }
2784
2785 return std::unique_ptr<Model>(model);
2786 }
2787
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const std::string & input_file_contents)2788 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2789 const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2790 const std::string& input_file_contents) {
2791 std::unique_ptr<GraphDef> tf_graph(new GraphDef);
2792 CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
2793
2794 std::unique_ptr<GraphDef> pruned_graph =
2795 MaybeReplaceCompositeSubgraph(*tf_graph);
2796 if (pruned_graph) {
2797 tf_graph = std::move(pruned_graph);
2798 }
2799 return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
2800 }
2801
GetPotentiallySupportedOps()2802 std::vector<std::string> GetPotentiallySupportedOps() {
2803 std::vector<std::string> supported_ops;
2804 const internal::ConverterMapType& converter_map =
2805 internal::GetTensorFlowNodeConverterMap();
2806
2807 for (const auto& item : converter_map) {
2808 supported_ops.push_back(item.first);
2809 }
2810 return supported_ops;
2811 }
2812
2813 } // namespace toco
2814