1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/import_tensorflow.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/map.h"
23 #include "google/protobuf/text_format.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_split.h"
29 #include "absl/strings/strip.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/public/session_options.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/lite/toco/model.h"
47 #include "tensorflow/lite/toco/model_flags.pb.h"
48 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
49 #include "tensorflow/lite/toco/tensorflow_util.h"
50 #include "tensorflow/lite/toco/tooling_util.h"
51 
52 using tensorflow::AttrValue;
53 using tensorflow::DT_BOOL;
54 using tensorflow::DT_COMPLEX64;
55 using tensorflow::DT_FLOAT;
56 using tensorflow::DT_INT32;
57 using tensorflow::DT_INT64;
58 using tensorflow::DT_QUINT8;
59 using tensorflow::DT_STRING;
60 using tensorflow::DT_UINT32;
61 using tensorflow::DT_UINT8;
62 using tensorflow::GraphDef;
63 using tensorflow::NodeDef;
64 using tensorflow::OpRegistry;
65 using tensorflow::TensorProto;
66 using tensorflow::TensorShapeProto;
67 
68 namespace toco {
69 
70 namespace {
HasAttr(const NodeDef & node,const std::string & attr_name)71 bool HasAttr(const NodeDef& node, const std::string& attr_name) {
72   return node.attr().count(attr_name) > 0;
73 }
74 
HasWildcardDimension(const TensorShapeProto & shape)75 bool HasWildcardDimension(const TensorShapeProto& shape) {
76   for (const auto& dim : shape.dim()) {
77     if (dim.size() == -1) return true;
78   }
79   return false;
80 }
81 
GetStringAttr(const NodeDef & node,const std::string & attr_name)82 const std::string& GetStringAttr(const NodeDef& node,
83                                  const std::string& attr_name) {
84   CHECK(HasAttr(node, attr_name));
85   const auto& attr = node.attr().at(attr_name);
86   CHECK_EQ(attr.value_case(), AttrValue::kS);
87   return attr.s();
88 }
89 
GetIntAttr(const NodeDef & node,const std::string & attr_name)90 int64 GetIntAttr(const NodeDef& node, const std::string& attr_name) {
91   CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
92                                   << node.DebugString();
93   const auto& attr = node.attr().at(attr_name);
94   CHECK_EQ(attr.value_case(), AttrValue::kI);
95   return attr.i();
96 }
97 
GetFloatAttr(const NodeDef & node,const std::string & attr_name)98 float GetFloatAttr(const NodeDef& node, const std::string& attr_name) {
99   CHECK(HasAttr(node, attr_name));
100   const auto& attr = node.attr().at(attr_name);
101   CHECK_EQ(attr.value_case(), AttrValue::kF);
102   return attr.f();
103 }
104 
GetBoolAttr(const NodeDef & node,const std::string & attr_name)105 bool GetBoolAttr(const NodeDef& node, const std::string& attr_name) {
106   CHECK(HasAttr(node, attr_name));
107   const auto& attr = node.attr().at(attr_name);
108   CHECK_EQ(attr.value_case(), AttrValue::kB);
109   return attr.b();
110 }
111 
GetDataTypeAttr(const NodeDef & node,const std::string & attr_name)112 tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
113                                      const std::string& attr_name) {
114   CHECK(HasAttr(node, attr_name));
115   const auto& attr = node.attr().at(attr_name);
116   CHECK_EQ(attr.value_case(), AttrValue::kType);
117   return attr.type();
118 }
119 
GetShapeAttr(const NodeDef & node,const std::string & attr_name)120 const TensorShapeProto& GetShapeAttr(const NodeDef& node,
121                                      const std::string& attr_name) {
122   CHECK(HasAttr(node, attr_name));
123   const auto& attr = node.attr().at(attr_name);
124   CHECK_EQ(attr.value_case(), AttrValue::kShape);
125   return attr.shape();
126 }
127 
GetTensorAttr(const NodeDef & node,const std::string & attr_name)128 const TensorProto& GetTensorAttr(const NodeDef& node,
129                                  const std::string& attr_name) {
130   CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
131   const auto& attr = node.attr().at(attr_name);
132   CHECK_EQ(attr.value_case(), AttrValue::kTensor);
133   return attr.tensor();
134 }
135 
GetListAttr(const NodeDef & node,const std::string & attr_name)136 const AttrValue::ListValue& GetListAttr(const NodeDef& node,
137                                         const std::string& attr_name) {
138   CHECK(HasAttr(node, attr_name));
139   const auto& attr = node.attr().at(attr_name);
140   CHECK_EQ(attr.value_case(), AttrValue::kList);
141   return attr.list();
142 }
143 
CheckOptionalAttr(const NodeDef & node,const std::string & attr_name,const std::string & expected_value)144 tensorflow::Status CheckOptionalAttr(const NodeDef& node,
145                                      const std::string& attr_name,
146                                      const std::string& expected_value) {
147   if (HasAttr(node, attr_name)) {
148     const std::string& value = GetStringAttr(node, attr_name);
149     if (value != expected_value) {
150       return tensorflow::errors::InvalidArgument(
151           "Unexpected value for attribute '" + attr_name + "'. Expected '" +
152           expected_value + "'");
153     }
154   }
155   return tensorflow::Status::OK();
156 }
157 
CheckOptionalAttr(const NodeDef & node,const std::string & attr_name,const tensorflow::DataType & expected_value)158 tensorflow::Status CheckOptionalAttr(
159     const NodeDef& node, const std::string& attr_name,
160     const tensorflow::DataType& expected_value) {
161   if (HasAttr(node, attr_name)) {
162     const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
163     if (value != expected_value) {
164       return tensorflow::errors::InvalidArgument(
165           "Unexpected value for attribute '" + attr_name + "'. Expected '" +
166           tensorflow::DataType_Name(expected_value) + "'");
167     }
168   }
169   return tensorflow::Status::OK();
170 }
171 
172 template <typename T1, typename T2>
ExpectValue(const T1 & v1,const T2 & v2,const std::string & description)173 tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
174                                const std::string& description) {
175   if (v1 == v2) return tensorflow::Status::OK();
176   return tensorflow::errors::InvalidArgument(absl::StrCat(
177       "Unexpected ", description, ": got ", v1, ", expected ", v2));
178 }
179 
ConvertDataType(tensorflow::DataType dtype)180 ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
181   if (dtype == DT_UINT8)
182     return ArrayDataType::kUint8;
183   else if (dtype == DT_FLOAT)
184     return ArrayDataType::kFloat;
185   else if (dtype == DT_BOOL)
186     return ArrayDataType::kBool;
187   else if (dtype == DT_INT32)
188     return ArrayDataType::kInt32;
189   else if (dtype == DT_UINT32)
190     return ArrayDataType::kUint32;
191   else if (dtype == DT_INT64)
192     return ArrayDataType::kInt64;
193   else if (dtype == DT_STRING)
194     return ArrayDataType::kString;
195   else if (dtype == DT_COMPLEX64)
196     return ArrayDataType::kComplex64;
197   else
198     LOG(INFO) << "Unsupported data type in placeholder op: " << dtype;
199   return ArrayDataType::kNone;
200 }
201 
ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim> & input_dims,int * input_flat_size,Shape * shape)202 tensorflow::Status ImportShape(
203     const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
204         input_dims,
205     int* input_flat_size, Shape* shape) {
206   std::vector<int> input_dims_only_sizes;
207   bool zero_sized_shape = false;
208   for (auto& d : input_dims) {
209     // TensorFlow's shapes use int64s, while TOCO uses ints.
210     if (d.size() > std::numeric_limits<int>::max()) {
211       return tensorflow::errors::InvalidArgument("Shape element overflows");
212     }
213     if (d.size() == 0) {
214       zero_sized_shape = true;
215     }
216     input_dims_only_sizes.push_back(d.size());
217   }
218 
219   // Note that up to this point we were OK with the input shape containing
220   // elements valued -1 or 0, which are perfectly legal in tensorflow. However
221   // our CheckValidShapeDimensions() insists on them being >= 1, with the
222   // exception of the "scalar" shape [0]. The main issue with zero-values shape
223   // elements is that the corresponding arrays don't contain any data and the
224   // allocation code gets a bit confused. It seems that the code expects an
225   // empty shape for zero-sized shapes, so we will do just that, except for the
226   // [0] case.
227   // TODO(b/119325030): In order to correctly import the "scalar" shapes the
228   // following test must include "&& input_dims_only_sizes.size() > 1", but
229   // that seems to slow everything down a lot.
230   if (zero_sized_shape) {
231     shape->mutable_dims()->clear();
232     if (input_flat_size != nullptr) *input_flat_size = 0;
233     return tensorflow::Status::OK();
234   }
235 
236   *shape->mutable_dims() = input_dims_only_sizes;
237 
238   if (input_flat_size == nullptr) return tensorflow::Status::OK();
239 
240   return NumElements(input_dims_only_sizes, input_flat_size);
241 }
242 
243 // Define ways to retrieve data from tensors of different types.
244 // TODO(b/80208043): simply use tensorflow::Tensor::FromProto() instead.
245 template <typename T>
246 struct TensorTraits;
247 
248 template <>
249 struct TensorTraits<float> {
sizetoco::__anon2e11eff70111::TensorTraits250   static int size(const TensorProto& p) { return p.float_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits251   static float get(const TensorProto& p, int i) { return p.float_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits252   static std::string accessor_name() { return "float_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits253   static std::string type_name() { return "float"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits254   static void CopyFromContent(const TensorProto& p, std::vector<float>* data) {
255     toco::port::CopyToBuffer(p.tensor_content(),
256                              reinterpret_cast<char*>(data->data()));
257   }
258 };
259 
260 template <>
261 struct TensorTraits<uint8_t> {
sizetoco::__anon2e11eff70111::TensorTraits262   static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits263   static uint8_t get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits264   static std::string accessor_name() { return "int_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits265   static std::string type_name() { return "uint8"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits266   static void CopyFromContent(const TensorProto& p,
267                               std::vector<uint8_t>* data) {
268     toco::port::CopyToBuffer(p.tensor_content(),
269                              reinterpret_cast<char*>(data->data()));
270   }
271 };
272 
273 template <>
274 struct TensorTraits<std::complex<float>> {
sizetoco::__anon2e11eff70111::TensorTraits275   static int size(const TensorProto& p) { return p.scomplex_val_size() / 2; }
gettoco::__anon2e11eff70111::TensorTraits276   static std::complex<float> get(const TensorProto& p, int i) {
277     return std::complex<float>(p.scomplex_val(2 * i),
278                                p.scomplex_val(2 * i + 1));
279   }
accessor_nametoco::__anon2e11eff70111::TensorTraits280   static std::string accessor_name() { return "scomplex_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits281   static std::string type_name() { return "complex64"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits282   static void CopyFromContent(const TensorProto& p,
283                               std::vector<std::complex<float>>* data) {
284     toco::port::CopyToBuffer(p.tensor_content(),
285                              reinterpret_cast<char*>(data->data()));
286   }
287 };
288 
289 template <>
290 struct TensorTraits<int32> {
sizetoco::__anon2e11eff70111::TensorTraits291   static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits292   static int32 get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits293   static std::string accessor_name() { return "int_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits294   static std::string type_name() { return "int32"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits295   static void CopyFromContent(const TensorProto& p, std::vector<int32>* data) {
296     toco::port::CopyToBuffer(p.tensor_content(),
297                              reinterpret_cast<char*>(data->data()));
298   }
299 };
300 
301 template <>
302 struct TensorTraits<uint32> {
sizetoco::__anon2e11eff70111::TensorTraits303   static int size(const TensorProto& p) { return p.uint32_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits304   static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits305   static std::string accessor_name() { return "uint32_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits306   static std::string type_name() { return "uint32"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits307   static void CopyFromContent(const TensorProto& p, std::vector<uint32>* data) {
308     toco::port::CopyToBuffer(p.tensor_content(),
309                              reinterpret_cast<char*>(data->data()));
310   }
311 };
312 
313 template <>
314 struct TensorTraits<int64> {
sizetoco::__anon2e11eff70111::TensorTraits315   static int size(const TensorProto& p) { return p.int64_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits316   static int64 get(const TensorProto& p, int i) { return p.int64_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits317   static std::string accessor_name() { return "int64_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits318   static std::string type_name() { return "int64"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits319   static void CopyFromContent(const TensorProto& p, std::vector<int64>* data) {
320     toco::port::CopyToBuffer(p.tensor_content(),
321                              reinterpret_cast<char*>(data->data()));
322   }
323 };
324 
325 template <>
326 struct TensorTraits<bool> {
sizetoco::__anon2e11eff70111::TensorTraits327   static int size(const TensorProto& p) { return p.bool_val_size(); }
gettoco::__anon2e11eff70111::TensorTraits328   static bool get(const TensorProto& p, int i) { return p.bool_val(i); }
accessor_nametoco::__anon2e11eff70111::TensorTraits329   static std::string accessor_name() { return "bool_val"; }
type_nametoco::__anon2e11eff70111::TensorTraits330   static std::string type_name() { return "bool"; }
CopyFromContenttoco::__anon2e11eff70111::TensorTraits331   static void CopyFromContent(const TensorProto& p, std::vector<bool>* data) {
332     std::vector<char> buf(p.tensor_content().size());
333     toco::port::CopyToBuffer(p.tensor_content(), buf.data());
334     for (int i = 0; i < p.tensor_content().size(); i++) {
335       (*data)[i] = static_cast<bool>(buf[i]);
336     }
337   }
338 };
339 
340 template <typename T>
ImportTensorData(const TensorProto & input_tensor,int input_flat_size,std::vector<T> * output_data)341 tensorflow::Status ImportTensorData(const TensorProto& input_tensor,
342                                     int input_flat_size,
343                                     std::vector<T>* output_data) {
344   CHECK_GE(output_data->size(), input_flat_size);
345   int num_elements_in_tensor = TensorTraits<T>::size(input_tensor);
346   if (num_elements_in_tensor == input_flat_size) {
347     for (int i = 0; i < num_elements_in_tensor; i++) {
348       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
349     }
350   } else if (input_tensor.tensor_content().size() ==
351              input_flat_size * sizeof(T)) {
352     TensorTraits<T>::CopyFromContent(input_tensor, output_data);
353   } else if (num_elements_in_tensor > 0 &&
354              num_elements_in_tensor < input_flat_size) {
355     // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
356     // official way to import tensor data. This particular else-if handles a
357     // grappler optimization where the last few elements in a tensor are
358     // omitted if they are repeated.
359     int i = 0;
360     for (; i < num_elements_in_tensor; ++i) {
361       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
362     }
363     auto last = (*output_data)[i - 1];
364     for (; i < input_flat_size; ++i) {
365       (*output_data)[i] = last;
366     }
367   } else {
368     std::string accessor_name = TensorTraits<T>::accessor_name();
369     std::string type_name = TensorTraits<T>::type_name();
370     return tensorflow::errors::InvalidArgument(
371         absl::StrCat("Neither input_content (",
372                      input_tensor.tensor_content().size() / sizeof(T), ") nor ",
373                      accessor_name, " (", num_elements_in_tensor,
374                      ") have the right dimensions (", input_flat_size,
375                      ") for this ", type_name, " tensor"));
376   }
377   return tensorflow::Status::OK();
378 }
379 
ImportFloatArray(const TensorProto & input_tensor,Array * output_array)380 tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
381                                     Array* output_array) {
382   CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
383   const auto& input_shape = input_tensor.tensor_shape();
384   CHECK_LE(input_shape.dim_size(), 6);
385   int input_flat_size;
386   auto status = ImportShape(input_shape.dim(), &input_flat_size,
387                             output_array->mutable_shape());
388   if (!status.ok()) return status;
389 
390   auto& output_float_data =
391       output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
392   output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
393                            0.f);
394   return ImportTensorData<float>(input_tensor, input_flat_size,
395                                  &output_float_data);
396 }
397 
ImportComplex64Array(const TensorProto & input_tensor,Array * output_array)398 tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor,
399                                         Array* output_array) {
400   CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64);
401   const auto& input_shape = input_tensor.tensor_shape();
402   CHECK_LE(input_shape.dim_size(), 4);
403   int input_flat_size;
404   auto status = ImportShape(input_shape.dim(), &input_flat_size,
405                             output_array->mutable_shape());
406   if (!status.ok()) return status;
407 
408   auto& output_complex_data =
409       output_array->GetMutableBuffer<ArrayDataType::kComplex64>().data;
410   output_complex_data.resize(RequiredBufferSizeForShape(output_array->shape()),
411                              std::complex<float>(0.f, 0.f));
412   return ImportTensorData<std::complex<float>>(input_tensor, input_flat_size,
413                                                &output_complex_data);
414 }
415 
ImportQuint8Array(const TensorProto & input_tensor,Array * output_array)416 tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
417                                      Array* output_array) {
418   CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
419   const auto& input_shape = input_tensor.tensor_shape();
420   CHECK_LE(input_shape.dim_size(), 6);
421   int input_flat_size;
422   auto status = ImportShape(input_shape.dim(), &input_flat_size,
423                             output_array->mutable_shape());
424   if (!status.ok()) return status;
425 
426   auto& output_int_data =
427       output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
428   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
429   return ImportTensorData<uint8_t>(input_tensor, input_flat_size,
430                                    &output_int_data);
431 }
432 
ImportInt32Array(const TensorProto & input_tensor,Array * output_array)433 tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
434                                     Array* output_array) {
435   CHECK_EQ(input_tensor.dtype(), DT_INT32);
436   const auto& input_shape = input_tensor.tensor_shape();
437   CHECK_LE(input_shape.dim_size(), 6);
438   int input_flat_size;
439   auto status = ImportShape(input_shape.dim(), &input_flat_size,
440                             output_array->mutable_shape());
441   if (!status.ok()) return status;
442 
443   auto& output_int_data =
444       output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
445   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
446   return ImportTensorData<int32>(input_tensor, input_flat_size,
447                                  &output_int_data);
448 }
449 
ImportUint32Array(const TensorProto & input_tensor,Array * output_array)450 tensorflow::Status ImportUint32Array(const TensorProto& input_tensor,
451                                      Array* output_array) {
452   CHECK_EQ(input_tensor.dtype(), DT_UINT32);
453   const auto& input_shape = input_tensor.tensor_shape();
454   CHECK_LE(input_shape.dim_size(), 6);
455   int input_flat_size;
456   auto status = ImportShape(input_shape.dim(), &input_flat_size,
457                             output_array->mutable_shape());
458   if (!status.ok()) return status;
459 
460   auto& output_int_data =
461       output_array->GetMutableBuffer<ArrayDataType::kUint32>().data;
462   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
463   return ImportTensorData<uint32>(input_tensor, input_flat_size,
464                                   &output_int_data);
465 }
466 
ImportInt64Array(const TensorProto & input_tensor,Array * output_array)467 tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
468                                     Array* output_array) {
469   CHECK_EQ(input_tensor.dtype(), DT_INT64);
470   const auto& input_shape = input_tensor.tensor_shape();
471   CHECK_LE(input_shape.dim_size(), 6);
472   int input_flat_size;
473   auto status = ImportShape(input_shape.dim(), &input_flat_size,
474                             output_array->mutable_shape());
475   if (!status.ok()) return status;
476 
477   auto& output_int_data =
478       output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
479   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
480   return ImportTensorData<int64>(input_tensor, input_flat_size,
481                                  &output_int_data);
482 }
483 
ImportBoolArray(const TensorProto & input_tensor,Array * output_array)484 tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
485                                    Array* output_array) {
486   CHECK_EQ(input_tensor.dtype(), DT_BOOL);
487   const auto& input_shape = input_tensor.tensor_shape();
488   CHECK_LE(input_shape.dim_size(), 6);
489   int input_flat_size;
490   auto status = ImportShape(input_shape.dim(), &input_flat_size,
491                             output_array->mutable_shape());
492   if (!status.ok()) return status;
493 
494   auto& output_bool_data =
495       output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
496   output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
497                           false);
498   status =
499       ImportTensorData<bool>(input_tensor, input_flat_size, &output_bool_data);
500   if (!status.ok() && output_bool_data.size() == 1) {
501     // Some graphs have bool const nodes without actual value...
502     // assuming that 'false' is implied.
503     // So far only encountered that in an array with 1 entry, let's
504     // require that until we encounter a graph where that's not the case.
505     output_bool_data[0] = false;
506     return tensorflow::Status::OK();
507   }
508   return status;
509 }
510 
ImportStringArray(const TensorProto & input_tensor,Array * output_array)511 tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
512                                      Array* output_array) {
513   CHECK_EQ(input_tensor.dtype(), DT_STRING);
514   const auto& input_shape = input_tensor.tensor_shape();
515   CHECK_LE(input_shape.dim_size(), 6);
516   int input_flat_size;
517   auto status = ImportShape(input_shape.dim(), &input_flat_size,
518                             output_array->mutable_shape());
519   if (!status.ok()) return status;
520 
521   if (input_flat_size != input_tensor.string_val_size()) {
522     return tensorflow::errors::InvalidArgument(
523         "Input_content string_val doesn't have the right dimensions "
524         "for this string tensor");
525   }
526 
527   auto& output_string_data =
528       output_array->GetMutableBuffer<ArrayDataType::kString>().data;
529   output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
530   CHECK_GE(output_string_data.size(), input_flat_size);
531   for (int i = 0; i < input_flat_size; ++i) {
532     output_string_data[i] = input_tensor.string_val(i);
533   }
534   return tensorflow::Status::OK();
535 }
536 
537 // Count the number of inputs of a given node. If
538 // `tf_import_flags.drop_control_dependency` is true, count the number of
539 // non-control-dependency inputs.
GetInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags)540 int GetInputsCount(const NodeDef& node,
541                    const TensorFlowImportFlags& tf_import_flags) {
542   if (tf_import_flags.drop_control_dependency) {
543     for (size_t i = 0; i < node.input_size(); ++i) {
544       if (node.input(i)[0] == '^') {
545         return i;
546       }
547     }
548   }
549   return node.input_size();
550 }
551 
CheckInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,int expected_input_count)552 tensorflow::Status CheckInputsCount(
553     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
554     int expected_input_count) {
555   if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
556     return tensorflow::errors::FailedPrecondition(
557         node.op(), " node expects ", expected_input_count,
558         " input(s) other than control dependencies: ", node.DebugString());
559   }
560   return tensorflow::Status::OK();
561 }
562 
563 template <ArrayDataType T>
CreateConstArray(Model * model,std::string const & name,std::vector<typename toco::DataType<T>> const & data)564 std::string CreateConstArray(
565     Model* model, std::string const& name,
566     std::vector<typename toco::DataType<T>> const& data) {
567   // Utility function to create a const 1D array, useful for input parameters.
568   std::string array_name = toco::AvailableArrayName(*model, name);
569   auto& array = model->GetOrCreateArray(array_name);
570   array.data_type = T;
571   array.mutable_shape()->mutable_dims()->emplace_back(
572       static_cast<int>(data.size()));
573   array.GetMutableBuffer<T>().data = data;
574   return array_name;
575 }
576 
577 // Retain TensorFlow NodeDef in Toco Operator.
578 //
579 // If an op is supported by Toco but not supported by TFLite, TFLite exporter
580 // will use the retained NodeDef to populate a Flex op when Flex mode is
581 // enabled.
582 //
583 // This can't be easily applied to all operations, because a TensorFlow node
584 // may become multiple Toco operators. Thus we need to call this function in
585 // operator conversion functions one by one whenever feasible.
586 //
587 // This may cause problems if a graph transformation rule changes parameters
588 // of the node. When calling this function, please check if any existing
589 // graph transformation rule will change an existing operator with the same
590 // type.
591 //
592 // This provides a route to handle Toco-supported & TFLite-unsupported ops
593 // in Flex mode. However it's not a solid solution. Eventually we should
594 // get rid of this.
595 // TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
596 // this function.
RetainTensorFlowNodeDef(const NodeDef & node,Operator * op)597 void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
598   node.SerializeToString(&op->tensorflow_node_def);
599 }
600 
GetOutputNamesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)601 void GetOutputNamesFromNodeDef(const NodeDef& node,
602                                const tensorflow::OpDef& op_def,
603                                TensorFlowUnsupportedOperator* op) {
604   int next_output = 0;
605   auto add_output = [&node, &next_output, op]() {
606     if (next_output == 0) {
607       op->outputs.push_back(node.name());  // Implicit :0.
608     } else {
609       op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
610     }
611     ++next_output;
612   };
613   for (int i = 0; i < op_def.output_arg_size(); ++i) {
614     std::string multiples = op_def.output_arg(i).number_attr();
615     if (!multiples.empty()) {
616       CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
617       int num_outputs = GetIntAttr(node, multiples);
618       for (int j = 0; j < num_outputs; ++j) {
619         add_output();
620       }
621     } else {
622       std::string list = op_def.output_arg(i).type_list_attr();
623       if (!list.empty()) {
624         CHECK(HasAttr(node, list)) << "No attr named " << list;
625         const AttrValue::ListValue& list_value = GetListAttr(node, list);
626         for (int j = 0; j < list_value.type_size(); ++j) {
627           add_output();
628         }
629       } else {
630         add_output();
631       }
632     }
633   }
634 }
635 
GetOutputTypesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)636 void GetOutputTypesFromNodeDef(const NodeDef& node,
637                                const tensorflow::OpDef& op_def,
638                                TensorFlowUnsupportedOperator* op) {
639   // The given type to the op, or clear the types if invalid.
640   auto add_type = [&node, op](tensorflow::DataType type) {
641     if (type == tensorflow::DT_INVALID) {
642       LOG(WARNING) << "Op node missing output type attribute: " << node.name();
643       op->output_data_types.clear();
644     } else {
645       op->output_data_types.push_back(ConvertDataType(type));
646     }
647   };
648 
649   // Retrieve the data type according to the OpDef definition: either the
650   // "type" or "type_attr" field will be set.
651   auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
652     if (a.type() != tensorflow::DT_INVALID) {
653       return a.type();
654     } else if (HasAttr(node, a.type_attr())) {
655       return GetDataTypeAttr(node, a.type_attr());
656     } else {
657       return tensorflow::DT_INVALID;
658     }
659   };
660 
661   for (int i = 0; i < op_def.output_arg_size(); ++i) {
662     std::string multiples = op_def.output_arg(i).number_attr();
663     if (!multiples.empty()) {
664       CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
665       int num_outputs = GetIntAttr(node, multiples);
666       auto type = get_type(op_def.output_arg(i));
667       for (int j = 0; j < num_outputs; ++j) {
668         add_type(type);
669       }
670     } else {
671       std::string list = op_def.output_arg(i).type_list_attr();
672       if (!list.empty()) {
673         CHECK(HasAttr(node, list)) << "No attr named " << list;
674         const AttrValue::ListValue& list_value = GetListAttr(node, list);
675         for (int j = 0; j < list_value.type_size(); ++j) {
676           add_type(list_value.type(j));
677         }
678       } else {
679         add_type(get_type(op_def.output_arg(i)));
680       }
681     }
682   }
683 }
684 
ConvertUnsupportedOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)685 tensorflow::Status ConvertUnsupportedOperator(
686     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
687     const ModelFlags& model_flags, Model* model) {
688   // Names of special attributes in TF graph that are used by Toco.
689   static constexpr char kAttrOutputQuantized[] = "_output_quantized";
690   static constexpr char kAttrOutputTypes[] = "_output_types";
691   static constexpr char kAttrOutputShapes[] = "_output_shapes";
692   static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
693       "_support_output_type_float_in_quantized_op";
694 
695   LOG(INFO) << "Converting unsupported operation: " << node.op();
696 
697   auto* op = new TensorFlowUnsupportedOperator;
698   op->tensorflow_op = node.op();
699 
700   // For Flex mode. Please read the comments of the function.
701   RetainTensorFlowNodeDef(node, op);
702 
703   model->operators.emplace_back(op);
704 
705   // Parse inputs.
706   const int num_inputs = GetInputsCount(node, tf_import_flags);
707   for (int i = 0; i < num_inputs; ++i) {
708     op->inputs.push_back(node.input(i));
709   }
710 
711   // Parse outputs. Name them after the node's name, plus an ordinal suffix.
712   // Note that some outputs are to be multiplied by a named attribute.
713   const tensorflow::OpDef* op_def = nullptr;
714   if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
715     GetOutputNamesFromNodeDef(node, *op_def, op);
716   } else {
717     op->outputs.push_back(node.name());  // Implicit :0.
718   }
719 
720   // Parse if the op supports quantization
721   if (HasAttr(node, kAttrOutputQuantized)) {
722     op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
723   }
724   // Parse if the quantized op allows output arrays of type float
725   if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
726     op->support_output_type_float_in_quantized_op =
727         GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
728   }
729 
730   // Parse output type(s).
731   if (HasAttr(node, kAttrOutputTypes)) {
732     const auto& output_types = GetListAttr(node, kAttrOutputTypes);
733     for (int i = 0; i < output_types.type_size(); ++i) {
734       op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
735     }
736   } else if (HasAttr(node, "Tout")) {
737     const auto& output_type = GetDataTypeAttr(node, "Tout");
738     op->output_data_types.push_back(ConvertDataType(output_type));
739   } else if (op_def != nullptr) {
740     GetOutputTypesFromNodeDef(node, *op_def, op);
741   } else {
742     // TODO(b/113613439): Figure out how to propagate types for custom ops
743     // that have no OpDef.
744     LOG(INFO) << "Unable to determine output type for op: " << node.op();
745   }
746 
747   // Parse output shape(s).
748   if (HasAttr(node, kAttrOutputShapes)) {
749     const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
750     Shape output_shape;
751     for (int i = 0; i < output_shapes.shape_size(); ++i) {
752       const auto& shape = output_shapes.shape(i);
753       // TOCO doesn't yet properly handle shapes with wildcard dimensions.
754       // TODO(b/113613439): Handle shape inference for unsupported ops that have
755       // shapes with wildcard dimensions.
756       if (HasWildcardDimension(shape)) {
757         LOG(INFO) << "Skipping wildcard output shape(s) for node: "
758                   << node.name();
759         op->output_shapes.clear();
760         break;
761       }
762       const auto status =
763           ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
764       if (!status.ok()) {
765         return status;
766       }
767       op->output_shapes.push_back(output_shape);
768     }
769   }
770   return tensorflow::Status::OK();
771 }
772 
ConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)773 tensorflow::Status ConvertConstOperator(
774     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
775     const ModelFlags& model_flags, Model* model) {
776   CHECK_EQ(node.op(), "Const");
777   const auto& tensor = GetTensorAttr(node, "value");
778   const auto dtype = GetDataTypeAttr(node, "dtype");
779 
780   tensorflow::Status status = tensorflow::Status::OK();
781 
782   auto& array = model->GetOrCreateArray(node.name());
783   switch (dtype) {
784     case DT_FLOAT:
785       array.data_type = ArrayDataType::kFloat;
786       status = ImportFloatArray(tensor, &array);
787       break;
788     case DT_INT32:
789       array.data_type = ArrayDataType::kInt32;
790       status = ImportInt32Array(tensor, &array);
791       break;
792     case DT_UINT32:
793       array.data_type = ArrayDataType::kUint32;
794       status = ImportUint32Array(tensor, &array);
795       break;
796     case DT_QUINT8:
797       array.data_type = ArrayDataType::kUint8;
798       status = ImportQuint8Array(tensor, &array);
799       break;
800     case DT_INT64:
801       array.data_type = ArrayDataType::kInt64;
802       status = ImportInt64Array(tensor, &array);
803       break;
804     case DT_STRING:
805       array.data_type = ArrayDataType::kString;
806       status = ImportStringArray(tensor, &array);
807       break;
808     case DT_BOOL:
809       array.data_type = ArrayDataType::kBool;
810       status = ImportBoolArray(tensor, &array);
811       break;
812     case DT_COMPLEX64:
813       array.data_type = ArrayDataType::kComplex64;
814       status = ImportComplex64Array(tensor, &array);
815       break;
816     default:
817       array.data_type = ArrayDataType::kNone;
818       // do nothing, silently ignore the Const data.
819       // We just make a dummy buffer to indicate that
820       // this array does not rely on external input.
821       array.GetMutableBuffer<ArrayDataType::kNone>();
822       break;
823   }
824   TF_RETURN_WITH_CONTEXT_IF_ERROR(
825       status, " (while processing node '" + node.name() + "')");
826   return tensorflow::Status::OK();
827 }
828 
ConvertConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)829 tensorflow::Status ConvertConvOperator(
830     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
831     const ModelFlags& model_flags, Model* model) {
832   CHECK_EQ(node.op(), "Conv2D");
833   TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
834 
835   // We only support NHWC, which is the default data_format.
836   // So if data_format is not defined, we're all good.
837   TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
838   TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
839 
840   const auto& input_name = node.input(0);
841   const auto& weights_name = node.input(1);
842   const auto& reordered_weights_name =
843       AvailableArrayName(*model, weights_name + "_reordered");
844   // Check if a ReorderAxesOperator was already created for these weights
845   // (that happens when multiple layers share the same weights).
846   const Operator* existing_reorder =
847       GetOpWithOutput(*model, reordered_weights_name);
848   if (existing_reorder) {
849     // Check that it is safe to rely on the _reordered naming of the output
850     // array!
851     CHECK(existing_reorder->type == OperatorType::kReorderAxes);
852   } else {
853     // Create a new ReorderAxesOperator
854     auto* reorder = new ReorderAxesOperator;
855     reorder->inputs = {weights_name};
856     reorder->outputs = {reordered_weights_name};
857     reorder->input_axes_order = AxesOrder::kHWIO;
858     reorder->output_axes_order = AxesOrder::kOHWI;
859     model->operators.emplace_back(reorder);
860   }
861   if (!HasAttr(node, "strides")) {
862     return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
863   }
864   const auto& strides = GetListAttr(node, "strides");
865   TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
866   TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
867   TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
868   int dilation_height_factor;
869   int dilation_width_factor;
870   if (HasAttr(node, "dilations")) {
871     const auto& dilations = GetListAttr(node, "dilations");
872     TF_RETURN_IF_ERROR(
873         ExpectValue(dilations.i_size(), 4, "number of dilations"));
874     if (dilations.i(0) != 1 || dilations.i(3) != 1) {
875       return tensorflow::errors::InvalidArgument(absl::StrCat(
876           "Can only import Conv ops with dilation along the height "
877           "(1st) or width (2nd) axis. TensorFlow op \"",
878           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
879           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
880     }
881     dilation_height_factor = dilations.i(1);
882     dilation_width_factor = dilations.i(2);
883   } else {
884     dilation_height_factor = 1;
885     dilation_width_factor = 1;
886   }
887   const auto& padding = GetStringAttr(node, "padding");
888   PaddingType padding_type;
889   if (padding == "SAME") {
890     padding_type = PaddingType::kSame;
891   } else if (padding == "VALID") {
892     padding_type = PaddingType::kValid;
893   } else {
894     return tensorflow::errors::InvalidArgument(
895         "Bad padding (only SAME and VALID are supported)");
896   }
897   auto* conv = new ConvOperator;
898   conv->inputs = {input_name, reordered_weights_name};
899   conv->outputs = {node.name()};
900   conv->stride_height = strides.i(1);
901   conv->stride_width = strides.i(2);
902   conv->dilation_height_factor = dilation_height_factor;
903   conv->dilation_width_factor = dilation_width_factor;
904   conv->padding.type = padding_type;
905   model->operators.emplace_back(conv);
906 
907   return tensorflow::Status::OK();
908 }
909 
ConvertDepthwiseConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)910 tensorflow::Status ConvertDepthwiseConvOperator(
911     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
912     const ModelFlags& model_flags, Model* model) {
913   CHECK_EQ(node.op(), "DepthwiseConv2dNative");
914   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
915 
916   // We only support NHWC, which is the default data_format.
917   // So if data_format is not defined, we're all good.
918   if (HasAttr(node, "data_format")) {
919     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
920   }
921   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
922 
923   const auto& input_name = node.input(0);
924   const auto& weights_name = node.input(1);
925   const auto& reordered_weights_name = weights_name + "_reordered";
926   // Check if a ReorderAxesOperator was already created for these weights
927   // (that happens when multiple layers share the same weights).
928   const Operator* existing_reorder =
929       GetOpWithOutput(*model, reordered_weights_name);
930   if (existing_reorder) {
931     // Check that it is safe to rely on the _reordered naming of the output
932     // array!
933     CHECK(existing_reorder->type == OperatorType::kReorderAxes);
934   } else {
935     // Create a new ReorderAxesOperator
936     auto* reorder = new ReorderAxesOperator;
937     reorder->inputs = {weights_name};
938     reorder->outputs = {reordered_weights_name};
939     reorder->input_axes_order = AxesOrder::kHWIM;
940     reorder->output_axes_order = AxesOrder::k1HWO;
941     model->operators.emplace_back(reorder);
942   }
943   const auto& strides = GetListAttr(node, "strides");
944   TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
945   TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
946   TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
947   int dilation_height_factor;
948   int dilation_width_factor;
949   if (HasAttr(node, "dilations")) {
950     const auto& dilations = GetListAttr(node, "dilations");
951     TF_RETURN_IF_ERROR(
952         ExpectValue(dilations.i_size(), 4, "number of dilations"));
953     if (dilations.i(0) != 1 || dilations.i(3) != 1) {
954       return tensorflow::errors::InvalidArgument(absl::StrCat(
955           "Can only import Conv ops with dilation along the height "
956           "(1st) or width (2nd) axis. TensorFlow op \"",
957           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
958           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
959     }
960     dilation_height_factor = dilations.i(1);
961     dilation_width_factor = dilations.i(2);
962   } else {
963     dilation_height_factor = 1;
964     dilation_width_factor = 1;
965   }
966   const auto& padding = GetStringAttr(node, "padding");
967   PaddingType padding_type;
968   if (padding == "SAME") {
969     padding_type = PaddingType::kSame;
970   } else if (padding == "VALID") {
971     padding_type = PaddingType::kValid;
972   } else {
973     return tensorflow::errors::InvalidArgument(
974         "Bad padding (only SAME and VALID are supported)");
975   }
976   auto* conv = new DepthwiseConvOperator;
977   conv->inputs = {input_name, reordered_weights_name};
978   conv->outputs = {node.name()};
979   conv->stride_height = strides.i(1);
980   conv->stride_width = strides.i(2);
981   conv->dilation_height_factor = dilation_height_factor;
982   conv->dilation_width_factor = dilation_width_factor;
983   conv->padding.type = padding_type;
984   model->operators.emplace_back(conv);
985   return tensorflow::Status::OK();
986 }
987 
ConvertDepthToSpaceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)988 tensorflow::Status ConvertDepthToSpaceOperator(
989     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
990     const ModelFlags& model_flags, Model* model) {
991   CHECK_EQ(node.op(), "DepthToSpace");
992   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
993 
994   tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
995   if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
996       dtype != DT_INT64) {
997     const auto* enum_descriptor = tensorflow::DataType_descriptor();
998     LOG(FATAL) << "TFLite does not support DepthToSpace with type T:"
999                << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1000                << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1001   }
1002   auto* op = new DepthToSpaceOperator;
1003   op->inputs.push_back(node.input(0));
1004   op->outputs.push_back(node.name());
1005   op->block_size = GetIntAttr(node, "block_size");
1006   QCHECK_GE(op->block_size, 2);
1007   model->operators.emplace_back(op);
1008   return tensorflow::Status::OK();
1009 }
1010 
ConvertSpaceToDepthOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1011 tensorflow::Status ConvertSpaceToDepthOperator(
1012     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1013     const ModelFlags& model_flags, Model* model) {
1014   CHECK_EQ(node.op(), "SpaceToDepth");
1015   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1016 
1017   tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
1018   if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
1019       dtype != DT_INT64) {
1020     const auto* enum_descriptor = tensorflow::DataType_descriptor();
1021     LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
1022                << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1023                << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1024   }
1025   auto* op = new SpaceToDepthOperator;
1026   op->inputs.push_back(node.input(0));
1027   op->outputs.push_back(node.name());
1028   op->block_size = GetIntAttr(node, "block_size");
1029   QCHECK_GE(op->block_size, 2);
1030   model->operators.emplace_back(op);
1031   return tensorflow::Status::OK();
1032 }
1033 
ConvertBiasAddOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1034 tensorflow::Status ConvertBiasAddOperator(
1035     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1036     const ModelFlags& model_flags, Model* model) {
1037   CHECK_EQ(node.op(), "BiasAdd");
1038   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1039 
1040   const auto& input_name = node.input(0);
1041   const auto& bias_name = node.input(1);
1042   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1043   auto* biasadd = new AddOperator;
1044   biasadd->inputs.push_back(input_name);
1045   biasadd->inputs.push_back(bias_name);
1046   biasadd->outputs.push_back(node.name());
1047   model->operators.emplace_back(biasadd);
1048   return tensorflow::Status::OK();
1049 }
1050 
ConvertRandomUniform(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1051 tensorflow::Status ConvertRandomUniform(
1052     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1053     const ModelFlags& model_flags, Model* model) {
1054   CHECK_EQ(node.op(), "RandomUniform");
1055   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1056 
1057   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
1058   auto op = absl::make_unique<RandomUniformOperator>();
1059   op->inputs.push_back(node.input(0));
1060   op->outputs.push_back(node.name());
1061   op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1062   op->seed = GetIntAttr(node, "seed");
1063   op->seed2 = GetIntAttr(node, "seed2");
1064   CHECK(model != nullptr);
1065   model->operators.emplace_back(std::move(op));
1066   return tensorflow::Status::OK();
1067 }
1068 
ConvertIdentityOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1069 tensorflow::Status ConvertIdentityOperator(
1070     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1071     const ModelFlags& model_flags, Model* model) {
1072   CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
1073         node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
1074         node.op() == "Snapshot");
1075   auto* op = new TensorFlowIdentityOperator;
1076   // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
1077   // identity nodes with multiple inputs, but the other inputs seem
1078   // to be gratuitous (in the case of rajeev_lstm.pb, these are
1079   // enumerating the LSTM state arrays). We will just ignore extra
1080   // inputs beyond the first input.
1081   QCHECK_GE(node.input_size(), 1)
1082       << node.op()
1083       << " node expects at least 1 input other than control dependencies: "
1084       << node.DebugString();
1085   const auto& input_name = node.input(0);
1086   op->inputs.push_back(input_name);
1087   op->outputs.push_back(node.name());
1088   model->operators.emplace_back(op);
1089   return tensorflow::Status::OK();
1090 }
1091 
ConvertIdentityNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1092 tensorflow::Status ConvertIdentityNOperator(
1093     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1094     const ModelFlags& model_flags, Model* model) {
1095   CHECK_EQ(node.op(), "IdentityN");
1096   for (int i = 0; i < node.input_size(); ++i) {
1097     auto* op = new TensorFlowIdentityOperator;
1098     const auto& input_name = node.input(i);
1099     std::string output_name = node.name();
1100     if (i > 0) {
1101       output_name = output_name + ":" + std::to_string(i);
1102     }
1103     op->inputs.push_back(input_name);
1104     op->outputs.push_back(output_name);
1105     model->operators.emplace_back(op);
1106   }
1107   return tensorflow::Status::OK();
1108 }
1109 
ConvertFakeQuantWithMinMaxArgs(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1110 tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
1111     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1112     const ModelFlags& model_flags, Model* model) {
1113   CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
1114   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1115   auto* op = new FakeQuantOperator;
1116   op->inputs.push_back(node.input(0));
1117   op->minmax.reset(new MinMax);
1118   auto& minmax = *op->minmax;
1119   minmax.min = GetFloatAttr(node, "min");
1120   minmax.max = GetFloatAttr(node, "max");
1121   op->outputs.push_back(node.name());
1122   // tf.fake_quant_with_min_max_args num_bits defaults to 8.
1123   op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1124   if (HasAttr(node, "narrow_range")) {
1125     op->narrow_range = GetBoolAttr(node, "narrow_range");
1126   }
1127   model->operators.emplace_back(op);
1128   return tensorflow::Status::OK();
1129 }
1130 
ConvertFakeQuantWithMinMaxVars(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1131 tensorflow::Status ConvertFakeQuantWithMinMaxVars(
1132     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1133     const ModelFlags& model_flags, Model* model) {
1134   CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
1135   const int num_inputs = GetInputsCount(node, tf_import_flags);
1136   QCHECK(num_inputs == 3 || num_inputs == 4)
1137       << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
1138          "control dependencies: "
1139       << node.DebugString();
1140   auto* op = new FakeQuantOperator;
1141   for (int i = 0; i < 3; i++) {
1142     op->inputs.push_back(node.input(i));
1143   }
1144   op->outputs.push_back(node.name());
1145   op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1146   if (HasAttr(node, "narrow_range")) {
1147     op->narrow_range = GetBoolAttr(node, "narrow_range");
1148   }
1149   model->operators.emplace_back(op);
1150   return tensorflow::Status::OK();
1151 }
1152 
ConvertSqueezeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1153 tensorflow::Status ConvertSqueezeOperator(
1154     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1155     const ModelFlags& model_flags, Model* model) {
1156   CHECK_EQ(node.op(), "Squeeze");
1157   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1158   auto* op = new SqueezeOperator;
1159   op->inputs.push_back(node.input(0));
1160   op->outputs.push_back(node.name());
1161 
1162   // When omitted we are to squeeze all dimensions == 1.
1163   if (HasAttr(node, "squeeze_dims")) {
1164     const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
1165     for (int i = 0; i < squeeze_dims.i_size(); ++i) {
1166       op->squeeze_dims.push_back(squeeze_dims.i(i));
1167     }
1168   }
1169 
1170   model->operators.emplace_back(op);
1171   return tensorflow::Status::OK();
1172 }
1173 
ConvertSplitOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1174 tensorflow::Status ConvertSplitOperator(
1175     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1176     const ModelFlags& model_flags, Model* model) {
1177   CHECK_EQ(node.op(), "Split");
1178   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1179   auto* op = new TensorFlowSplitOperator;
1180   op->inputs.push_back(node.input(0));
1181   op->inputs.push_back(node.input(1));
1182   const int num_split = GetIntAttr(node, "num_split");
1183   op->outputs.push_back(node.name());
1184   for (int i = 1; i < num_split; i++) {
1185     op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1186   }
1187   op->num_split = num_split;
1188   model->operators.emplace_back(op);
1189   return tensorflow::Status::OK();
1190 }
1191 
ConvertSplitVOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1192 tensorflow::Status ConvertSplitVOperator(
1193     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1194     const ModelFlags& model_flags, Model* model) {
1195   CHECK_EQ(node.op(), "SplitV");
1196   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1197   auto* op = new TensorFlowSplitVOperator;
1198   op->inputs.push_back(node.input(0));
1199   op->inputs.push_back(node.input(1));
1200   op->inputs.push_back(node.input(2));
1201   const int num_split = GetIntAttr(node, "num_split");
1202   op->outputs.push_back(node.name());
1203   for (int i = 1; i < num_split; i++) {
1204     op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1205   }
1206   op->num_split = num_split;
1207   model->operators.emplace_back(op);
1208   return tensorflow::Status::OK();
1209 }
1210 
ConvertSwitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1211 tensorflow::Status ConvertSwitchOperator(
1212     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1213     const ModelFlags& model_flags, Model* model) {
1214   CHECK_EQ(node.op(), "Switch");
1215   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1216   auto* op = new TensorFlowSwitchOperator;
1217   op->inputs.push_back(node.input(0));
1218   op->inputs.push_back(node.input(1));
1219   op->outputs.push_back(node.name());
1220   // Switch operators have two outputs: "name" and "name:1".
1221   op->outputs.push_back(node.name() + ":1");
1222   model->operators.emplace_back(op);
1223   return tensorflow::Status::OK();
1224 }
1225 
ConvertSoftmaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1226 tensorflow::Status ConvertSoftmaxOperator(
1227     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1228     const ModelFlags& model_flags, Model* model) {
1229   CHECK_EQ(node.op(), "Softmax");
1230   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1231   const auto& input_name = node.input(0);
1232   auto* softmax = new SoftmaxOperator;
1233   softmax->inputs.push_back(input_name);
1234   softmax->outputs.push_back(node.name());
1235   // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
1236   CHECK(!node.attr().count("beta"));  // Stab in the dark, just in case.
1237   if (node.attr().count("_softmax_beta")) {
1238     softmax->beta = GetFloatAttr(node, "_softmax_beta");
1239   } else {
1240     softmax->beta = 1.f;
1241   }
1242   model->operators.emplace_back(softmax);
1243   return tensorflow::Status::OK();
1244 }
1245 
ConvertLRNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1246 tensorflow::Status ConvertLRNOperator(
1247     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1248     const ModelFlags& model_flags, Model* model) {
1249   CHECK_EQ(node.op(), "LRN");
1250   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1251   const auto& input_name = node.input(0);
1252   auto* lrn = new LocalResponseNormalizationOperator;
1253   lrn->inputs.push_back(input_name);
1254   lrn->outputs.push_back(node.name());
1255   lrn->range = GetIntAttr(node, "depth_radius");
1256   lrn->bias = GetFloatAttr(node, "bias");
1257   lrn->alpha = GetFloatAttr(node, "alpha");
1258   lrn->beta = GetFloatAttr(node, "beta");
1259   model->operators.emplace_back(lrn);
1260   return tensorflow::Status::OK();
1261 }
1262 
ConvertMaxPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1263 tensorflow::Status ConvertMaxPoolOperator(
1264     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1265     const ModelFlags& model_flags, Model* model) {
1266   CHECK_EQ(node.op(), "MaxPool");
1267   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1268   const auto& input_name = node.input(0);
1269   // We only support NHWC, which is the default data_format.
1270   // So if data_format is not defined, we're all good.
1271   if (node.attr().count("data_format")) {
1272     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1273   }
1274   if (HasAttr(node, "T")) {
1275     CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1276   } else {
1277     LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
1278   }
1279   auto* maxpool = new MaxPoolOperator;
1280   maxpool->inputs.push_back(input_name);
1281   maxpool->outputs.push_back(node.name());
1282   const auto& strides = GetListAttr(node, "strides");
1283   CHECK_EQ(strides.i_size(), 4);
1284   CHECK_EQ(strides.i(0), 1);
1285   CHECK_EQ(strides.i(3), 1);
1286   maxpool->stride_height = strides.i(1);
1287   maxpool->stride_width = strides.i(2);
1288   const auto& ksize = GetListAttr(node, "ksize");
1289   CHECK_EQ(ksize.i_size(), 4);
1290   CHECK_EQ(ksize.i(0), 1);
1291   CHECK_EQ(ksize.i(3), 1);
1292   maxpool->kheight = ksize.i(1);
1293   maxpool->kwidth = ksize.i(2);
1294   const auto& padding = GetStringAttr(node, "padding");
1295   if (padding == "SAME") {
1296     maxpool->padding.type = PaddingType::kSame;
1297   } else if (padding == "VALID") {
1298     maxpool->padding.type = PaddingType::kValid;
1299   } else {
1300     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1301   }
1302   model->operators.emplace_back(maxpool);
1303   return tensorflow::Status::OK();
1304 }
1305 
ConvertAvgPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1306 tensorflow::Status ConvertAvgPoolOperator(
1307     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1308     const ModelFlags& model_flags, Model* model) {
1309   CHECK_EQ(node.op(), "AvgPool");
1310   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1311   const auto& input_name = node.input(0);
1312   // We only support NHWC, which is the default data_format.
1313   // So if data_format is not defined, we're all good.
1314   if (node.attr().count("data_format")) {
1315     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1316   }
1317   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1318   auto* avgpool = new AveragePoolOperator;
1319   avgpool->inputs.push_back(input_name);
1320   avgpool->outputs.push_back(node.name());
1321   const auto& strides = GetListAttr(node, "strides");
1322   CHECK_EQ(strides.i_size(), 4);
1323   CHECK_EQ(strides.i(0), 1);
1324   CHECK_EQ(strides.i(3), 1);
1325   avgpool->stride_height = strides.i(1);
1326   avgpool->stride_width = strides.i(2);
1327   const auto& ksize = GetListAttr(node, "ksize");
1328   CHECK_EQ(ksize.i_size(), 4);
1329   CHECK_EQ(ksize.i(0), 1);
1330   CHECK_EQ(ksize.i(3), 1);
1331   avgpool->kheight = ksize.i(1);
1332   avgpool->kwidth = ksize.i(2);
1333   const auto& padding = GetStringAttr(node, "padding");
1334   if (padding == "SAME") {
1335     avgpool->padding.type = PaddingType::kSame;
1336   } else if (padding == "VALID") {
1337     avgpool->padding.type = PaddingType::kValid;
1338   } else {
1339     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1340   }
1341   model->operators.emplace_back(avgpool);
1342   return tensorflow::Status::OK();
1343 }
1344 
ConvertBatchMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1345 tensorflow::Status ConvertBatchMatMulOperator(
1346     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1347     const ModelFlags& model_flags, Model* model) {
1348   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1349 
1350   auto* batch_matmul = new BatchMatMulOperator;
1351   // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
1352   if (HasAttr(node, "adj_x")) {
1353     batch_matmul->adj_x = GetBoolAttr(node, "adj_x");
1354   }
1355   if (HasAttr(node, "adj_y")) {
1356     batch_matmul->adj_y = GetBoolAttr(node, "adj_y");
1357   }
1358   batch_matmul->inputs = {node.input(0), node.input(1)};
1359   batch_matmul->outputs = {node.name()};
1360 
1361   // For Flex mode. Please read the comments of the function.
1362   RetainTensorFlowNodeDef(node, batch_matmul);
1363 
1364   model->operators.emplace_back(batch_matmul);
1365   return tensorflow::Status::OK();
1366 }
1367 
ConvertMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1368 tensorflow::Status ConvertMatMulOperator(
1369     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1370     const ModelFlags& model_flags, Model* model) {
1371   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1372 
1373   CHECK(!HasAttr(node, "adjoint_a") ||
1374         (GetBoolAttr(node, "adjoint_a") == false));
1375   CHECK(!HasAttr(node, "adjoint_b") ||
1376         (GetBoolAttr(node, "adjoint_b") == false));
1377 
1378   auto* matmul = new TensorFlowMatMulOperator;
1379   if (HasAttr(node, "transpose_a")) {
1380     matmul->transpose_a = GetBoolAttr(node, "transpose_a");
1381   }
1382   if (HasAttr(node, "transpose_b")) {
1383     matmul->transpose_b = GetBoolAttr(node, "transpose_b");
1384   }
1385 
1386   matmul->inputs = {node.input(0), node.input(1)};
1387   matmul->outputs = {node.name()};
1388   model->operators.emplace_back(matmul);
1389   return tensorflow::Status::OK();
1390 }
1391 
ConvertConcatOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1392 tensorflow::Status ConvertConcatOperator(
1393     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1394     const ModelFlags& model_flags, Model* model) {
1395   Operator* op = nullptr;
1396   if (node.op() == "Concat") {
1397     op = new TensorFlowConcatOperator;
1398   } else if (node.op() == "ConcatV2") {
1399     op = new TensorFlowConcatV2Operator;
1400   } else {
1401     LOG(FATAL) << "Expected Concat or ConcatV2";
1402   }
1403   const int num_inputs = GetInputsCount(node, tf_import_flags);
1404   QCHECK_GE(num_inputs, 2)
1405       << node.op()
1406       << " node expects at least 2 inputs other than control dependencies: "
1407       << node.DebugString();
1408   CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1409   for (int i = 0; i < num_inputs; ++i) {
1410     op->inputs.push_back(node.input(i));
1411   }
1412   op->outputs.push_back(node.name());
1413   model->operators.emplace_back(op);
1414   return tensorflow::Status::OK();
1415 }
1416 
ConvertMirrorPadOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1417 tensorflow::Status ConvertMirrorPadOperator(
1418     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1419     const ModelFlags& model_flags, Model* model) {
1420   if (node.op() != "MirrorPad") {
1421     LOG(FATAL) << "Expected MirrorPad.";
1422   }
1423   const int num_inputs = GetInputsCount(node, tf_import_flags);
1424   CHECK_EQ(num_inputs, 2);
1425   auto* op = new MirrorPadOperator;
1426   for (int i = 0; i < num_inputs; ++i) {
1427     op->inputs.push_back(node.input(i));
1428   }
1429   op->outputs.push_back(node.name());
1430   const auto mode = GetStringAttr(node, "mode");
1431   if (mode == "REFLECT") {
1432     op->mode = toco::MirrorPadMode::kReflect;
1433   } else if (mode == "SYMMETRIC") {
1434     op->mode = toco::MirrorPadMode::kSymmetric;
1435   }
1436 
1437   model->operators.emplace_back(op);
1438 
1439   return tensorflow::Status::OK();
1440 }
1441 
1442 static constexpr int kAnyNumInputs = -1;
1443 
1444 enum FlexSupport { kFlexOk, kFlexNotOk };
1445 
1446 // This method supports simple operators without additional attributes.
1447 // Converts a simple operator that takes no attributes. The list of inputs is
1448 // taken from the given NodeDef, and its number must match NumInputs, unless
1449 // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
1450 // will be eligible for being exported as a flex op.
1451 template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
ConvertSimpleOperatorGeneric(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1452 tensorflow::Status ConvertSimpleOperatorGeneric(
1453     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1454     const ModelFlags& model_flags, Model* model) {
1455   if (NumInputs != kAnyNumInputs) {
1456     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
1457   }
1458   auto* op = new Op;
1459   const int num_inputs = GetInputsCount(node, tf_import_flags);
1460   for (int i = 0; i < num_inputs; ++i) {
1461     op->inputs.push_back(node.input(i));
1462   }
1463   op->outputs.push_back(node.name());
1464   if (NumOutputs > 1) {
1465     for (int i = 1; i < NumOutputs; ++i) {
1466       op->outputs.push_back(node.name() + ":" + std::to_string(i));
1467     }
1468   }
1469 
1470   if (flex == kFlexOk) {
1471     RetainTensorFlowNodeDef(node, op);
1472   }
1473 
1474   model->operators.emplace_back(op);
1475   return tensorflow::Status::OK();
1476 }
1477 
1478 // Convert a simple operator which is not valid as a flex op.
1479 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1480 tensorflow::Status ConvertSimpleOperator(
1481     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1482     const ModelFlags& model_flags, Model* model) {
1483   return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
1484       node, tf_import_flags, model_flags, model);
1485 }
1486 
1487 // Convert a simple operator which is valid as a flex op.
1488 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperatorFlexOk(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1489 tensorflow::Status ConvertSimpleOperatorFlexOk(
1490     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1491     const ModelFlags& model_flags, Model* model) {
1492   return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
1493       node, tf_import_flags, model_flags, model);
1494 }
1495 
1496 // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
1497 // the types are not supported. Converting Const operators here avoids
1498 // expensive copies of the protocol buffers downstream in the flex delegate.
ConditionallyConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1499 tensorflow::Status ConditionallyConvertConstOperator(
1500     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1501     const ModelFlags& model_flags, Model* model) {
1502   // We avoid incomplete and zero shapes because the resulting arrays
1503   // are not completely compatible with Eager/TensorFlow.
1504   const auto& tensor = GetTensorAttr(node, "value");
1505   const auto& shape = tensor.tensor_shape();
1506   for (const auto& dim : shape.dim()) {
1507     if (dim.size() <= 0) {
1508       return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1509                                         model);
1510     }
1511   }
1512   switch (GetDataTypeAttr(node, "dtype")) {
1513     case DT_FLOAT:
1514     case DT_INT32:
1515     case DT_QUINT8:
1516     case DT_INT64:
1517     case DT_STRING:
1518     case DT_BOOL:
1519     case DT_COMPLEX64:
1520       return ConvertConstOperator(node, tf_import_flags, model_flags, model);
1521     default:
1522       return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1523                                         model);
1524   }
1525 }
1526 
ConvertStridedSliceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1527 tensorflow::Status ConvertStridedSliceOperator(
1528     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1529     const ModelFlags& model_flags, Model* model) {
1530   CHECK_EQ(node.op(), "StridedSlice");
1531   // TODO(soroosh): The 4th input (strides) should be e optional, to be
1532   // consistent with TF.
1533   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
1534 
1535   auto* op = new StridedSliceOperator;
1536   for (const auto& input : node.input()) {
1537     op->inputs.push_back(input);
1538   }
1539   op->outputs.push_back(node.name());
1540 
1541   op->begin_mask =
1542       HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
1543   op->ellipsis_mask =
1544       HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
1545   op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
1546   op->new_axis_mask =
1547       HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
1548   op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
1549                              ? GetIntAttr(node, "shrink_axis_mask")
1550                              : 0;
1551 
1552   model->operators.emplace_back(op);
1553   return tensorflow::Status::OK();
1554 }
1555 
ConvertPlaceholderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1556 tensorflow::Status ConvertPlaceholderOperator(
1557     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1558     const ModelFlags& model_flags, Model* model) {
1559   CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1560   if (node.op() == "Placeholder") {
1561     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
1562   }
1563 
1564   bool inside_input_arrays = false;
1565   for (const auto& input_array : model_flags.input_arrays()) {
1566     if (node.name() == input_array.name()) {
1567       inside_input_arrays = true;
1568       break;
1569     }
1570   }
1571 
1572   if (!inside_input_arrays) {
1573     model->AddInvalidInputArray(node.name());
1574   }
1575 
1576   auto& array = model->GetOrCreateArray(node.name());
1577   if (node.attr().count("dtype")) {
1578     array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1579   }
1580   if (node.attr().count("shape")) {
1581     const auto& shape = GetShapeAttr(node, "shape");
1582     auto num_dims = shape.dim_size();
1583     // TODO(b/62716978): This logic needs to be revisited.  During dims
1584     // refactoring it is an interim fix.
1585     if (num_dims > 0 && !HasWildcardDimension(shape)) {
1586       auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
1587       dst_array_dims.resize(num_dims);
1588       for (std::size_t i = 0; i < num_dims; i++) {
1589         dst_array_dims[i] = shape.dim(i).size();
1590       }
1591     }
1592   }
1593   return tensorflow::Status::OK();
1594 }
1595 
ConvertNoOpOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1596 tensorflow::Status ConvertNoOpOperator(
1597     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1598     const ModelFlags& model_flags, Model* model) {
1599   return tensorflow::Status::OK();
1600 }
1601 
ConvertCastOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1602 tensorflow::Status ConvertCastOperator(
1603     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1604     const ModelFlags& model_flags, Model* model) {
1605   CHECK_EQ(node.op(), "Cast");
1606   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1607   const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1608   const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1609   auto* op = new CastOperator;
1610   op->src_data_type = ConvertDataType(tf_src_dtype);
1611   op->dst_data_type = ConvertDataType(tf_dst_dtype);
1612   op->inputs.push_back(node.input(0));
1613   op->outputs.push_back(node.name());
1614   model->operators.emplace_back(op);
1615   return tensorflow::Status::OK();
1616 }
1617 
ConvertFloorOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1618 tensorflow::Status ConvertFloorOperator(
1619     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1620     const ModelFlags& model_flags, Model* model) {
1621   CHECK_EQ(node.op(), "Floor");
1622   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1623   const auto data_type = GetDataTypeAttr(node, "T");
1624   CHECK(data_type == DT_FLOAT);
1625   auto* op = new FloorOperator;
1626   op->inputs.push_back(node.input(0));
1627   op->outputs.push_back(node.name());
1628   model->operators.emplace_back(op);
1629   return tensorflow::Status::OK();
1630 }
1631 
ConvertCeilOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1632 tensorflow::Status ConvertCeilOperator(
1633     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1634     const ModelFlags& model_flags, Model* model) {
1635   CHECK_EQ(node.op(), "Ceil");
1636   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1637   const auto data_type = GetDataTypeAttr(node, "T");
1638   CHECK(data_type == DT_FLOAT);
1639   auto* op = new CeilOperator;
1640   op->inputs.push_back(node.input(0));
1641   op->outputs.push_back(node.name());
1642   model->operators.emplace_back(op);
1643   return tensorflow::Status::OK();
1644 }
1645 
ConvertRoundOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1646 tensorflow::Status ConvertRoundOperator(
1647     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1648     const ModelFlags& model_flags, Model* model) {
1649   CHECK_EQ(node.op(), "Round");
1650   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1651   const auto data_type = GetDataTypeAttr(node, "T");
1652   CHECK(data_type == DT_FLOAT);
1653   auto* op = new RoundOperator;
1654   op->inputs.push_back(node.input(0));
1655   op->outputs.push_back(node.name());
1656   model->operators.emplace_back(op);
1657   return tensorflow::Status::OK();
1658 }
1659 
ConvertGatherOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1660 tensorflow::Status ConvertGatherOperator(
1661     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1662     const ModelFlags& model_flags, Model* model) {
1663   CHECK(node.op() == "Gather" || node.op() == "GatherV2");
1664   if (node.op() == "Gather")
1665     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1666   if (node.op() == "GatherV2")
1667     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1668   const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1669   CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1670   auto* op = new GatherOperator;
1671   op->inputs.push_back(node.input(0));
1672   op->inputs.push_back(node.input(1));
1673   if (node.input_size() >= 3) {
1674     // GatherV2 form where we are provided an axis. It may be either a constant
1675     // or runtime defined value, so we just wire up the array and let
1676     // ResolveGatherAttributes take care of it later on.
1677     const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
1678     CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
1679     op->inputs.push_back(node.input(2));
1680   } else {
1681     // Gather form that assumes axis=0.
1682     op->axis = {0};
1683   }
1684   op->outputs.push_back(node.name());
1685   model->operators.emplace_back(op);
1686   return tensorflow::Status::OK();
1687 }
1688 
ConvertGatherNdOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1689 tensorflow::Status ConvertGatherNdOperator(
1690     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1691     const ModelFlags& model_flags, Model* model) {
1692   CHECK_EQ(node.op(), "GatherNd");
1693   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1694   const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1695   CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1696   auto* op = new GatherNdOperator;
1697   op->inputs.push_back(node.input(0));
1698   op->inputs.push_back(node.input(1));
1699   op->outputs.push_back(node.name());
1700   model->operators.emplace_back(op);
1701   return tensorflow::Status::OK();
1702 }
1703 
1704 template <typename Op>
ConvertArgMinMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1705 tensorflow::Status ConvertArgMinMaxOperator(
1706     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1707     const ModelFlags& model_flags, Model* model) {
1708   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1709   const auto axis_data_type =
1710       HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
1711   const auto output_type = HasAttr(node, "output_type")
1712                                ? GetDataTypeAttr(node, "output_type")
1713                                : DT_INT64;
1714   CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
1715   CHECK(output_type == DT_INT64 || output_type == DT_INT32);
1716   auto* op = new Op;
1717   op->output_data_type = ConvertDataType(output_type);
1718   op->inputs.push_back(node.input(0));
1719   op->inputs.push_back(node.input(1));
1720   op->outputs.push_back(node.name());
1721   model->operators.emplace_back(op);
1722   return tensorflow::Status::OK();
1723 }
1724 
ConvertArgMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1725 tensorflow::Status ConvertArgMaxOperator(
1726     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1727     const ModelFlags& model_flags, Model* model) {
1728   CHECK_EQ(node.op(), "ArgMax");
1729   return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
1730                                                   model_flags, model);
1731 }
1732 
ConvertArgMinOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1733 tensorflow::Status ConvertArgMinOperator(
1734     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1735     const ModelFlags& model_flags, Model* model) {
1736   CHECK_EQ(node.op(), "ArgMin");
1737   return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
1738                                                   model_flags, model);
1739 }
1740 
ConvertResizeBilinearOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1741 tensorflow::Status ConvertResizeBilinearOperator(
1742     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1743     const ModelFlags& model_flags, Model* model) {
1744   CHECK_EQ(node.op(), "ResizeBilinear");
1745   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1746   auto* op = new ResizeBilinearOperator;
1747 
1748   op->align_corners = false;
1749   op->half_pixel_centers = false;
1750   if (HasAttr(node, "align_corners")) {
1751     op->align_corners = GetBoolAttr(node, "align_corners");
1752   }
1753   if (HasAttr(node, "half_pixel_centers")) {
1754     op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1755   }
1756 
1757   op->inputs.push_back(node.input(0));
1758   op->inputs.push_back(node.input(1));
1759   op->outputs.push_back(node.name());
1760   model->operators.emplace_back(op);
1761   return tensorflow::Status::OK();
1762 }
1763 
ConvertResizeNearestNeighborOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1764 tensorflow::Status ConvertResizeNearestNeighborOperator(
1765     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1766     const ModelFlags& model_flags, Model* model) {
1767   CHECK_EQ(node.op(), "ResizeNearestNeighbor");
1768   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1769   auto* op = new ResizeNearestNeighborOperator;
1770 
1771   op->align_corners = false;
1772   op->half_pixel_centers = false;
1773   if (HasAttr(node, "align_corners")) {
1774     op->align_corners = GetBoolAttr(node, "align_corners");
1775   }
1776   if (HasAttr(node, "half_pixel_centers")) {
1777     op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1778   }
1779 
1780   op->inputs.push_back(node.input(0));
1781   op->inputs.push_back(node.input(1));
1782   op->outputs.push_back(node.name());
1783   model->operators.emplace_back(op);
1784   return tensorflow::Status::OK();
1785 }
1786 
ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1787 tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
1788     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1789     const ModelFlags& model_flags, Model* model) {
1790   CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1791   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1792 
1793   // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
1794   // to the input, before feeding it into TensorFlowRsqrtOperator.
1795   // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1796 
1797   std::string multiplier = node.name() + "_mul";
1798   if (GetBoolAttr(node, "scale_after_normalization")) {
1799     // Create graph:
1800     //   v -> RSQRT ->
1801     //                 MUL  -> multiplier
1802     //   gamma  ----->
1803     std::string rsqrt = node.name() + "_rsqrt";
1804 
1805     auto* rsqrt_op = new TensorFlowRsqrtOperator;
1806     rsqrt_op->inputs.push_back(node.input(2));
1807     rsqrt_op->outputs.push_back(rsqrt);
1808     model->operators.emplace_back(rsqrt_op);
1809 
1810     auto* mul_op = new MulOperator;
1811     mul_op->inputs.push_back(rsqrt);
1812     mul_op->inputs.push_back(node.input(4));
1813     mul_op->outputs.push_back(multiplier);
1814     model->operators.emplace_back(mul_op);
1815   } else {
1816     // Create graph:
1817     //   v -> RSQRT -> multiplier
1818     auto* rsqrt_op = new TensorFlowRsqrtOperator;
1819     rsqrt_op->inputs.push_back(node.input(2));
1820     rsqrt_op->outputs.push_back(multiplier);
1821     model->operators.emplace_back(rsqrt_op);
1822   }
1823 
1824   auto* op = new BatchNormalizationOperator;
1825   op->global_normalization = true;
1826 
1827   op->inputs.push_back(node.input(0));
1828   op->inputs.push_back(node.input(1));
1829   op->inputs.push_back(multiplier);
1830   op->inputs.push_back(node.input(3));
1831   op->outputs.push_back(node.name());
1832 
1833   model->operators.emplace_back(op);
1834   return tensorflow::Status::OK();
1835 }
1836 
ConvertFusedBatchNormOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1837 tensorflow::Status ConvertFusedBatchNormOperator(
1838     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1839     const ModelFlags& model_flags, Model* model) {
1840   CHECK((node.op() == "FusedBatchNorm") || (node.op() == "FusedBatchNormV3"));
1841   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1842 
1843   // Declare shortcuts for the inputs.
1844   const std::string& gamma_input = node.input(1);
1845   const std::string& beta_input = node.input(2);
1846   const std::string& moving_mean_input = node.input(3);
1847   const std::string& moving_variance_input = node.input(4);
1848 
1849   // Create an array holding the epsilon value (typically, 0.001).
1850   const std::string epsilon_array_name =
1851       CreateConstArray<ArrayDataType::kFloat>(model,
1852                                               node.name() + "_epsilon_array",
1853                                               {GetFloatAttr(node, "epsilon")});
1854 
1855   // Add epsilon to the moving variance.
1856   const std::string epsilon_add_op_name = node.name() + "_epsilon";
1857   auto* epsilon_add_op = new AddOperator;
1858   epsilon_add_op->inputs.push_back(moving_variance_input);
1859   epsilon_add_op->inputs.push_back(epsilon_array_name);
1860   epsilon_add_op->outputs.push_back(epsilon_add_op_name);
1861   model->operators.emplace_back(epsilon_add_op);
1862 
1863   // Take the inverse square root of the (variance + epsilon).
1864   const std::string rsqrt_op_name = node.name() + "_rsqrt";
1865   auto* rsqrt_op = new TensorFlowRsqrtOperator;
1866   rsqrt_op->inputs.push_back(epsilon_add_op_name);
1867   rsqrt_op->outputs.push_back(rsqrt_op_name);
1868   model->operators.emplace_back(rsqrt_op);
1869 
1870   // Multiply the result by gamma.
1871   const std::string multiplier = node.name() + "_mul";
1872   auto* mul_op = new MulOperator;
1873   mul_op->inputs.push_back(rsqrt_op_name);
1874   mul_op->inputs.push_back(gamma_input);
1875   mul_op->outputs.push_back(multiplier);
1876   model->operators.emplace_back(mul_op);
1877 
1878   // Now we have all required inputs for the BatchNormalizationOperator.
1879   auto* op = new BatchNormalizationOperator;
1880   op->global_normalization = true;
1881 
1882   op->inputs.push_back(node.input(0));
1883   op->inputs.push_back(moving_mean_input);
1884   op->inputs.push_back(multiplier);
1885   op->inputs.push_back(beta_input);
1886   op->outputs.push_back(node.name());
1887 
1888   model->operators.emplace_back(op);
1889   return tensorflow::Status::OK();
1890 }
1891 
ConvertSpaceToBatchNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1892 tensorflow::Status ConvertSpaceToBatchNDOperator(
1893     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1894     const ModelFlags& model_flags, Model* model) {
1895   CHECK_EQ(node.op(), "SpaceToBatchND");
1896   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1897   CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1898   CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1899   auto* op = new SpaceToBatchNDOperator;
1900   op->inputs.push_back(node.input(0));
1901   op->inputs.push_back(node.input(1));
1902   op->inputs.push_back(node.input(2));
1903   op->outputs.push_back(node.name());
1904   model->operators.emplace_back(op);
1905   return tensorflow::Status::OK();
1906 }
1907 
ConvertBatchToSpaceNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1908 tensorflow::Status ConvertBatchToSpaceNDOperator(
1909     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1910     const ModelFlags& model_flags, Model* model) {
1911   CHECK_EQ(node.op(), "BatchToSpaceND");
1912   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1913   CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1914   CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1915   auto* op = new BatchToSpaceNDOperator;
1916   op->inputs.push_back(node.input(0));
1917   op->inputs.push_back(node.input(1));
1918   op->inputs.push_back(node.input(2));
1919   op->outputs.push_back(node.name());
1920   model->operators.emplace_back(op);
1921   return tensorflow::Status::OK();
1922 }
1923 
1924 template <typename T>
ConvertReduceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1925 tensorflow::Status ConvertReduceOperator(
1926     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1927     const ModelFlags& model_flags, Model* model) {
1928   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1929   auto* op = new T;
1930   op->inputs.push_back(node.input(0));
1931   op->inputs.push_back(node.input(1));
1932   op->outputs.push_back(node.name());
1933   model->operators.emplace_back(op);
1934   if (HasAttr(node, "keepdims")) {
1935     op->keep_dims = GetBoolAttr(node, "keepdims");
1936   } else if (HasAttr(node, "keep_dims")) {
1937     op->keep_dims = GetBoolAttr(node, "keep_dims");
1938   }
1939   return tensorflow::Status::OK();
1940 }
1941 
1942 // TODO(b/139320642): Add test when fused op is supported.
ConvertSvdfOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1943 tensorflow::Status ConvertSvdfOperator(
1944     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1945     const ModelFlags& model_flags, Model* model) {
1946   CHECK_EQ(node.op(), "Svdf");
1947   const int input_size = GetInputsCount(node, tf_import_flags);
1948   QCHECK(input_size == 4 || input_size == 5)
1949       << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1950       << node.DebugString();
1951   bool has_bias = (input_size == 5);
1952   auto* op = new SvdfOperator;
1953   int index = 0;
1954   op->inputs.push_back(node.input(index++));
1955   op->inputs.push_back(node.input(index++));
1956   op->inputs.push_back(node.input(index++));
1957   if (has_bias) {
1958     op->inputs.push_back(node.input(index++));
1959   }
1960   op->inputs.push_back(node.input(index));
1961   op->outputs.push_back(node.name());
1962   if (node.attr().at("ActivationFunction").s() == "Relu") {
1963     op->fused_activation_function = FusedActivationFunctionType::kRelu;
1964   } else {
1965     op->fused_activation_function = FusedActivationFunctionType::kNone;
1966   }
1967   op->rank = node.attr().at("Rank").i();
1968   model->operators.emplace_back(op);
1969   return tensorflow::Status::OK();
1970 }
1971 
1972 // This is just bare bones support to get the shapes to propagate.
ConvertTransposeConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1973 tensorflow::Status ConvertTransposeConvOperator(
1974     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1975     const ModelFlags& model_flags, Model* model) {
1976   CHECK_EQ(node.op(), "Conv2DBackpropInput");
1977   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1978   auto* op = new TransposeConvOperator;
1979   op->inputs.push_back(node.input(0));
1980   op->inputs.push_back(node.input(1));
1981   op->inputs.push_back(node.input(2));
1982   op->outputs.push_back(node.name());
1983   const auto& strides = GetListAttr(node, "strides");
1984   op->stride_height = strides.i(1);
1985   op->stride_width = strides.i(2);
1986   CHECK_EQ(strides.i_size(), 4)
1987       << "Can only import TransposeConv ops with 4D strides. TensorFlow op \""
1988       << node.name() << "\" has " << strides.i_size() << "D strides.";
1989   CHECK((strides.i(0) == 1) && (strides.i(3) == 1))
1990       << "Can only import TransposeConv ops with striding along the height "
1991          "(1st) or width (2nd) axis. TensorFlow op \""
1992       << node.name() << "\" had strides:[ " << strides.i(0) << ", "
1993       << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "].";
1994   op->stride_height = strides.i(1);
1995   op->stride_width = strides.i(2);
1996   if (HasAttr(node, "dilations")) {
1997     const auto& dilations = GetListAttr(node, "dilations");
1998     CHECK_EQ(dilations.i_size(), 4)
1999         << "Dilation unsupported in TransposeConv. TensorFlow op \""
2000         << node.name() << "\" had dilations";
2001     CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
2002           (dilations.i(2) == 1) && (dilations.i(3) == 1))
2003         << "Dilation unsupported in TransposeConv. TensorFlow op \""
2004         << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
2005         << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
2006         << "].";
2007   }
2008 
2009   const std::string& weights_name = node.input(TransposeConvOperator::WEIGHTS);
2010   const std::string& transposed_weights_name = weights_name + "_transposed";
2011   // Check if a TransposeOperator was already created for these weights
2012   // (can happen when multiple layers share the same weights).
2013   const Operator* existing_transpose =
2014       GetOpWithOutput(*model, transposed_weights_name);
2015   if (existing_transpose) {
2016     CHECK(existing_transpose->type == OperatorType::kTranspose);
2017   } else {
2018     // Transpose weights from HWOI order to OHWI order, which is more efficient
2019     // for computation. (Note that TensorFlow considers the order as HWIO
2020     // because they consider this a backward conv, inverting the sense of
2021     // input/output.)
2022     TransposeOperator* transpose = new TransposeOperator;
2023     std::string perm_array = CreateConstArray<ArrayDataType::kInt32>(
2024         model, node.name() + "_transpose_perm", {2, 0, 1, 3});
2025     transpose->inputs = {weights_name, perm_array};
2026     transpose->outputs = {transposed_weights_name};
2027     model->operators.emplace_back(transpose);
2028   }
2029   op->inputs[1] = transposed_weights_name;
2030 
2031   auto const& padding = GetStringAttr(node, "padding");
2032   if (padding == "SAME") {
2033     op->padding.type = PaddingType::kSame;
2034   } else if (padding == "VALID") {
2035     op->padding.type = PaddingType::kValid;
2036   } else {
2037     LOG(FATAL) << "Only SAME and VALID padding supported on "
2038                   "Conv2DBackpropInput nodes.";
2039   }
2040   model->operators.emplace_back(op);
2041   return tensorflow::Status::OK();
2042 }
2043 
ConvertRangeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2044 tensorflow::Status ConvertRangeOperator(
2045     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2046     const ModelFlags& model_flags, Model* model) {
2047   CHECK_EQ(node.op(), "Range");
2048   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
2049   auto* op = new RangeOperator;
2050   if (HasAttr(node, "Tidx")) {
2051     const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
2052     CHECK(dtype == DT_UINT8 || dtype == DT_INT32 || dtype == DT_INT64 ||
2053           dtype == DT_FLOAT);
2054     op->dtype = ConvertDataType(dtype);
2055   }
2056   op->inputs.push_back(node.input(0));
2057   op->inputs.push_back(node.input(1));
2058   op->inputs.push_back(node.input(2));
2059   op->outputs.push_back(node.name());
2060 
2061   model->operators.emplace_back(op);
2062   return tensorflow::Status::OK();
2063 }
2064 
2065 // Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
2066 // they aren't the same thing.  tf.stack results in a "Pack" operator.  "Stack"
2067 // operators also exist, but involve manipulating the TF runtime stack, and are
2068 // not directly related to tf.stack() usage.
ConvertPackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2069 tensorflow::Status ConvertPackOperator(
2070     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2071     const ModelFlags& model_flags, Model* model) {
2072   CHECK_EQ(node.op(), "Pack");
2073   auto op = absl::make_unique<PackOperator>();
2074   const int num_inputs = GetInputsCount(node, tf_import_flags);
2075   QCHECK_GE(num_inputs, 1)
2076       << node.op()
2077       << " node expects at least 1 input other than control dependencies: "
2078       << node.DebugString();
2079   CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
2080   for (int i = 0; i < num_inputs; ++i) {
2081     op->inputs.push_back(node.input(i));
2082   }
2083   op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
2084   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2085   op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2086   op->outputs.push_back(node.name());
2087   model->operators.emplace_back(std::move(op));
2088   return tensorflow::Status::OK();
2089 }
2090 
ConvertUnpackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2091 tensorflow::Status ConvertUnpackOperator(
2092     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2093     const ModelFlags& model_flags, Model* model) {
2094   CHECK_EQ(node.op(), "Unpack");
2095   auto op = absl::make_unique<UnpackOperator>();
2096   const int num_inputs = GetInputsCount(node, tf_import_flags);
2097   QCHECK_EQ(num_inputs, 1);
2098   op->inputs.push_back(node.input(0));
2099   op->num = GetIntAttr(node, "num");
2100   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2101   op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2102 
2103   op->outputs.push_back(node.name());  // Implicit :0.
2104   for (int i = 1; i < op->num; ++i) {
2105     op->outputs.push_back(node.name() + ":" + std::to_string(i));
2106   }
2107   model->operators.emplace_back(std::move(op));
2108   return tensorflow::Status::OK();
2109 }
2110 
2111 // Some TensorFlow ops only occur in graph cycles, representing
2112 // control flow. We do not currently support control flow, so we wouldn't
2113 // be able to fully support such graphs, including performing inference,
2114 // anyway. However, rather than erroring out early on graphs being cyclic,
2115 // it helps to at least support these just enough to allow getting a
2116 // graph visualization. This is not trivial, as we require graphs to be
2117 // acyclic aside from RNN back-edges. The solution is to special-case
2118 // such ops as RNN back-edges, which is technically incorrect (does not
2119 // allow representing the op's semantics) but good enough to get a
2120 // graph visualization.
ConvertOperatorSpecialCasedAsRNNBackEdge(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2121 tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
2122     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2123     const ModelFlags& model_flags, Model* model) {
2124   // At the moment, the only type of operator special-cased in this way is
2125   // NextIteration, occurring only in control-flow cycles.
2126   CHECK_EQ(node.op(), "NextIteration");
2127   CHECK_EQ(node.input_size(), 1);
2128   auto* rnn_state = model->flags.add_rnn_states();
2129   // This RNN state is not explicitly created by the user, so it's
2130   // OK for some later graph transformation to discard it.
2131   rnn_state->set_discardable(true);
2132   rnn_state->set_state_array(node.name());
2133   rnn_state->set_back_edge_source_array(node.input(0));
2134   // TODO(tianjuny): Temporary set the size to 1 to avoid transient array
2135   // allocation crash. The real value should depend on the hidden_size of RNN.
2136   rnn_state->set_size(1);
2137   return tensorflow::Status::OK();
2138 }
2139 
ConvertShapeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2140 tensorflow::Status ConvertShapeOperator(
2141     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2142     const ModelFlags& model_flags, Model* model) {
2143   CHECK_EQ(node.op(), "Shape");
2144   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2145   const auto out_type =
2146       HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
2147   CHECK(out_type == DT_INT64 || out_type == DT_INT32);
2148   auto op = absl::make_unique<TensorFlowShapeOperator>();
2149   op->output_data_type = ConvertDataType(out_type);
2150   op->inputs.push_back(node.input(0));
2151   op->outputs.push_back(node.name());
2152   model->operators.push_back(std::move(op));
2153   return tensorflow::Status::OK();
2154 }
2155 
ConvertReverseSequenceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2156 tensorflow::Status ConvertReverseSequenceOperator(
2157     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2158     const ModelFlags& model_flags, Model* model) {
2159   CHECK_EQ(node.op(), "ReverseSequence");
2160   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2161   auto op = absl::make_unique<ReverseSequenceOperator>();
2162   if (HasAttr(node, "seq_dim")) {
2163     op->seq_dim = GetIntAttr(node, "seq_dim");
2164   }
2165   // In tf.reverse_sequence, batch_dim defaults to 0.
2166   op->batch_dim =
2167       HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0;
2168   const int num_inputs = GetInputsCount(node, tf_import_flags);
2169   for (int i = 0; i < num_inputs; ++i) {
2170     op->inputs.push_back(node.input(i));
2171   }
2172   op->outputs.push_back(node.name());
2173   model->operators.push_back(std::move(op));
2174   return tensorflow::Status::OK();
2175 }
2176 
StripCaretFromArrayNames(Model * model)2177 void StripCaretFromArrayNames(Model* model) {
2178   for (auto& op : model->operators) {
2179     for (auto& input : op->inputs) {
2180       input = std::string(absl::StripPrefix(input, "^"));
2181     }
2182     for (auto& output : op->outputs) {
2183       output = std::string(absl::StripPrefix(output, "^"));
2184     }
2185   }
2186   for (auto& array : model->GetArrayMap()) {
2187     if (absl::StartsWith(array.first, "^")) {
2188       LOG(FATAL) << "What?";
2189     }
2190   }
2191 }
2192 
StripZeroOutputIndexFromInputs(NodeDef * node)2193 void StripZeroOutputIndexFromInputs(NodeDef* node) {
2194   for (auto& input : *node->mutable_input()) {
2195     input = std::string(absl::StripSuffix(input, ":0"));
2196   }
2197 }
2198 
2199 // In TensorFlow GraphDef, when a node has multiple outputs, they are named
2200 // name:0, name:1, ...
2201 // where 'name' is the node's name(). Just 'name' is an equivalent shorthand
2202 // form for name:0.
2203 // A TensorFlow GraphDef does not explicitly list all the outputs of each node
2204 // (unlike inputs), it being implied by the node's name and operator type
2205 // (the latter implies the number of outputs).
2206 // This makes it non-trivial for us to reconstruct the list of all arrays
2207 // present in the graph and, for each operator, the list of its outputs.
2208 // We do that by taking advantage of the fact that
2209 // at least each node lists explicitly its inputs, so after we've loaded
2210 // all nodes, we can use that information.
AddExtraOutputs(Model * model)2211 void AddExtraOutputs(Model* model) {
2212   // Construct the list of all arrays consumed by anything in the graph.
2213   std::vector<std::string> consumed_arrays;
2214   // Add arrays consumed by an op.
2215   for (const auto& consumer_op : model->operators) {
2216     for (const std::string& input : consumer_op->inputs) {
2217       consumed_arrays.push_back(input);
2218     }
2219   }
2220   // Add global outputs of the model.
2221   for (const std::string& output_array : model->flags.output_arrays()) {
2222     consumed_arrays.push_back(output_array);
2223   }
2224   // Add arrays consumed by a RNN back-edge.
2225   for (const auto& rnn_state : model->flags.rnn_states()) {
2226     consumed_arrays.push_back(rnn_state.back_edge_source_array());
2227   }
2228   // Now add operator outputs so that all arrays that are consumed,
2229   // are produced.
2230   for (const std::string& consumed_array : consumed_arrays) {
2231     // Test if consumed_array is already the output of some op.
2232     // This has occurred in a model where separate nodes had names of the form
2233     // foo:$i with the same base name foo.
2234     if (GetOpWithOutput(*model, consumed_array)) {
2235       continue;
2236     }
2237     // Split the consumed array name into the form name:output_index.
2238     const std::vector<std::string>& split = absl::StrSplit(consumed_array, ':');
2239     // If not of the form name:output_index, then this is not an additional
2240     // output of a node with multiple outputs, so nothing to do here.
2241     if (split.size() != 2) {
2242       continue;
2243     }
2244     int output_index = 0;
2245     if (!absl::SimpleAtoi(split[1], &output_index)) {
2246       continue;
2247     }
2248     // Each op is initially recorded as producing at least the array that
2249     // has its name. We use that to identify the producer node.
2250     auto* producer_op = GetOpWithOutput(*model, split[0]);
2251     if (!producer_op) {
2252       continue;
2253     }
2254     // Add extra outputs to that producer node, all the way to the
2255     // output_index.
2256     while (producer_op->outputs.size() <= output_index) {
2257       using toco::port::StringF;
2258       producer_op->outputs.push_back(
2259           StringF("%s:%d", split[0], producer_op->outputs.size()));
2260     }
2261   }
2262 }
2263 
InlineAllFunctions(GraphDef * graphdef)2264 bool InlineAllFunctions(GraphDef* graphdef) {
2265   if (graphdef->library().function().empty()) {
2266     VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
2267     return false;
2268   }
2269 
2270   // Override "_noinline" attribute on all functions
2271   GraphDef graphdef_copy(*graphdef);
2272   for (auto& function :
2273        (*graphdef_copy.mutable_library()->mutable_function())) {
2274     auto* attributes = function.mutable_attr();
2275     if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
2276       (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
2277     }
2278   }
2279 
2280   // Construct minimum resources needed to use ExpandInlineFunctions().
2281   tensorflow::SessionOptions options;
2282   auto* device_count = options.config.mutable_device_count();
2283   device_count->insert({"CPU", 1});
2284   std::vector<std::unique_ptr<tensorflow::Device>> devices;
2285   TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
2286       options, "/job:localhost/replica:0/task:0", &devices));
2287 
2288   tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
2289                                             graphdef_copy.library());
2290   tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
2291   tensorflow::ProcessFunctionLibraryRuntime pflr(
2292       &device_mgr, tensorflow::Env::Default(), &options.config,
2293       TF_GRAPH_DEF_VERSION, &fld,
2294       options.config.graph_options().optimizer_options(), nullptr);
2295   tensorflow::FunctionLibraryRuntime* flr;
2296   flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
2297 
2298   tensorflow::Graph graph(fld);
2299   tensorflow::ImportGraphDefOptions gc_opts;
2300   gc_opts.validate_shape = false;
2301   const auto& tf_convert_status = tensorflow::ImportGraphDef(
2302       gc_opts, graphdef_copy, &graph, nullptr, nullptr);
2303   if (!tf_convert_status.ok()) {
2304     LOG(ERROR) << "tensorflow::ImportGraphDef failed with status: "
2305                << tf_convert_status.ToString();
2306     return false;
2307   }
2308 
2309   // Iterate over the graph until there are no more nodes to be inlined.
2310   bool graph_modified = false;
2311   while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
2312     graph_modified = true;
2313   }
2314 
2315   // Output inlined graph
2316   if (graph_modified) {
2317     LOG(INFO) << "Found and inlined TensorFlow functions.";
2318     graph.ToGraphDef(graphdef);
2319   }
2320   return graph_modified;
2321 }
2322 
ConvertTopKV2Operator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2323 tensorflow::Status ConvertTopKV2Operator(
2324     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2325     const ModelFlags& model_flags, Model* model) {
2326   CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
2327   auto op = absl::make_unique<TopKV2Operator>();
2328   op->inputs.push_back(node.input(0));
2329   // K can be encoded as attr (TopK) convert it to a const.
2330   if (HasAttr(node, "k")) {
2331     std::string k_array = CreateConstArray<ArrayDataType::kInt32>(
2332         model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
2333     op->inputs.push_back(k_array);
2334   } else {
2335     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2336     op->inputs.push_back(node.input(1));
2337   }
2338   // The op has two outputs.
2339   op->outputs.push_back(node.name());
2340   op->outputs.push_back(node.name() + ":1");
2341   model->operators.emplace_back(op.release());
2342   return tensorflow::Status::OK();
2343 }
2344 
ConvertDynamicPartitionOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2345 tensorflow::Status ConvertDynamicPartitionOperator(
2346     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2347     const ModelFlags& model_flags, Model* model) {
2348   auto op = absl::make_unique<DynamicPartitionOperator>();
2349   CHECK(HasAttr(node, "num_partitions"));
2350   op->num_partitions = GetIntAttr(node, "num_partitions");
2351   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2352   op->inputs.push_back(node.input(0));
2353   op->inputs.push_back(node.input(1));
2354   CHECK_GT(op->num_partitions, 1);
2355   op->outputs.push_back(node.name());  // Implicit :0.
2356   for (int i = 1; i < op->num_partitions; ++i) {
2357     op->outputs.push_back(node.name() + ":" + std::to_string(i));
2358   }
2359   model->operators.emplace_back(op.release());
2360   return tensorflow::Status::OK();
2361 }
2362 
ConvertDynamicStitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2363 tensorflow::Status ConvertDynamicStitchOperator(
2364     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2365     const ModelFlags& model_flags, Model* model) {
2366   // The parallel and non-parallel variants are the same besides whether they
2367   // have a parallel loop; there are no behavioral differences.
2368   CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
2369   auto op = absl::make_unique<DynamicStitchOperator>();
2370   CHECK(HasAttr(node, "N"));
2371   op->num_partitions = GetIntAttr(node, "N");
2372   // Expect all ID partitions + all value partitions.
2373   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
2374   for (int i = 0; i < op->num_partitions * 2; ++i) {
2375     op->inputs.push_back(node.input(i));
2376   }
2377   op->outputs.push_back(node.name());
2378   model->operators.emplace_back(op.release());
2379   return tensorflow::Status::OK();
2380 }
2381 
ConvertSparseToDenseOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2382 tensorflow::Status ConvertSparseToDenseOperator(
2383     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2384     const ModelFlags& model_flags, Model* model) {
2385   CHECK_EQ(node.op(), "SparseToDense");
2386   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2387 
2388   auto* op = new SparseToDenseOperator;
2389   for (const std::string& input : node.input()) {
2390     op->inputs.push_back(input);
2391   }
2392   op->outputs.push_back(node.name());
2393 
2394   op->validate_indices = HasAttr(node, "validate_indices")
2395                              ? GetBoolAttr(node, "validate_indices")
2396                              : true;
2397   model->operators.emplace_back(op);
2398   return tensorflow::Status::OK();
2399 }
2400 
ConvertOneHotOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2401 tensorflow::Status ConvertOneHotOperator(
2402     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2403     const ModelFlags& model_flags, Model* model) {
2404   CHECK_EQ(node.op(), "OneHot");
2405   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2406 
2407   const auto dtype = GetDataTypeAttr(node, "T");
2408   // TODO(b/111744875): Support DT_UINT8 and quantization.
2409   CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
2410         dtype == DT_BOOL);
2411 
2412   auto op = absl::make_unique<OneHotOperator>();
2413   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
2414   for (const std::string& input : node.input()) {
2415     op->inputs.push_back(input);
2416   }
2417   op->outputs.push_back(node.name());
2418   model->operators.emplace_back(op.release());
2419   return tensorflow::Status::OK();
2420 }
2421 
ConvertCTCBeamSearchDecoderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2422 tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
2423     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2424     const ModelFlags& model_flags, Model* model) {
2425   CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
2426   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2427 
2428   auto* op = new CTCBeamSearchDecoderOperator;
2429   for (const std::string& input : node.input()) {
2430     op->inputs.push_back(input);
2431   }
2432 
2433   op->beam_width =
2434       HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
2435   op->top_paths =
2436       HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
2437   op->merge_repeated = HasAttr(node, "merge_repeated")
2438                            ? GetBoolAttr(node, "merge_repeated")
2439                            : true;
2440 
2441   // There are top_paths + 1 outputs.
2442   op->outputs.push_back(node.name());  // Implicit :0.
2443   for (int i = 0; i < op->top_paths; ++i) {
2444     op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
2445   }
2446   model->operators.emplace_back(op);
2447   return tensorflow::Status::OK();
2448 }
2449 
2450 // This isn't a TensorFlow builtin op. Currently this node can only be generated
2451 // with TfLite OpHint API.
ConvertUnidirectionalSequenceLstm(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2452 tensorflow::Status ConvertUnidirectionalSequenceLstm(
2453     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2454     const ModelFlags& model_flags, Model* model) {
2455   DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
2456 
2457   const auto& indices = GetListAttr(node, "_tflite_input_indices");
2458 
2459   auto* op = new UnidirectionalSequenceLstmOperator();
2460 
2461   // The input size needs to be the same as the TfLite UniDirectionalSequence
2462   // Lstm implementation.
2463   const int kInputsSize = 20;
2464 
2465   op->inputs.resize(kInputsSize);
2466 
2467   if (indices.i_size() != node.input().size()) {
2468     // New version, the optional inputs are filled with constant nodes.
2469     int count = 0;
2470     for (int idx = 0; idx < kInputsSize; ++idx) {
2471       if (count < indices.i_size() && indices.i(count) == idx) {
2472         // Specified input.
2473         op->inputs[idx] = node.input(idx);
2474         count++;
2475       } else {
2476         // Optional input.
2477         std::string optional_name = node.name() + "_" + std::to_string(idx);
2478         model->CreateOptionalArray(optional_name);
2479         op->inputs[idx] = optional_name;
2480       }
2481     }
2482   } else {  // Legacy version.
2483     std::vector<bool> done(kInputsSize);
2484     int idx = 0;
2485     for (const std::string& input : node.input()) {
2486       int real_index = indices.i(idx);
2487       op->inputs[real_index] = (input);
2488       done[real_index] = true;
2489       idx++;
2490     }
2491 
2492     for (int idx = 0; idx < done.size(); idx++) {
2493       if (!done[idx]) {
2494         std::string optional_name = node.name() + "_" + std::to_string(idx);
2495         model->CreateOptionalArray(optional_name);
2496         op->inputs[idx] = optional_name;
2497       }
2498     }
2499   }
2500 
2501   // There're three outputs, only the last one is required.
2502   op->outputs.push_back(node.name() + ":2");
2503   model->operators.emplace_back(op);
2504 
2505   return tensorflow::Status::OK();
2506 }
2507 
ConvertLeakyReluOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2508 tensorflow::Status ConvertLeakyReluOperator(
2509     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2510     const ModelFlags& model_flags, Model* model) {
2511   CHECK_EQ(node.op(), "LeakyRelu");
2512   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2513   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
2514   const auto& input_name = node.input(0);
2515   auto* op = new LeakyReluOperator;
2516   op->inputs.push_back(input_name);
2517   op->outputs.push_back(node.name());
2518   op->alpha = GetFloatAttr(node, "alpha");
2519   model->operators.emplace_back(op);
2520   return tensorflow::Status::OK();
2521 }
2522 
ConvertUnidirectionalSequenceRnn(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2523 tensorflow::Status ConvertUnidirectionalSequenceRnn(
2524     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2525     const ModelFlags& model_flags, Model* model) {
2526   DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
2527 
2528   const auto& indices = GetListAttr(node, "_tflite_input_indices");
2529   if (indices.i_size() != node.input().size()) {
2530     return tensorflow::errors::InvalidArgument("Input size does not match.");
2531   }
2532 
2533   auto* op = new UnidirectionalSequenceRnnOperator();
2534   for (const std::string& input : node.input()) {
2535     op->inputs.push_back(input);
2536   }
2537   // Only use the last one as input.
2538   op->outputs.push_back(node.name() + ":1");
2539   model->operators.emplace_back(op);
2540 
2541   return tensorflow::Status::OK();
2542 }
2543 
2544 }  // namespace
2545 
2546 namespace internal {
2547 
2548 using ConverterType = tensorflow::Status (*)(
2549     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2550     const ModelFlags& model_flags, Model* model);
2551 using ConverterMapType = std::unordered_map<std::string, ConverterType>;
2552 
GetTensorFlowNodeConverterMapForFlex()2553 ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
2554   return std::unordered_map<std::string, ConverterType>({
2555       // We need to let TOCO convert Placeholder information into
2556       // array data, so that the data types are correct.
2557       {"LegacyFedInput", ConvertPlaceholderOperator},
2558       {"Placeholder", ConvertPlaceholderOperator},
2559       {"Const", ConditionallyConvertConstOperator},
2560   });
2561 }
2562 
GetTensorFlowNodeConverterMap()2563 ConverterMapType GetTensorFlowNodeConverterMap() {
2564   return std::unordered_map<std::string, ConverterType>({
2565       {"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
2566       {"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
2567       {"AddV2", ConvertSimpleOperator<AddOperator, 2, 1>},
2568       {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
2569       {"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
2570       {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
2571       {"ArgMax", ConvertArgMaxOperator},
2572       {"ArgMin", ConvertArgMinOperator},
2573       {"Assert",
2574        ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
2575       {"AvgPool", ConvertAvgPoolOperator},
2576       {"BatchMatMul", ConvertBatchMatMulOperator},
2577       {"BatchMatMulV2", ConvertBatchMatMulOperator},
2578       {"BatchNormWithGlobalNormalization",
2579        ConvertBatchNormWithGlobalNormalizationOperator},
2580       {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
2581       {"BiasAdd", ConvertBiasAddOperator},
2582       {"Cast", ConvertCastOperator},
2583       {"Ceil", ConvertCeilOperator},
2584       {"CheckNumerics", ConvertIdentityOperator},
2585       {"Concat", ConvertConcatOperator},
2586       {"ConcatV2", ConvertConcatOperator},
2587       {"Const", ConvertConstOperator},
2588       {"Conv2D", ConvertConvOperator},
2589       {"Conv2DBackpropInput", ConvertTransposeConvOperator},
2590       {"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
2591       {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
2592       {"DepthToSpace", ConvertDepthToSpaceOperator},
2593       {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
2594       {"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
2595       {"DynamicPartition", ConvertDynamicPartitionOperator},
2596       {"DynamicStitch", ConvertDynamicStitchOperator},
2597       {"Elu", ConvertSimpleOperator<EluOperator, 1, 1>},
2598       {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
2599       {"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
2600       {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
2601       {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
2602       {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
2603       {"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
2604       {"Floor", ConvertFloorOperator},
2605       {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
2606       {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
2607       {"FusedBatchNorm", ConvertFusedBatchNormOperator},
2608       {"FusedBatchNormV3", ConvertFusedBatchNormOperator},
2609       {"Gather", ConvertGatherOperator},
2610       {"GatherV2", ConvertGatherOperator},
2611       {"GatherNd", ConvertGatherNdOperator},
2612       {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
2613       {"GreaterEqual",
2614        ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
2615       {"Identity", ConvertIdentityOperator},
2616       {"IdentityN", ConvertIdentityNOperator},
2617       {"LRN", ConvertLRNOperator},
2618       {"LeakyRelu", ConvertLeakyReluOperator},
2619       {"LegacyFedInput", ConvertPlaceholderOperator},
2620       {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
2621       {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
2622       {"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
2623       {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
2624       {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
2625       {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
2626       {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
2627       {"MatMul", ConvertMatMulOperator},
2628       {"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
2629       {"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
2630       // `MatrixDiagV3` has an `align` attribute. However, Toco only converts
2631       // `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols,
2632       // padding_value` inputs. In this case, `align` can be ignored.
2633       {"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>},
2634       {"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
2635       {"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
2636       // `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts
2637       // `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this
2638       // case, `align` can be ignored.
2639       {"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>},
2640       {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
2641       {"MaxPool", ConvertMaxPoolOperator},
2642       {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
2643       {"Mean", ConvertReduceOperator<MeanOperator>},
2644       {"Merge",
2645        ConvertSimpleOperator<TensorFlowMergeOperator, kAnyNumInputs, 1>},
2646       {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
2647       {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
2648       {"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
2649       {"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
2650       {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
2651       {"NoOp", ConvertNoOpOperator},
2652       {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
2653       {"OneHot", ConvertOneHotOperator},
2654       {"Pack", ConvertPackOperator},
2655       {"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
2656       {"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
2657       {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
2658       {"Placeholder", ConvertPlaceholderOperator},
2659       {"PlaceholderWithDefault", ConvertIdentityOperator},
2660       {"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
2661       {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
2662       {"RandomUniform", ConvertRandomUniform},
2663       {"Range", ConvertRangeOperator},
2664       {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
2665       {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
2666       {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
2667       {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
2668       {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
2669       {"ResizeBilinear", ConvertResizeBilinearOperator},
2670       {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
2671       {"ReverseSequence", ConvertReverseSequenceOperator},
2672       {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
2673       {"Round", ConvertRoundOperator},
2674       {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
2675       {"ScatterNd", ConvertSimpleOperator<ScatterNdOperator, 3, 1>},
2676       {"SegmentSum", ConvertSimpleOperator<SegmentSumOperator, 2, 1>},
2677       {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
2678       {"SelectV2", ConvertSimpleOperator<SelectOperator, 3, 1>},
2679       {"Shape", ConvertShapeOperator},
2680       {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
2681       {"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
2682       {"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
2683       {"Softmax", ConvertSoftmaxOperator},
2684       {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
2685       {"SpaceToDepth", ConvertSpaceToDepthOperator},
2686       {"SparseToDense", ConvertSparseToDenseOperator},
2687       {"Split", ConvertSplitOperator},
2688       {"SplitV", ConvertSplitVOperator},
2689       {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
2690       {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
2691       {"SquaredDifference",
2692        ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
2693       {"Snapshot", ConvertIdentityOperator},
2694       {"Squeeze", ConvertSqueezeOperator},
2695       {"StopGradient", ConvertIdentityOperator},
2696       {"StridedSlice", ConvertStridedSliceOperator},
2697       {"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
2698       {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
2699       {"Svdf", ConvertSvdfOperator},
2700       {"Switch", ConvertSwitchOperator},
2701       {"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
2702       {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
2703       {"TopK", ConvertTopKV2Operator},
2704       {"TopKV2", ConvertTopKV2Operator},
2705       {"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
2706       {"Unpack", ConvertUnpackOperator},
2707       {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
2708       {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
2709       {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
2710       {"MirrorPad", ConvertMirrorPadOperator},
2711       {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
2712       {"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
2713   });
2714 }
2715 
ImportTensorFlowNode(const tensorflow::NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model,const ConverterMapType & converter_map)2716 tensorflow::Status ImportTensorFlowNode(
2717     const tensorflow::NodeDef& node,
2718     const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
2719     Model* model, const ConverterMapType& converter_map) {
2720   auto converter = converter_map.find(node.op());
2721   if (converter == converter_map.end()) {
2722     return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
2723                                       model);
2724   } else {
2725     return converter->second(node, tf_import_flags, model_flags, model);
2726   }
2727 }
2728 }  // namespace internal
2729 
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const GraphDef & tf_graph)2730 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2731     const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2732     const GraphDef& tf_graph) {
2733   LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
2734 
2735   GraphDef inlined_graph(tf_graph);
2736   if (InlineAllFunctions(&inlined_graph)) {
2737     LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
2738   }
2739 
2740   // Check input and output specification.
2741   for (const auto& specified_input_array : model_flags.input_arrays()) {
2742     CHECK(!absl::EndsWith(specified_input_array.name(), ":0"))
2743         << "Unsupported explicit zero output index: "
2744         << specified_input_array.name();
2745   }
2746   for (const std::string& specified_output_array :
2747        model_flags.output_arrays()) {
2748     CHECK(!absl::EndsWith(specified_output_array, ":0"))
2749         << "Unsupported explicit zero output index: " << specified_output_array;
2750   }
2751 
2752   Model* model = new Model;
2753   internal::ConverterMapType converter_map;
2754 
2755   // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
2756   // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
2757   // converted to TFLite Flex ops.
2758   if (!tf_import_flags.import_all_ops_as_unsupported) {
2759     converter_map = internal::GetTensorFlowNodeConverterMap();
2760   } else {
2761     converter_map = internal::GetTensorFlowNodeConverterMapForFlex();
2762   }
2763 
2764   for (auto node : inlined_graph.node()) {
2765     StripZeroOutputIndexFromInputs(&node);
2766     auto status = internal::ImportTensorFlowNode(
2767         node, tf_import_flags, model_flags, model, converter_map);
2768     CHECK(status.ok()) << status.error_message();
2769   }
2770 
2771   ResolveModelFlags(model_flags, model);
2772 
2773   StripCaretFromArrayNames(model);
2774   AddExtraOutputs(model);
2775   FixNoMissingArray(model);
2776   FixNoOrphanedArray(model);
2777   FixOperatorOrdering(model);
2778   CheckInvariants(*model);
2779 
2780   // if rnn state arrays are constant, make them transient
2781   for (const auto& rnn_state : model->flags.rnn_states()) {
2782     model->GetArray(rnn_state.state_array()).buffer = nullptr;
2783   }
2784 
2785   return std::unique_ptr<Model>(model);
2786 }
2787 
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const std::string & input_file_contents)2788 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2789     const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2790     const std::string& input_file_contents) {
2791   std::unique_ptr<GraphDef> tf_graph(new GraphDef);
2792   CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
2793 
2794   std::unique_ptr<GraphDef> pruned_graph =
2795       MaybeReplaceCompositeSubgraph(*tf_graph);
2796   if (pruned_graph) {
2797     tf_graph = std::move(pruned_graph);
2798   }
2799   return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
2800 }
2801 
GetPotentiallySupportedOps()2802 std::vector<std::string> GetPotentiallySupportedOps() {
2803   std::vector<std::string> supported_ops;
2804   const internal::ConverterMapType& converter_map =
2805       internal::GetTensorFlowNodeConverterMap();
2806 
2807   for (const auto& item : converter_map) {
2808     supported_ops.push_back(item.first);
2809   }
2810   return supported_ops;
2811 }
2812 
2813 }  // namespace toco
2814