1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "google/protobuf/map.h"
22 #include "google/protobuf/text_format.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/model.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/runtime/types.h"
35 #include "tensorflow/lite/toco/tensorflow_util.h"
36 #include "tensorflow/lite/toco/tooling_util.h"
37 
38 using tensorflow::DT_BOOL;
39 using tensorflow::DT_COMPLEX64;
40 using tensorflow::DT_FLOAT;
41 using tensorflow::DT_INT16;
42 using tensorflow::DT_INT32;
43 using tensorflow::DT_INT64;
44 using tensorflow::DT_UINT8;
45 using tensorflow::GraphDef;
46 using tensorflow::TensorProto;
47 
48 namespace toco {
49 namespace {
50 
GetTensorFlowDataType(ArrayDataType data_type,const string & error_location)51 tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
52                                            const string& error_location) {
53   switch (data_type) {
54     case ArrayDataType::kBool:
55       return tensorflow::DT_BOOL;
56     case ArrayDataType::kFloat:
57       return tensorflow::DT_FLOAT;
58     case ArrayDataType::kUint8:
59       return tensorflow::DT_UINT8;
60     case ArrayDataType::kInt32:
61       return tensorflow::DT_INT32;
62     case ArrayDataType::kInt64:
63       return tensorflow::DT_INT64;
64     case ArrayDataType::kString:
65       return tensorflow::DT_STRING;
66     case ArrayDataType::kComplex64:
67       return tensorflow::DT_COMPLEX64;
68     default:
69     case ArrayDataType::kNone:
70       LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
71                  << "' in " << error_location;
72       return tensorflow::DT_INVALID;
73   }
74 }
75 
GetTensorFlowDataTypeForOp(ArrayDataType data_type,const string & op_name)76 tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
77                                                 const string& op_name) {
78   return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
79 }
80 
GetTensorFlowDataType(const Model & model,const string & array_name)81 tensorflow::DataType GetTensorFlowDataType(const Model& model,
82                                            const string& array_name) {
83   return GetTensorFlowDataType(model.GetArray(array_name).data_type,
84                                "array '" + array_name + "'");
85 }
86 
87 // TensorFlow sometimes forbids what it calls "legacy scalars",
88 // which are 1-D shapes where the unique shape size is 1.
89 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
90 // For that reason, we generally avoid creating legacy scalars,
91 // by detecting the case where a 1-D shape would be of size 1 and
92 // replacing that by a 0-D shape.
93 // However, there is a special circumstance where we must not do that
94 // and must unconditionally create a 1-D shape even if it is going to
95 // be of size 1: that is the case of bias vectors, with BiasAdd nodes.
96 // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
97 // a depth of 1, that would be a legacy scalar, so in that case we
98 // must go ahead and keep the shape 1-D, letting it be a legacy scalar.
99 enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
100 
ExportFloatArray(const Shape & input_shape,const float * input_data,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)101 void ExportFloatArray(const Shape& input_shape, const float* input_data,
102                       TensorProto* output_tensor,
103                       LegacyScalarPolicy legacy_scalar_policy) {
104   output_tensor->set_dtype(DT_FLOAT);
105   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
106   auto* shape = output_tensor->mutable_tensor_shape();
107 
108   const int kDims = input_shape.dimensions_count();
109   if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
110       kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
111     for (int i = 0; i < kDims; ++i) {
112       shape->add_dim()->set_size(input_shape.dims(i));
113     }
114   }
115   output_tensor->set_tensor_content(
116       string(reinterpret_cast<const char*>(input_data),
117              sizeof(*input_data) * input_flat_size));
118 }
119 
ExportFloatArray(AxesOrder input_axes_order,const Shape & input_shape,const float * input_data,AxesOrder output_axes_order,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)120 void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
121                       const float* input_data, AxesOrder output_axes_order,
122                       TensorProto* output_tensor,
123                       LegacyScalarPolicy legacy_scalar_policy) {
124   CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
125   output_tensor->set_dtype(DT_FLOAT);
126   CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
127   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
128 
129   Shape shuffled_shape;
130   ShuffleDims(input_shape, input_axes_order, output_axes_order,
131               &shuffled_shape);
132   std::vector<float> shuffled_data(input_flat_size);
133   ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
134                input_data, shuffled_data.data());
135 
136   ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
137                    legacy_scalar_policy);
138 }
139 
HasAlreadyExportedConst(const string & name,const GraphDef & tensorflow_graph)140 bool HasAlreadyExportedConst(const string& name,
141                              const GraphDef& tensorflow_graph) {
142   for (const auto& node : tensorflow_graph.node()) {
143     if (node.op() == "Const" && node.name() == name) {
144       return true;
145     }
146   }
147   return false;
148 }
149 
ConvertFloatTensorConst(const string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph,LegacyScalarPolicy legacy_scalar_policy)150 void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
151                              const float* input_data,
152                              AxesOrder input_axes_order,
153                              AxesOrder output_axes_order,
154                              GraphDef* tensorflow_graph,
155                              LegacyScalarPolicy legacy_scalar_policy) {
156   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
157     return;
158   }
159   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
160   const_op->set_op("Const");
161   const_op->set_name(name);
162   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
163   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
164   ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
165                    tensor, legacy_scalar_policy);
166 }
167 
ConvertFloatTensorConst(const string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)168 void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
169                              const float* input_data,
170                              AxesOrder input_axes_order,
171                              AxesOrder output_axes_order,
172                              GraphDef* tensorflow_graph) {
173   ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
174                           output_axes_order, tensorflow_graph,
175                           LegacyScalarPolicy::kAvoidLegacyScalars);
176 }
177 
ConvertFloatTensorConst(const Model & model,const string & name,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)178 void ConvertFloatTensorConst(const Model& model, const string& name,
179                              AxesOrder input_axes_order,
180                              AxesOrder output_axes_order,
181                              GraphDef* tensorflow_graph) {
182   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
183     return;
184   }
185   CHECK(model.HasArray(name));
186   const auto& input_array = model.GetArray(name);
187   const auto& input_shape = input_array.shape();
188   CHECK(input_array.buffer);
189   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
190   const float* input_data =
191       input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
192   ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
193                           output_axes_order, tensorflow_graph);
194 }
195 
ConvertFloatTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)196 void ConvertFloatTensorConst(const Model& model, const string& name,
197                              GraphDef* tensorflow_graph) {
198   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
199     return;
200   }
201   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
202   const_op->set_op("Const");
203   const_op->set_name(name);
204   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
205   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
206   CHECK(model.HasArray(name));
207   const auto& input_array = model.GetArray(name);
208   const auto& input_shape = input_array.shape();
209   CHECK(input_array.buffer);
210   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
211   const float* input_data =
212       input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
213   ExportFloatArray(input_shape, input_data, tensor,
214                    LegacyScalarPolicy::kAvoidLegacyScalars);
215 }
216 
ConvertBoolTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)217 void ConvertBoolTensorConst(const Model& model, const string& name,
218                             GraphDef* tensorflow_graph) {
219   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
220     return;
221   }
222   CHECK(model.HasArray(name));
223   const auto& array = model.GetArray(name);
224   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
225   const_op->set_op("Const");
226   const_op->set_name(name);
227   (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
228   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
229   tensor->set_dtype(DT_BOOL);
230   const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
231   for (auto index : data) {
232     tensor->add_bool_val(index);
233   }
234   const auto& array_shape = array.shape();
235   auto* shape = tensor->mutable_tensor_shape();
236   for (int i = 0; i < array_shape.dimensions_count(); i++) {
237     shape->add_dim()->set_size(array_shape.dims(i));
238   }
239 }
240 
ConvertIntTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)241 void ConvertIntTensorConst(const Model& model, const string& name,
242                            GraphDef* tensorflow_graph) {
243   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
244     return;
245   }
246   CHECK(model.HasArray(name));
247   const auto& array = model.GetArray(name);
248   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
249   const_op->set_op("Const");
250   const_op->set_name(name);
251   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
252   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
253   tensor->set_dtype(DT_INT32);
254   const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
255   for (auto index : data) {
256     tensor->add_int_val(index);
257   }
258   const auto& array_shape = array.shape();
259   auto* shape = tensor->mutable_tensor_shape();
260   for (int i = 0; i < array_shape.dimensions_count(); i++) {
261     shape->add_dim()->set_size(array_shape.dims(i));
262   }
263 }
264 
CreateIntTensorConst(const string & name,const std::vector<int32> & data,const std::vector<int32> & shape,GraphDef * tensorflow_graph)265 void CreateIntTensorConst(const string& name, const std::vector<int32>& data,
266                           const std::vector<int32>& shape,
267                           GraphDef* tensorflow_graph) {
268   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
269     return;
270   }
271   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
272   const_op->set_op("Const");
273   const_op->set_name(name);
274   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
275   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
276   tensor->set_dtype(DT_INT32);
277   for (auto index : data) {
278     tensor->add_int_val(index);
279   }
280   auto* tensor_shape = tensor->mutable_tensor_shape();
281   int num_elements = 1;
282   for (int size : shape) {
283     tensor_shape->add_dim()->set_size(size);
284     num_elements *= size;
285   }
286   CHECK_EQ(num_elements, data.size());
287 }
288 
ConvertComplex64TensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)289 void ConvertComplex64TensorConst(const Model& model, const string& name,
290                                  GraphDef* tensorflow_graph) {
291   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
292     return;
293   }
294   CHECK(model.HasArray(name));
295   const auto& array = model.GetArray(name);
296   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
297   const_op->set_op("Const");
298   const_op->set_name(name);
299   (*const_op->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
300   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
301   tensor->set_dtype(DT_COMPLEX64);
302   const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data;
303   for (auto index : data) {
304     tensor->add_scomplex_val(std::real(index));
305     tensor->add_scomplex_val(std::imag(index));
306   }
307   const auto& array_shape = array.shape();
308   auto* shape = tensor->mutable_tensor_shape();
309   for (int i = 0; i < array_shape.dimensions_count(); i++) {
310     shape->add_dim()->set_size(array_shape.dims(i));
311   }
312 }
313 
CreateMatrixShapeTensorConst(const string & name,int rows,int cols,GraphDef * tensorflow_graph)314 void CreateMatrixShapeTensorConst(const string& name, int rows, int cols,
315                                   GraphDef* tensorflow_graph) {
316   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
317     return;
318   }
319   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
320   const_op->set_op("Const");
321   const_op->set_name(name);
322   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
323   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
324   tensor->set_dtype(DT_INT32);
325   const int32 data[2] = {cols, rows};
326   tensor->set_tensor_content(
327       string(reinterpret_cast<const char*>(data), sizeof(data)));
328   auto* shape = tensor->mutable_tensor_shape();
329   shape->add_dim()->set_size(2);
330 }
331 
CreateDummyConcatDimTensorConst(const string & name,int dim,GraphDef * tensorflow_graph)332 void CreateDummyConcatDimTensorConst(const string& name, int dim,
333                                      GraphDef* tensorflow_graph) {
334   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
335     return;
336   }
337   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
338   const_op->set_op("Const");
339   const_op->set_name(name);
340   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
341   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
342   tensor->set_dtype(DT_INT32);
343   tensor->add_int_val(dim);
344 }
345 
CreateReshapeShapeTensorConst(const string & name,const std::vector<int32> & shape,GraphDef * tensorflow_graph)346 void CreateReshapeShapeTensorConst(const string& name,
347                                    const std::vector<int32>& shape,
348                                    GraphDef* tensorflow_graph) {
349   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
350     return;
351   }
352   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
353   const_op->set_op("Const");
354   const_op->set_name(name);
355   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
356   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
357   tensor->set_dtype(DT_INT32);
358   for (auto s : shape) {
359     tensor->add_int_val(s);
360   }
361   // TensorFlow sometimes forbids what it calls "legacy scalars",
362   // which are shapes of size 1 where the unique shape size is 1.
363   // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
364   if (shape.size() > 1) {
365     auto* tensor_shape = tensor->mutable_tensor_shape();
366     tensor_shape->add_dim()->set_size(shape.size());
367   }
368 }
369 
WalkUpToConstantArray(const Model & model,const string & name)370 string WalkUpToConstantArray(const Model& model, const string& name) {
371   const Array& original_array = model.GetArray(name);
372   if (original_array.buffer) {
373     return name;
374   }
375   const auto* op = GetOpWithOutput(model, name);
376   CHECK(op);
377   CHECK(op->type == OperatorType::kFakeQuant);
378   const string& input_of_fakequant_name = op->inputs[0];
379   const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
380   CHECK(input_of_fakequant.buffer);
381   return input_of_fakequant_name;
382 }
383 
ConvertConvOperator(const Model & model,const ConvOperator & src_op,GraphDef * tensorflow_graph)384 void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
385                          GraphDef* tensorflow_graph) {
386   const bool has_bias = src_op.inputs.size() >= 3;
387   string conv_output = src_op.outputs[0];
388   if (has_bias) {
389     conv_output += "/conv";
390   }
391 
392   tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
393   conv2d_op->set_op("Conv2D");
394   conv2d_op->set_name(conv_output);
395   *conv2d_op->add_input() = src_op.inputs[0];
396   *conv2d_op->add_input() = src_op.inputs[1];
397   (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
398   const string& weights_array_name =
399       WalkUpToConstantArray(model, src_op.inputs[1]);
400   const auto& weights_array = model.GetArray(weights_array_name);
401   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
402   ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
403                           AxesOrder::kHWIO, tensorflow_graph);
404   auto& strides = (*conv2d_op->mutable_attr())["strides"];
405   strides.mutable_list()->add_i(1);
406   strides.mutable_list()->add_i(src_op.stride_height);
407   strides.mutable_list()->add_i(src_op.stride_width);
408   strides.mutable_list()->add_i(1);
409   if ((src_op.dilation_width_factor != 1) ||
410       (src_op.dilation_height_factor != 1)) {
411     auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
412     dilations.mutable_list()->add_i(1);
413     dilations.mutable_list()->add_i(src_op.dilation_height_factor);
414     dilations.mutable_list()->add_i(src_op.dilation_width_factor);
415     dilations.mutable_list()->add_i(1);
416   }
417   string padding;
418   if (src_op.padding.type == PaddingType::kSame) {
419     padding = "SAME";
420   } else if (src_op.padding.type == PaddingType::kValid) {
421     padding = "VALID";
422   } else {
423     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
424   }
425   (*conv2d_op->mutable_attr())["padding"].set_s(padding);
426 
427   if (has_bias) {
428     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
429     biasadd_op->set_op("BiasAdd");
430     biasadd_op->set_name(src_op.outputs[0]);
431     biasadd_op->add_input(conv_output);
432     biasadd_op->add_input(src_op.inputs[2]);
433     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
434     CHECK(model.HasArray(src_op.inputs[2]));
435     const string& bias_array_name =
436         WalkUpToConstantArray(model, src_op.inputs[2]);
437     const auto& bias_array = model.GetArray(bias_array_name);
438     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
439     Shape bias_shape_1d = bias_array.shape();
440     UnextendShape(&bias_shape_1d, 1);
441     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
442     const float* bias_data =
443         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
444     ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
445                             AxesOrder::kOneAxis, AxesOrder::kOneAxis,
446                             tensorflow_graph,
447                             LegacyScalarPolicy::kDoCreateLegacyScalars);
448   }
449 }
450 
ConvertDepthwiseConvOperator(const Model & model,const DepthwiseConvOperator & src_op,GraphDef * tensorflow_graph)451 void ConvertDepthwiseConvOperator(const Model& model,
452                                   const DepthwiseConvOperator& src_op,
453                                   GraphDef* tensorflow_graph) {
454   const bool has_bias = src_op.inputs.size() >= 3;
455   string conv_output = src_op.outputs[0];
456   if (has_bias) {
457     conv_output += "/conv";
458   }
459 
460   tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
461   dc2d_op->set_op("DepthwiseConv2dNative");
462   dc2d_op->set_name(conv_output);
463   *dc2d_op->add_input() = src_op.inputs[0];
464   *dc2d_op->add_input() = src_op.inputs[1];
465   (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
466 
467   // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
468   // We need to convert that to H x W x InputDepth x Multiplier.
469   // That's only a matter of constructing a Dims object; the actual
470   // array layout is the same.
471   CHECK(model.HasArray(src_op.inputs[1]));
472   const string& src_weights_name =
473       WalkUpToConstantArray(model, src_op.inputs[1]);
474   const auto& src_weights_array = model.GetArray(src_weights_name);
475   const auto& src_weights_shape = src_weights_array.shape();
476   CHECK_EQ(src_weights_shape.dimensions_count(), 4);
477   const Shape dst_weights_shape =
478       Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
479              src_weights_shape.dims(3) / src_op.depth_multiplier,
480              src_op.depth_multiplier});
481   CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
482   CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
483         src_weights_shape.dims(3));
484   CHECK_EQ(src_weights_shape.dims(0), 1);
485 
486   CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
487   const float* src_weights_data =
488       src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
489   ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
490                           AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
491 
492   auto& strides = (*dc2d_op->mutable_attr())["strides"];
493   strides.mutable_list()->add_i(1);
494   strides.mutable_list()->add_i(src_op.stride_height);
495   strides.mutable_list()->add_i(src_op.stride_width);
496   strides.mutable_list()->add_i(1);
497   // TODO(b/116063589): To return a working TF GraphDef, we should be returning
498   // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
499   // the conv since TF doesn't support dilations.
500   if ((src_op.dilation_width_factor != 1) ||
501       (src_op.dilation_height_factor != 1)) {
502     auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
503     dilations.mutable_list()->add_i(1);
504     dilations.mutable_list()->add_i(src_op.dilation_height_factor);
505     dilations.mutable_list()->add_i(src_op.dilation_width_factor);
506     dilations.mutable_list()->add_i(1);
507   }
508   string padding;
509   if (src_op.padding.type == PaddingType::kSame) {
510     padding = "SAME";
511   } else if (src_op.padding.type == PaddingType::kValid) {
512     padding = "VALID";
513   } else {
514     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
515   }
516   (*dc2d_op->mutable_attr())["padding"].set_s(padding);
517 
518   if (has_bias) {
519     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
520     biasadd_op->set_op("BiasAdd");
521     biasadd_op->set_name(src_op.outputs[0]);
522     biasadd_op->add_input(conv_output);
523     biasadd_op->add_input(src_op.inputs[2]);
524     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
525     CHECK(model.HasArray(src_op.inputs[2]));
526     const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
527     const auto& bias_array = model.GetArray(bias_name);
528     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
529     Shape bias_shape_1d = bias_array.shape();
530     UnextendShape(&bias_shape_1d, 1);
531     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
532     const float* bias_data =
533         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
534     ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
535                             AxesOrder::kOneAxis, AxesOrder::kOneAxis,
536                             tensorflow_graph,
537                             LegacyScalarPolicy::kDoCreateLegacyScalars);
538   }
539 }
540 
ConvertTransposeConvOperator(const Model & model,const TransposeConvOperator & src_op,GraphDef * tensorflow_graph)541 void ConvertTransposeConvOperator(const Model& model,
542                                   const TransposeConvOperator& src_op,
543                                   GraphDef* tensorflow_graph) {
544   tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
545   conv2d_op->set_op("Conv2DBackpropInput");
546   conv2d_op->set_name(src_op.outputs[0]);
547   *conv2d_op->add_input() = src_op.inputs[0];
548   *conv2d_op->add_input() = src_op.inputs[1];
549   *conv2d_op->add_input() = src_op.inputs[2];
550   (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
551   const string& weights_array_name = WalkUpToConstantArray(
552       model, src_op.inputs[TransposeConvOperator::WEIGHTS]);
553   const auto& weights_array = model.GetArray(weights_array_name);
554   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
555   ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
556                           AxesOrder::kHWOI, tensorflow_graph);
557   auto& strides = (*conv2d_op->mutable_attr())["strides"];
558   strides.mutable_list()->add_i(1);
559   strides.mutable_list()->add_i(src_op.stride_height);
560   strides.mutable_list()->add_i(src_op.stride_width);
561   strides.mutable_list()->add_i(1);
562   string padding;
563   if (src_op.padding.type == PaddingType::kSame) {
564     padding = "SAME";
565   } else if (src_op.padding.type == PaddingType::kValid) {
566     padding = "VALID";
567   } else {
568     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
569   }
570   (*conv2d_op->mutable_attr())["padding"].set_s(padding);
571 }
572 
ConvertDepthToSpaceOperator(const Model & model,const DepthToSpaceOperator & src_op,GraphDef * tensorflow_graph)573 void ConvertDepthToSpaceOperator(const Model& model,
574                                  const DepthToSpaceOperator& src_op,
575                                  GraphDef* tensorflow_graph) {
576   tensorflow::NodeDef* op = tensorflow_graph->add_node();
577   op->set_op("DepthToSpace");
578   op->set_name(src_op.outputs[0]);
579   *op->add_input() = src_op.inputs[0];
580   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
581   (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
582 }
583 
ConvertSpaceToDepthOperator(const Model & model,const SpaceToDepthOperator & src_op,GraphDef * tensorflow_graph)584 void ConvertSpaceToDepthOperator(const Model& model,
585                                  const SpaceToDepthOperator& src_op,
586                                  GraphDef* tensorflow_graph) {
587   tensorflow::NodeDef* op = tensorflow_graph->add_node();
588   op->set_op("SpaceToDepth");
589   op->set_name(src_op.outputs[0]);
590   *op->add_input() = src_op.inputs[0];
591   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
592   (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
593 }
594 
ConvertFullyConnectedOperator(const Model & model,const FullyConnectedOperator & src_op,GraphDef * tensorflow_graph)595 void ConvertFullyConnectedOperator(const Model& model,
596                                    const FullyConnectedOperator& src_op,
597                                    GraphDef* tensorflow_graph) {
598   // Reshape input activations to have the shape expected by the MatMul.
599   const string reshape_output =
600       AvailableArrayName(model, src_op.outputs[0] + "/reshape");
601   const string reshape_shape =
602       AvailableArrayName(model, reshape_output + "/shape");
603   const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
604   const auto& fc_weights_shape = fc_weights_array.shape();
605   CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
606   CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
607                                tensorflow_graph);
608   tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
609   reshape_op->set_op("Reshape");
610   reshape_op->set_name(reshape_output);
611   reshape_op->add_input(src_op.inputs[0]);
612   reshape_op->add_input(reshape_shape);
613   (*reshape_op->mutable_attr())["T"].set_type(
614       GetTensorFlowDataType(model, src_op.inputs[0]));
615 
616   const bool has_bias = src_op.inputs.size() >= 3;
617   string matmul_output = src_op.outputs[0];
618   if (has_bias) {
619     matmul_output += "/matmul";
620   }
621 
622   // Transpose the RHS input from column-major to row-major to match TensorFlow
623   // expectations. This is the inverse of the transpose we do during
624   // ResolveTensorFlowMatMul.
625   const string transpose_output =
626       AvailableArrayName(model, matmul_output + "/transpose_weights");
627   const string transpose_perm =
628       AvailableArrayName(model, transpose_output + "/perm");
629   CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
630   tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
631   transpose_op->set_op("Transpose");
632   transpose_op->set_name(transpose_output);
633   *transpose_op->add_input() = src_op.inputs[1];
634   *transpose_op->add_input() = transpose_perm;
635   (*transpose_op->mutable_attr())["T"].set_type(
636       GetTensorFlowDataType(model, src_op.inputs[1]));
637   (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
638 
639   tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
640   matmul_op->set_op("MatMul");
641   matmul_op->set_name(matmul_output);
642   *matmul_op->add_input() = reshape_output;
643   *matmul_op->add_input() = transpose_op->name();
644   (*matmul_op->mutable_attr())["T"].set_type(
645       GetTensorFlowDataType(model, src_op.inputs[0]));
646   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
647   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
648   CHECK(model.HasArray(src_op.inputs[1]));
649 
650   // Add the bias, if it exists.
651   if (has_bias) {
652     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
653     biasadd_op->set_op("BiasAdd");
654     biasadd_op->set_name(src_op.outputs[0]);
655     biasadd_op->add_input(matmul_output);
656     biasadd_op->add_input(src_op.inputs[2]);
657     (*biasadd_op->mutable_attr())["T"].set_type(
658         GetTensorFlowDataType(model, src_op.inputs[0]));
659     CHECK(model.HasArray(src_op.inputs[2]));
660     const auto& bias_array = model.GetArray(src_op.inputs[2]);
661     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
662     Shape bias_shape_1d = bias_array.shape();
663     UnextendShape(&bias_shape_1d, 1);
664     CHECK(bias_array.buffer);
665     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
666     const float* bias_data =
667         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
668     ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
669                             bias_shape_1d, bias_data, AxesOrder::kOneAxis,
670                             AxesOrder::kOneAxis, tensorflow_graph,
671                             LegacyScalarPolicy::kDoCreateLegacyScalars);
672   }
673 }
674 
ConvertAddOperator(const Model & model,const AddOperator & src_op,GraphDef * tensorflow_graph)675 void ConvertAddOperator(const Model& model, const AddOperator& src_op,
676                         GraphDef* tensorflow_graph) {
677   tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
678   add_op->set_op("Add");
679   add_op->set_name(src_op.outputs[0]);
680   CHECK_EQ(src_op.inputs.size(), 2);
681   *add_op->add_input() = src_op.inputs[0];
682   *add_op->add_input() = src_op.inputs[1];
683   (*add_op->mutable_attr())["T"].set_type(
684       GetTensorFlowDataType(model, src_op.outputs[0]));
685 }
686 
ConvertAddNOperator(const Model & model,const AddNOperator & src_op,GraphDef * tensorflow_graph)687 void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
688                          GraphDef* tensorflow_graph) {
689   tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
690   add_op->set_op("AddN");
691   add_op->set_name(src_op.outputs[0]);
692   for (const auto& input : src_op.inputs) {
693     *add_op->add_input() = input;
694   }
695   (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
696   (*add_op->mutable_attr())["T"].set_type(
697       GetTensorFlowDataType(model, src_op.outputs[0]));
698 }
699 
ConvertMulOperator(const Model & model,const MulOperator & src_op,GraphDef * tensorflow_graph)700 void ConvertMulOperator(const Model& model, const MulOperator& src_op,
701                         GraphDef* tensorflow_graph) {
702   tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
703   mul_op->set_op("Mul");
704   mul_op->set_name(src_op.outputs[0]);
705   CHECK_EQ(src_op.inputs.size(), 2);
706   *mul_op->add_input() = src_op.inputs[0];
707   *mul_op->add_input() = src_op.inputs[1];
708   (*mul_op->mutable_attr())["T"].set_type(
709       GetTensorFlowDataType(model, src_op.outputs[0]));
710 }
711 
ConvertDivOperator(const Model & model,const DivOperator & src_op,GraphDef * tensorflow_graph)712 void ConvertDivOperator(const Model& model, const DivOperator& src_op,
713                         GraphDef* tensorflow_graph) {
714   tensorflow::NodeDef* div_op = tensorflow_graph->add_node();
715   div_op->set_op("Div");
716   div_op->set_name(src_op.outputs[0]);
717   CHECK_EQ(src_op.inputs.size(), 2);
718   *div_op->add_input() = src_op.inputs[0];
719   *div_op->add_input() = src_op.inputs[1];
720   (*div_op->mutable_attr())["T"].set_type(
721       GetTensorFlowDataType(model, src_op.outputs[0]));
722 }
723 
ConvertReluOperator(const Model & model,const ReluOperator & src_op,GraphDef * tensorflow_graph)724 void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
725                          GraphDef* tensorflow_graph) {
726   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
727   relu_op->set_op("Relu");
728   relu_op->set_name(src_op.outputs[0]);
729   *relu_op->add_input() = src_op.inputs[0];
730   (*relu_op->mutable_attr())["T"].set_type(
731       GetTensorFlowDataType(model, src_op.outputs[0]));
732 }
733 
ConvertRelu1Operator(const Relu1Operator & src_op,GraphDef * tensorflow_graph)734 void ConvertRelu1Operator(const Relu1Operator& src_op,
735                           GraphDef* tensorflow_graph) {
736   const string max_bounds = src_op.outputs[0] + "/max_bounds";
737   const string min_bounds = src_op.outputs[0] + "/min_bounds";
738   const string max_output = src_op.outputs[0] + "/max_output";
739 
740   tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
741   max_bounds_const_op->set_op("Const");
742   max_bounds_const_op->set_name(max_bounds);
743   (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
744   auto* max_bounds_const_op_tensor =
745       (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
746   max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
747   max_bounds_const_op_tensor->add_float_val(-1.0f);
748 
749   tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
750   min_bounds_const_op->set_op("Const");
751   min_bounds_const_op->set_name(min_bounds);
752   (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
753   auto* min_bounds_const_op_tensor =
754       (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
755   min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
756   min_bounds_const_op_tensor->add_float_val(1.0f);
757 
758   tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
759   max_op->set_op("Maximum");
760   max_op->set_name(max_output);
761   *max_op->add_input() = src_op.inputs[0];
762   *max_op->add_input() = max_bounds;
763   (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
764 
765   tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
766   min_op->set_op("Minimum");
767   min_op->set_name(src_op.outputs[0]);
768   *min_op->add_input() = max_output;
769   *min_op->add_input() = min_bounds;
770   (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
771 }
772 
ConvertRelu6Operator(const Relu6Operator & src_op,GraphDef * tensorflow_graph)773 void ConvertRelu6Operator(const Relu6Operator& src_op,
774                           GraphDef* tensorflow_graph) {
775   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
776   relu_op->set_op("Relu6");
777   relu_op->set_name(src_op.outputs[0]);
778   *relu_op->add_input() = src_op.inputs[0];
779   (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
780 }
781 
ConvertLogOperator(const LogOperator & src_op,GraphDef * tensorflow_graph)782 void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
783   tensorflow::NodeDef* op = tensorflow_graph->add_node();
784   op->set_op("Log");
785   op->set_name(src_op.outputs[0]);
786   CHECK_EQ(src_op.inputs.size(), 1);
787   *op->add_input() = src_op.inputs[0];
788   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
789 }
790 
ConvertLogisticOperator(const LogisticOperator & src_op,GraphDef * tensorflow_graph)791 void ConvertLogisticOperator(const LogisticOperator& src_op,
792                              GraphDef* tensorflow_graph) {
793   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
794   relu_op->set_op("Sigmoid");
795   relu_op->set_name(src_op.outputs[0]);
796   *relu_op->add_input() = src_op.inputs[0];
797   (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
798 }
799 
ConvertTanhOperator(const TanhOperator & src_op,GraphDef * tensorflow_graph)800 void ConvertTanhOperator(const TanhOperator& src_op,
801                          GraphDef* tensorflow_graph) {
802   tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
803   tanh_op->set_op("Tanh");
804   tanh_op->set_name(src_op.outputs[0]);
805   *tanh_op->add_input() = src_op.inputs[0];
806   (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
807 }
808 
ConvertSoftmaxOperator(const Model & model,const SoftmaxOperator & src_op,GraphDef * tensorflow_graph)809 void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
810                             GraphDef* tensorflow_graph) {
811   string softmax_input;
812   Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
813   if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
814     softmax_input = src_op.inputs[0];
815   } else {
816     // Insert a reshape operator that reduces the dimensions down to the 2 that
817     // are required for TensorFlow Logits.
818     const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape";
819     const string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
820     softmax_input = reshape_output;
821 
822     tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
823     reshape_op->set_op("Reshape");
824     reshape_op->set_name(reshape_output);
825     *reshape_op->add_input() = src_op.inputs[0];
826     *reshape_op->add_input() = softmax_size;
827     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
828 
829     const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
830     int32 flattened_size = 1;
831     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
832       flattened_size *= input_shape.dims(i);
833     }
834     const std::vector<int32> shape_data = {
835         flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
836     CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
837   }
838 
839   tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
840   softmax_op->set_op("Softmax");
841   softmax_op->set_name(src_op.outputs[0]);
842   *softmax_op->add_input() = softmax_input;
843   // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
844   CHECK_EQ(src_op.beta, 1.f);
845   (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
846 }
847 
ConvertLogSoftmaxOperator(const Model & model,const LogSoftmaxOperator & src_op,GraphDef * tensorflow_graph)848 void ConvertLogSoftmaxOperator(const Model& model,
849                                const LogSoftmaxOperator& src_op,
850                                GraphDef* tensorflow_graph) {
851   string softmax_input;
852   Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
853   if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
854     softmax_input = src_op.inputs[0];
855   } else {
856     // Insert a reshape operator that reduces the dimensions down to the 2 that
857     // are required for TensorFlow Logits.
858     const string reshape_output =
859         src_op.outputs[0] + "/log_softmax_insert_reshape";
860     const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size";
861     softmax_input = reshape_output;
862 
863     tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
864     reshape_op->set_op("Reshape");
865     reshape_op->set_name(reshape_output);
866     *reshape_op->add_input() = src_op.inputs[0];
867     *reshape_op->add_input() = softmax_size;
868     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
869 
870     const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
871     int32 flattened_size = 1;
872     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
873       flattened_size *= input_shape.dims(i);
874     }
875     const std::vector<int32> shape_data = {
876         flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
877     CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
878   }
879 
880   tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
881   log_softmax_op->set_op("LogSoftmax");
882   log_softmax_op->set_name(src_op.outputs[0]);
883   *log_softmax_op->add_input() = softmax_input;
884   (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
885 }
886 
ConvertL2NormalizationOperator(const L2NormalizationOperator & src_op,GraphDef * tensorflow_graph)887 void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
888                                     GraphDef* tensorflow_graph) {
889   const string square_output = src_op.outputs[0] + "/square";
890   const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices";
891   const string sum_output = src_op.outputs[0] + "/sum";
892   const string rsqrt_output = src_op.outputs[0] + "/rsqrt";
893   const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
894 
895   tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
896   sum_reduction_indices_op->set_op("Const");
897   sum_reduction_indices_op->set_name(sum_reduction_indices);
898   (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
899   auto* sum_reduction_indices_tensor =
900       (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
901   sum_reduction_indices_tensor->set_dtype(DT_INT32);
902   auto* sum_reduction_indices_shape =
903       sum_reduction_indices_tensor->mutable_tensor_shape();
904   auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
905   sum_reduction_indices_dim->set_size(2);
906   sum_reduction_indices_tensor->add_int_val(0);
907   sum_reduction_indices_tensor->add_int_val(1);
908 
909   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
910   square_op->set_op("Square");
911   square_op->set_name(square_output);
912   *square_op->add_input() = src_op.inputs[0];
913   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
914 
915   tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
916   sum_op->set_op("Sum");
917   sum_op->set_name(sum_output);
918   *sum_op->add_input() = square_output;
919   *sum_op->add_input() = sum_reduction_indices;
920   (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
921 
922   tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
923   rsqrt_op->set_op("Rsqrt");
924   rsqrt_op->set_name(rsqrt_output);
925   *rsqrt_op->add_input() = sum_output;
926   (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
927 
928   tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
929   mul_op->set_op("Mul");
930   mul_op->set_name(src_op.outputs[0]);
931   *mul_op->add_input() = src_op.inputs[0];
932   *mul_op->add_input() = rsqrt_output;
933   (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
934 }
935 
ConvertLocalResponseNormalizationOperator(const LocalResponseNormalizationOperator & src_op,GraphDef * tensorflow_graph)936 void ConvertLocalResponseNormalizationOperator(
937     const LocalResponseNormalizationOperator& src_op,
938     GraphDef* tensorflow_graph) {
939   tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
940   lrn_op->set_op("LRN");
941   lrn_op->set_name(src_op.outputs[0]);
942   *lrn_op->add_input() = src_op.inputs[0];
943   (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
944   (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
945   (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
946   (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
947 }
948 
ConvertFakeQuantOperator(const FakeQuantOperator & src_op,GraphDef * tensorflow_graph)949 void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
950                               GraphDef* tensorflow_graph) {
951   tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
952   fakequant_op->set_op("FakeQuantWithMinMaxArgs");
953   fakequant_op->set_name(src_op.outputs[0]);
954   CHECK_EQ(src_op.inputs.size(), 1);
955   *fakequant_op->add_input() = src_op.inputs[0];
956   CHECK(src_op.minmax);
957   (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
958   (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
959   if (src_op.num_bits) {
960     (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
961   }
962   if (src_op.narrow_range) {
963     (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
964   }
965 }
966 
ConvertMaxPoolOperator(const MaxPoolOperator & src_op,GraphDef * tensorflow_graph)967 void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
968                             GraphDef* tensorflow_graph) {
969   tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
970   maxpool_op->set_op("MaxPool");
971   maxpool_op->set_name(src_op.outputs[0]);
972   *maxpool_op->add_input() = src_op.inputs[0];
973   auto& strides = (*maxpool_op->mutable_attr())["strides"];
974   strides.mutable_list()->add_i(1);
975   strides.mutable_list()->add_i(src_op.stride_height);
976   strides.mutable_list()->add_i(src_op.stride_width);
977   strides.mutable_list()->add_i(1);
978   string padding;
979   if (src_op.padding.type == PaddingType::kSame) {
980     padding = "SAME";
981   } else if (src_op.padding.type == PaddingType::kValid) {
982     padding = "VALID";
983   } else {
984     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
985   }
986   (*maxpool_op->mutable_attr())["padding"].set_s(padding);
987   (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
988   auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
989   ksize.mutable_list()->add_i(1);
990   ksize.mutable_list()->add_i(src_op.kheight);
991   ksize.mutable_list()->add_i(src_op.kwidth);
992   ksize.mutable_list()->add_i(1);
993 }
994 
ConvertAveragePoolOperator(const AveragePoolOperator & src_op,GraphDef * tensorflow_graph)995 void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
996                                 GraphDef* tensorflow_graph) {
997   tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
998   avgpool_op->set_op("AvgPool");
999   avgpool_op->set_name(src_op.outputs[0]);
1000   *avgpool_op->add_input() = src_op.inputs[0];
1001   auto& strides = (*avgpool_op->mutable_attr())["strides"];
1002   strides.mutable_list()->add_i(1);
1003   strides.mutable_list()->add_i(src_op.stride_height);
1004   strides.mutable_list()->add_i(src_op.stride_width);
1005   strides.mutable_list()->add_i(1);
1006   string padding;
1007   if (src_op.padding.type == PaddingType::kSame) {
1008     padding = "SAME";
1009   } else if (src_op.padding.type == PaddingType::kValid) {
1010     padding = "VALID";
1011   } else {
1012     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1013   }
1014   (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1015   (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1016   auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1017   ksize.mutable_list()->add_i(1);
1018   ksize.mutable_list()->add_i(src_op.kheight);
1019   ksize.mutable_list()->add_i(src_op.kwidth);
1020   ksize.mutable_list()->add_i(1);
1021 }
1022 
ConvertConcatenationOperator(const Model & model,const ConcatenationOperator & src_op,GraphDef * tensorflow_graph)1023 void ConvertConcatenationOperator(const Model& model,
1024                                   const ConcatenationOperator& src_op,
1025                                   GraphDef* tensorflow_graph) {
1026   tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
1027   dc_op->set_op("ConcatV2");
1028   dc_op->set_name(src_op.outputs[0]);
1029   const string dummy_axis = src_op.outputs[0] + "/axis";
1030   CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
1031   for (const auto& input : src_op.inputs) {
1032     *dc_op->add_input() = input;
1033   }
1034   *dc_op->add_input() = dummy_axis;
1035   (*dc_op->mutable_attr())["T"].set_type(
1036       GetTensorFlowDataType(model, src_op.inputs[0]));
1037   (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1038   (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1039 }
1040 
ConvertTensorFlowReshapeOperator(const Model & model,const TensorFlowReshapeOperator & src_op,GraphDef * tensorflow_graph)1041 void ConvertTensorFlowReshapeOperator(const Model& model,
1042                                       const TensorFlowReshapeOperator& src_op,
1043                                       GraphDef* tensorflow_graph) {
1044   tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
1045   reshape_op->set_op("Reshape");
1046   reshape_op->set_name(src_op.outputs[0]);
1047   CHECK_EQ(src_op.inputs.size(), 2);
1048   *reshape_op->add_input() = src_op.inputs[0];
1049   *reshape_op->add_input() = src_op.inputs[1];
1050   (*reshape_op->mutable_attr())["T"].set_type(
1051       GetTensorFlowDataType(model, src_op.outputs[0]));
1052   const auto& shape_array = model.GetArray(src_op.inputs[1]);
1053   QCHECK(shape_array.data_type == ArrayDataType::kInt32)
1054       << "Only int32 shape is supported.";
1055   QCHECK(shape_array.buffer != nullptr)
1056       << "Shape inferred at runtime is not supported.";
1057   const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1058   CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
1059 }
1060 
ConvertL2PoolOperator(const L2PoolOperator & src_op,GraphDef * tensorflow_graph)1061 void ConvertL2PoolOperator(const L2PoolOperator& src_op,
1062                            GraphDef* tensorflow_graph) {
1063   const string square_output = src_op.outputs[0] + "/square";
1064   const string avgpool_output = src_op.outputs[0] + "/avgpool";
1065 
1066   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1067   square_op->set_op("Square");
1068   square_op->set_name(square_output);
1069   *square_op->add_input() = src_op.inputs[0];
1070   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1071 
1072   string padding;
1073   if (src_op.padding.type == PaddingType::kSame) {
1074     padding = "SAME";
1075   } else if (src_op.padding.type == PaddingType::kValid) {
1076     padding = "VALID";
1077   } else {
1078     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1079   }
1080 
1081   tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1082   avgpool_op->set_op("AvgPool");
1083   avgpool_op->set_name(avgpool_output);
1084   *avgpool_op->add_input() = square_output;
1085   auto& strides = (*avgpool_op->mutable_attr())["strides"];
1086   strides.mutable_list()->add_i(1);
1087   strides.mutable_list()->add_i(src_op.stride_height);
1088   strides.mutable_list()->add_i(src_op.stride_width);
1089   strides.mutable_list()->add_i(1);
1090 
1091   (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1092   (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1093   auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1094   ksize.mutable_list()->add_i(1);
1095   ksize.mutable_list()->add_i(src_op.kheight);
1096   ksize.mutable_list()->add_i(src_op.kwidth);
1097   ksize.mutable_list()->add_i(1);
1098 
1099   tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1100   sqrt_op->set_op("Sqrt");
1101   sqrt_op->set_name(src_op.outputs[0]);
1102   *sqrt_op->add_input() = avgpool_output;
1103   (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1104 }
1105 
ConvertSquareOperator(const TensorFlowSquareOperator & src_op,GraphDef * tensorflow_graph)1106 void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
1107                            GraphDef* tensorflow_graph) {
1108   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1109   square_op->set_op("Square");
1110   square_op->set_name(src_op.outputs[0]);
1111   CHECK_EQ(src_op.inputs.size(), 1);
1112   *square_op->add_input() = src_op.inputs[0];
1113   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1114 }
1115 
ConvertSqrtOperator(const TensorFlowSqrtOperator & src_op,GraphDef * tensorflow_graph)1116 void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
1117                          GraphDef* tensorflow_graph) {
1118   tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1119   sqrt_op->set_op("Sqrt");
1120   sqrt_op->set_name(src_op.outputs[0]);
1121   CHECK_EQ(src_op.inputs.size(), 1);
1122   *sqrt_op->add_input() = src_op.inputs[0];
1123   (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1124 }
1125 
ConvertRsqrtOperator(const Model & model,const TensorFlowRsqrtOperator & src_op,GraphDef * tensorflow_graph)1126 void ConvertRsqrtOperator(const Model& model,
1127                           const TensorFlowRsqrtOperator& src_op,
1128                           GraphDef* tensorflow_graph) {
1129   tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
1130   rsqrt_op->set_op("Rsqrt");
1131   rsqrt_op->set_name(src_op.outputs[0]);
1132   CHECK_EQ(src_op.inputs.size(), 1);
1133   *rsqrt_op->add_input() = src_op.inputs[0];
1134   const tensorflow::DataType data_type =
1135       GetTensorFlowDataType(model, src_op.inputs[0]);
1136   (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
1137 }
1138 
ConvertSplitOperator(const Model & model,const TensorFlowSplitOperator & src_op,GraphDef * tensorflow_graph)1139 void ConvertSplitOperator(const Model& model,
1140                           const TensorFlowSplitOperator& src_op,
1141                           GraphDef* tensorflow_graph) {
1142   tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1143   split_op->set_op("Split");
1144   split_op->set_name(src_op.outputs[0]);
1145   for (const auto& input : src_op.inputs) {
1146     *split_op->add_input() = input;
1147   }
1148   (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1149   (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1150   const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1151   CHECK(split_dim_array.buffer);
1152   CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1153   const auto& split_dim_data =
1154       split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1155   CHECK_EQ(split_dim_data.size(), 1);
1156   const int split_dim = split_dim_data[0];
1157   CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1158                                   tensorflow_graph);
1159 }
1160 
ConvertSplitVOperator(const Model & model,const TensorFlowSplitVOperator & src_op,GraphDef * tensorflow_graph)1161 void ConvertSplitVOperator(const Model& model,
1162                            const TensorFlowSplitVOperator& src_op,
1163                            GraphDef* tensorflow_graph) {
1164   tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node();
1165   split_v_op->set_op("SplitV");
1166   split_v_op->set_name(src_op.outputs[0]);
1167   for (const auto& input : src_op.inputs) {
1168     *split_v_op->add_input() = input;
1169   }
1170   (*split_v_op->mutable_attr())["T"].set_type(
1171       GetTensorFlowDataType(model, src_op.inputs[0]));
1172   (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1173   const auto& split_dim_array = model.GetArray(src_op.inputs[1]);
1174   CHECK(split_dim_array.buffer);
1175   CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1176   const auto& split_dim_data =
1177       split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1178   CHECK_EQ(split_dim_data.size(), 1);
1179   const int split_dim = split_dim_data[0];
1180   CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1181                                   tensorflow_graph);
1182 }
1183 
ConvertCastOperator(const Model & model,const CastOperator & src_op,GraphDef * tensorflow_graph)1184 void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1185                          GraphDef* tensorflow_graph) {
1186   tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
1187   cast_op->set_op("Cast");
1188   cast_op->set_name(src_op.outputs[0]);
1189   CHECK_EQ(src_op.inputs.size(), 1);
1190   *cast_op->add_input() = src_op.inputs[0];
1191 
1192   (*cast_op->mutable_attr())["DstT"].set_type(
1193       GetTensorFlowDataType(model, src_op.outputs[0]));
1194   (*cast_op->mutable_attr())["SrcT"].set_type(
1195       GetTensorFlowDataType(model, src_op.inputs[0]));
1196 }
1197 
ConvertFloorOperator(const Model & model,const FloorOperator & src_op,GraphDef * tensorflow_graph)1198 void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1199                           GraphDef* tensorflow_graph) {
1200   tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
1201   floor_op->set_op("Floor");
1202   floor_op->set_name(src_op.outputs[0]);
1203   CHECK_EQ(src_op.inputs.size(), 1);
1204   *floor_op->add_input() = src_op.inputs[0];
1205   (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1206 }
1207 
ConvertCeilOperator(const Model & model,const CeilOperator & src_op,GraphDef * tensorflow_graph)1208 void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
1209                          GraphDef* tensorflow_graph) {
1210   tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node();
1211   ceil_op->set_op("Ceil");
1212   ceil_op->set_name(src_op.outputs[0]);
1213   CHECK_EQ(src_op.inputs.size(), 1);
1214   *ceil_op->add_input() = src_op.inputs[0];
1215   (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
1216 }
1217 
ConvertGatherOperator(const Model & model,const GatherOperator & src_op,GraphDef * tensorflow_graph)1218 void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1219                            GraphDef* tensorflow_graph) {
1220   tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
1221   gather_op->set_op("GatherV2");
1222   gather_op->set_name(src_op.outputs[0]);
1223   *gather_op->add_input() = src_op.inputs[0];
1224   *gather_op->add_input() = src_op.inputs[1];
1225 
1226   if (!src_op.axis) {
1227     // Dynamic axis.
1228     CHECK_EQ(src_op.inputs.size(), 3);
1229     *gather_op->add_input() = src_op.inputs[2];
1230   } else {
1231     // Constant axis.
1232     CHECK_EQ(src_op.inputs.size(), 2);
1233     const string gather_axis =
1234         AvailableArrayName(model, gather_op->name() + "/axis");
1235     CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
1236                          tensorflow_graph);
1237     *gather_op->add_input() = gather_axis;
1238   }
1239 
1240   (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1241   (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
1242   const tensorflow::DataType params_type =
1243       GetTensorFlowDataType(model, src_op.inputs[0]);
1244   (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1245 }
1246 
ConvertArgMaxOperator(const Model & model,const ArgMaxOperator & src_op,GraphDef * tensorflow_graph)1247 void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1248                            GraphDef* tensorflow_graph) {
1249   tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
1250   argmax_op->set_op("ArgMax");
1251   argmax_op->set_name(src_op.outputs[0]);
1252   CHECK_EQ(src_op.inputs.size(), 2);
1253   *argmax_op->add_input() = src_op.inputs[0];
1254   *argmax_op->add_input() = src_op.inputs[1];
1255   (*argmax_op->mutable_attr())["T"].set_type(
1256       GetTensorFlowDataType(model, src_op.inputs[0]));
1257   (*argmax_op->mutable_attr())["Tidx"].set_type(
1258       GetTensorFlowDataType(model, src_op.inputs[1]));
1259   (*argmax_op->mutable_attr())["output_type"].set_type(
1260       GetTensorFlowDataType(model, src_op.outputs[0]));
1261 }
1262 
ConvertArgMinOperator(const Model & model,const ArgMinOperator & src_op,GraphDef * tensorflow_graph)1263 void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
1264                            GraphDef* tensorflow_graph) {
1265   tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
1266   argmin_op->set_op("ArgMin");
1267   argmin_op->set_name(src_op.outputs[0]);
1268   CHECK_EQ(src_op.inputs.size(), 2);
1269   *argmin_op->add_input() = src_op.inputs[0];
1270   *argmin_op->add_input() = src_op.inputs[1];
1271   (*argmin_op->mutable_attr())["T"].set_type(
1272       GetTensorFlowDataType(model, src_op.inputs[0]));
1273   (*argmin_op->mutable_attr())["Tidx"].set_type(
1274       GetTensorFlowDataType(model, src_op.inputs[1]));
1275   (*argmin_op->mutable_attr())["output_type"].set_type(
1276       GetTensorFlowDataType(model, src_op.outputs[0]));
1277 }
1278 
ConvertTransposeOperator(const Model & model,const TransposeOperator & src_op,GraphDef * tensorflow_graph)1279 void ConvertTransposeOperator(const Model& model,
1280                               const TransposeOperator& src_op,
1281                               GraphDef* tensorflow_graph) {
1282   tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
1283   transpose_op->set_op("Transpose");
1284   transpose_op->set_name(src_op.outputs[0]);
1285   CHECK_EQ(src_op.inputs.size(), 2);
1286   *transpose_op->add_input() = src_op.inputs[0];
1287   *transpose_op->add_input() = src_op.inputs[1];
1288   (*transpose_op->mutable_attr())["T"].set_type(
1289       GetTensorFlowDataType(model, src_op.inputs[0]));
1290   (*transpose_op->mutable_attr())["Tperm"].set_type(
1291       GetTensorFlowDataType(model, src_op.inputs[1]));
1292 }
1293 
ConvertTensorFlowShapeOperator(const Model & model,const TensorFlowShapeOperator & src_op,GraphDef * tensorflow_graph)1294 void ConvertTensorFlowShapeOperator(const Model& model,
1295                                     const TensorFlowShapeOperator& src_op,
1296                                     GraphDef* tensorflow_graph) {
1297   tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
1298   shape_op->set_op("Shape");
1299   shape_op->set_name(src_op.outputs[0]);
1300   CHECK_EQ(src_op.inputs.size(), 1);
1301   *shape_op->add_input() = src_op.inputs[0];
1302   (*shape_op->mutable_attr())["T"].set_type(
1303       GetTensorFlowDataType(model, src_op.inputs[0]));
1304   (*shape_op->mutable_attr())["out_type"].set_type(
1305       GetTensorFlowDataType(model, src_op.outputs[0]));
1306 }
1307 
ConvertRankOperator(const Model & model,const TensorFlowRankOperator & src_op,GraphDef * tensorflow_graph)1308 void ConvertRankOperator(const Model& model,
1309                          const TensorFlowRankOperator& src_op,
1310                          GraphDef* tensorflow_graph) {
1311   tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
1312   rank_op->set_op("Rank");
1313   rank_op->set_name(src_op.outputs[0]);
1314   CHECK_EQ(src_op.inputs.size(), 1);
1315   *rank_op->add_input() = src_op.inputs[0];
1316   (*rank_op->mutable_attr())["T"].set_type(
1317       GetTensorFlowDataType(model, src_op.inputs[0]));
1318 }
1319 
ConvertRangeOperator(const Model & model,const RangeOperator & src_op,GraphDef * tensorflow_graph)1320 void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1321                           GraphDef* tensorflow_graph) {
1322   tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
1323   range_op->set_op("Range");
1324   range_op->set_name(src_op.outputs[0]);
1325   CHECK_EQ(src_op.inputs.size(), 3);
1326   *range_op->add_input() = src_op.inputs[0];
1327   *range_op->add_input() = src_op.inputs[1];
1328   *range_op->add_input() = src_op.inputs[2];
1329   (*range_op->mutable_attr())["Tidx"].set_type(
1330       GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
1331 }
1332 
ConvertPackOperator(const Model & model,const PackOperator & src_op,GraphDef * tensorflow_graph)1333 void ConvertPackOperator(const Model& model, const PackOperator& src_op,
1334                          GraphDef* tensorflow_graph) {
1335   tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
1336   pack_op->set_op("Pack");
1337   pack_op->set_name(src_op.outputs[0]);
1338   for (const auto& input : src_op.inputs) {
1339     *pack_op->add_input() = input;
1340   }
1341   (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
1342   (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1343   (*pack_op->mutable_attr())["T"].set_type(
1344       GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1345 }
1346 
ConvertFillOperator(const Model & model,const FillOperator & src_op,GraphDef * tensorflow_graph)1347 void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1348                          GraphDef* tensorflow_graph) {
1349   tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
1350   fill_op->set_op("Fill");
1351   fill_op->set_name(src_op.outputs[0]);
1352   CHECK_EQ(src_op.inputs.size(), 2);
1353   *fill_op->add_input() = src_op.inputs[0];
1354   *fill_op->add_input() = src_op.inputs[1];
1355   (*fill_op->mutable_attr())["index_type"].set_type(
1356       GetTensorFlowDataType(model, src_op.inputs[0]));
1357   (*fill_op->mutable_attr())["T"].set_type(
1358       GetTensorFlowDataType(model, src_op.inputs[1]));
1359 }
1360 
ConvertFloorDivOperator(const Model & model,const FloorDivOperator & src_op,GraphDef * tensorflow_graph)1361 void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1362                              GraphDef* tensorflow_graph) {
1363   tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
1364   floor_div_op->set_op("FloorDiv");
1365   floor_div_op->set_name(src_op.outputs[0]);
1366   CHECK_EQ(src_op.inputs.size(), 2);
1367   *floor_div_op->add_input() = src_op.inputs[0];
1368   *floor_div_op->add_input() = src_op.inputs[1];
1369   (*floor_div_op->mutable_attr())["T"].set_type(
1370       GetTensorFlowDataType(model, src_op.inputs[0]));
1371 }
1372 
ConvertFloorModOperator(const Model & model,const FloorModOperator & src_op,GraphDef * tensorflow_graph)1373 void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op,
1374                              GraphDef* tensorflow_graph) {
1375   tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node();
1376   floor_mod_op->set_op("FloorMod");
1377   floor_mod_op->set_name(src_op.outputs[0]);
1378   DCHECK_EQ(src_op.inputs.size(), 2);
1379   *floor_mod_op->add_input() = src_op.inputs[0];
1380   *floor_mod_op->add_input() = src_op.inputs[1];
1381   (*floor_mod_op->mutable_attr())["T"].set_type(
1382       GetTensorFlowDataType(model, src_op.inputs[0]));
1383 }
1384 
ConvertExpandDimsOperator(const Model & model,const ExpandDimsOperator & src_op,GraphDef * tensorflow_graph)1385 void ConvertExpandDimsOperator(const Model& model,
1386                                const ExpandDimsOperator& src_op,
1387                                GraphDef* tensorflow_graph) {
1388   tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
1389   expand_dims_op->set_op("ExpandDims");
1390   expand_dims_op->set_name(src_op.outputs[0]);
1391   CHECK_EQ(src_op.inputs.size(), 2);
1392   *expand_dims_op->add_input() = src_op.inputs[0];
1393   *expand_dims_op->add_input() = src_op.inputs[1];
1394   (*expand_dims_op->mutable_attr())["T"].set_type(
1395       GetTensorFlowDataType(model, src_op.inputs[0]));
1396   (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1397       GetTensorFlowDataType(model, src_op.inputs[1]));
1398 }
1399 
ConvertResizeBilinearOperator(const Model & model,const ResizeBilinearOperator & src_op,GraphDef * tensorflow_graph)1400 void ConvertResizeBilinearOperator(const Model& model,
1401                                    const ResizeBilinearOperator& src_op,
1402                                    GraphDef* tensorflow_graph) {
1403   tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1404   resize_op->set_op("ResizeBilinear");
1405   resize_op->set_name(src_op.outputs[0]);
1406   CHECK_EQ(src_op.inputs.size(), 2);
1407   *resize_op->add_input() = src_op.inputs[0];
1408   *resize_op->add_input() = src_op.inputs[1];
1409   (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1410   (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1411 }
1412 
ConvertOneHotOperator(const Model & model,const OneHotOperator & src_op,GraphDef * tensorflow_graph)1413 void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
1414                            GraphDef* tensorflow_graph) {
1415   tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
1416   onehot_op->set_op("OneHot");
1417   onehot_op->set_name(src_op.outputs[0]);
1418   CHECK_EQ(src_op.inputs.size(), 4);
1419   for (const auto& input : src_op.inputs) {
1420     *onehot_op->add_input() = input;
1421   }
1422   (*onehot_op->mutable_attr())["T"].set_type(
1423       GetTensorFlowDataType(model, src_op.outputs[0]));
1424   (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
1425 }
1426 
1427 namespace {
1428 // TODO(aselle): Remove when available in absl
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)1429 absl::string_view FindLongestCommonPrefix(absl::string_view a,
1430                                           absl::string_view b) {
1431   if (a.empty() || b.empty()) return absl::string_view();
1432 
1433   const char* pa = a.data();
1434   const char* pb = b.data();
1435   string::difference_type count = 0;
1436   const string::difference_type limit = std::min(a.size(), b.size());
1437   while (count < limit && *pa == *pb) {
1438     ++pa;
1439     ++pb;
1440     ++count;
1441   }
1442 
1443   return absl::string_view(a.data(), count);
1444 }
1445 }  // namespace
1446 
ConvertLstmCellOperator(const Model & model,const LstmCellOperator & src_op,GraphDef * tensorflow_graph)1447 void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1448                              GraphDef* tensorflow_graph) {
1449   // Find the base name
1450   const string base(
1451       FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1452                               src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1453 
1454   // Concatenate inputs
1455   const string concat_output = base + "basic_lstm_cell/concat";
1456   // Op names have been chosen to match the tf.slim LSTM naming
1457   // as closely as possible.
1458   const int axis =
1459       model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1460           .shape()
1461           .dimensions_count() -
1462       1;
1463   // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1464   // works the same since the tensor has the same underlying data layout.
1465   const string axis_output = concat_output + "/axis";
1466   CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1467   tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
1468   concat_op->set_op("ConcatV2");
1469   concat_op->set_name(concat_output);
1470   *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1471   *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1472   *concat_op->add_input() = axis_output;
1473   (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1474   (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1475   (*concat_op->mutable_attr())["N"].set_i(2);  // Number of inputs
1476 
1477   // Write weights
1478   const string weights_output = base + "weights";
1479   CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1480   const string weights_name = WalkUpToConstantArray(
1481       model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1482   const auto& weights_array = model.GetArray(weights_name);
1483   // Convert 4D FullyConnected weights into 2D matrix
1484   const auto& weights_shape = weights_array.shape();
1485   CHECK_EQ(weights_shape.dimensions_count(), 2);
1486   CHECK(weights_array.buffer);
1487   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1488   const float* weights_data =
1489       weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1490   ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1491                           AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1492 
1493   // Fully connected matrix multiply
1494   const string matmul_output = base + "MatMul";
1495   tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
1496   matmul_op->set_op("MatMul");
1497   matmul_op->set_name(matmul_output);
1498   *matmul_op->add_input() = concat_output;
1499   *matmul_op->add_input() = weights_output;
1500   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1501   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1502   (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1503 
1504   // Write biases
1505   const string biases_output = base + "biases";
1506   CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1507   const string bias_name = WalkUpToConstantArray(
1508       model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1509   const auto& bias_array = model.GetArray(bias_name);
1510   // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1511   Shape bias_shape_1d = bias_array.shape();
1512   UnextendShape(&bias_shape_1d, 1);
1513   CHECK(bias_array.buffer);
1514   CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1515   const float* bias_data =
1516       bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1517   ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1518                           AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1519                           tensorflow_graph,
1520                           LegacyScalarPolicy::kDoCreateLegacyScalars);
1521 
1522   // Add biases
1523   string biasadd_output = base + "BiasAdd";
1524   tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
1525   biasadd_op->set_op("BiasAdd");
1526   biasadd_op->set_name(biasadd_output);
1527   biasadd_op->add_input(matmul_output);
1528   biasadd_op->add_input(biases_output);
1529   (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1530   (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1531 
1532   // Split
1533   string split_dim_output = base + "split/split_dim";
1534   // The dimension is the same as the concatenation dimension
1535   CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1536   string split_output = base + "split";
1537   tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1538   split_op->set_op("Split");
1539   split_op->set_name(split_output);
1540   *split_op->add_input() = split_dim_output;
1541   *split_op->add_input() = biasadd_output;
1542   (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1543   (*split_op->mutable_attr())["num_split"].set_i(4);  // Split into four outputs
1544 
1545   // Activation functions and memory computations
1546   const string tanh_0_output = base + "Tanh";
1547   tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
1548   tanh_0_op->set_op("Tanh");
1549   tanh_0_op->set_name(tanh_0_output);
1550   *tanh_0_op->add_input() = split_output + ":1";
1551   (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1552 
1553   const string sigmoid_1_output = base + "Sigmoid_1";
1554   tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
1555   logistic_1_op->set_op("Sigmoid");
1556   logistic_1_op->set_name(sigmoid_1_output);
1557   *logistic_1_op->add_input() = split_output;
1558   (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1559 
1560   const string mul_1_output = base + "mul_1";
1561   tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
1562   mul_1_op->set_op("Mul");
1563   mul_1_op->set_name(mul_1_output);
1564   *mul_1_op->add_input() = sigmoid_1_output;
1565   *mul_1_op->add_input() = tanh_0_output;
1566   (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1567 
1568   const string sigmoid_0_output = base + "Sigmoid";
1569   tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
1570   logistic_2_op->set_op("Sigmoid");
1571   logistic_2_op->set_name(sigmoid_0_output);
1572   *logistic_2_op->add_input() = split_output + ":2";
1573   (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1574 
1575   const string sigmoid_2_output = base + "Sigmoid_2";
1576   tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
1577   logistic_3_op->set_op("Sigmoid");
1578   logistic_3_op->set_name(sigmoid_2_output);
1579   *logistic_3_op->add_input() = split_output + ":3";
1580   (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1581 
1582   const string mul_0_output = base + "mul";
1583   tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
1584   mul_0_op->set_op("Mul");
1585   mul_0_op->set_name(mul_0_output);
1586   *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1587   *mul_0_op->add_input() = sigmoid_0_output;
1588   (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1589 
1590   const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1591   tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
1592   add_1_op->set_op("Add");
1593   add_1_op->set_name(add_1_output);
1594   *add_1_op->add_input() = mul_0_output;
1595   *add_1_op->add_input() = mul_1_output;
1596   (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1597 
1598   const string tanh_1_output = base + "Tanh_1";
1599   tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
1600   tanh_1_op->set_op("Tanh");
1601   tanh_1_op->set_name(tanh_1_output);
1602   *tanh_1_op->add_input() = add_1_output;
1603   (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1604 
1605   const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1606   tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
1607   mul_2_op->set_op("Mul");
1608   mul_2_op->set_name(mul_2_output);
1609   *mul_2_op->add_input() = tanh_1_output;
1610   *mul_2_op->add_input() = sigmoid_2_output;
1611   (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1612 }
1613 
ConvertSpaceToBatchNDOperator(const Model & model,const SpaceToBatchNDOperator & src_op,GraphDef * tensorflow_graph)1614 void ConvertSpaceToBatchNDOperator(const Model& model,
1615                                    const SpaceToBatchNDOperator& src_op,
1616                                    GraphDef* tensorflow_graph) {
1617   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1618   new_op->set_op("SpaceToBatchND");
1619   new_op->set_name(src_op.outputs[0]);
1620   CHECK_EQ(src_op.inputs.size(), 3);
1621   *new_op->add_input() = src_op.inputs[0];
1622   *new_op->add_input() = src_op.inputs[1];
1623   *new_op->add_input() = src_op.inputs[2];
1624   const tensorflow::DataType params_type =
1625       GetTensorFlowDataType(model, src_op.inputs[0]);
1626   (*new_op->mutable_attr())["T"].set_type(params_type);
1627   (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1628   (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1629 }
1630 
ConvertBatchToSpaceNDOperator(const Model & model,const BatchToSpaceNDOperator & src_op,GraphDef * tensorflow_graph)1631 void ConvertBatchToSpaceNDOperator(const Model& model,
1632                                    const BatchToSpaceNDOperator& src_op,
1633                                    GraphDef* tensorflow_graph) {
1634   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1635   new_op->set_op("BatchToSpaceND");
1636   new_op->set_name(src_op.outputs[0]);
1637   CHECK_EQ(src_op.inputs.size(), 3);
1638   *new_op->add_input() = src_op.inputs[0];
1639   *new_op->add_input() = src_op.inputs[1];
1640   *new_op->add_input() = src_op.inputs[2];
1641   const tensorflow::DataType params_type =
1642       GetTensorFlowDataType(model, src_op.inputs[0]);
1643   (*new_op->mutable_attr())["T"].set_type(params_type);
1644   (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1645   (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1646 }
1647 
ConvertPadOperator(const Model & model,const PadOperator & src_op,GraphDef * tensorflow_graph)1648 void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1649                         GraphDef* tensorflow_graph) {
1650   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1651   new_op->set_op("Pad");
1652   new_op->set_name(src_op.outputs[0]);
1653   CHECK_EQ(src_op.inputs.size(), 2);
1654   *new_op->add_input() = src_op.inputs[0];
1655   *new_op->add_input() = src_op.inputs[1];
1656 
1657   const tensorflow::DataType params_type =
1658       GetTensorFlowDataType(model, src_op.inputs[0]);
1659   (*new_op->mutable_attr())["T"].set_type(params_type);
1660 
1661   // Create the params tensor.
1662   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1663   params_op->set_op("Const");
1664   params_op->set_name(src_op.inputs[1]);
1665   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1666   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1667   tensor->set_dtype(DT_INT32);
1668 
1669   CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1670   for (int i = 0; i < src_op.left_padding.size(); ++i) {
1671     tensor->add_int_val(src_op.left_padding[i]);
1672     tensor->add_int_val(src_op.right_padding[i]);
1673   }
1674   auto* shape = tensor->mutable_tensor_shape();
1675   shape->add_dim()->set_size(src_op.left_padding.size());
1676   shape->add_dim()->set_size(2);
1677 }
1678 
ConvertPadV2Operator(const Model & model,const PadV2Operator & src_op,GraphDef * tensorflow_graph)1679 void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
1680                           GraphDef* tensorflow_graph) {
1681   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1682   new_op->set_op("PadV2");
1683   new_op->set_name(src_op.outputs[0]);
1684   CHECK_EQ(src_op.inputs.size(), 2);
1685   *new_op->add_input() = src_op.inputs[0];
1686   *new_op->add_input() = src_op.inputs[1];
1687   *new_op->add_input() = src_op.inputs[2];
1688 
1689   const tensorflow::DataType params_type =
1690       GetTensorFlowDataType(model, src_op.inputs[0]);
1691   (*new_op->mutable_attr())["T"].set_type(params_type);
1692 
1693   // Create the params tensor.
1694   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1695   params_op->set_op("Const");
1696   params_op->set_name(src_op.inputs[1]);
1697   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1698   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1699   tensor->set_dtype(DT_INT32);
1700 
1701   CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1702   for (int i = 0; i < src_op.left_padding.size(); ++i) {
1703     tensor->add_int_val(src_op.left_padding[i]);
1704     tensor->add_int_val(src_op.right_padding[i]);
1705   }
1706   auto* shape = tensor->mutable_tensor_shape();
1707   shape->add_dim()->set_size(src_op.left_padding.size());
1708   shape->add_dim()->set_size(2);
1709 }
1710 
CreateSliceInput(const string & input_name,const std::vector<int> & values,GraphDef * tensorflow_graph)1711 void CreateSliceInput(const string& input_name, const std::vector<int>& values,
1712                       GraphDef* tensorflow_graph) {
1713   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1714   params_op->set_op("Const");
1715   params_op->set_name(input_name);
1716   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1717   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1718   tensor->set_dtype(DT_INT32);
1719 
1720   for (int i = 0; i < values.size(); ++i) {
1721     tensor->add_int_val(values[i]);
1722   }
1723   auto* shape = tensor->mutable_tensor_shape();
1724   shape->add_dim()->set_size(values.size());
1725 }
1726 
ConvertStridedSliceOperator(const Model & model,const StridedSliceOperator & src_op,GraphDef * tensorflow_graph)1727 void ConvertStridedSliceOperator(const Model& model,
1728                                  const StridedSliceOperator& src_op,
1729                                  GraphDef* tensorflow_graph) {
1730   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1731   new_op->set_op("StridedSlice");
1732   new_op->set_name(src_op.outputs[0]);
1733   CHECK_EQ(src_op.inputs.size(), 4);
1734   *new_op->add_input() = src_op.inputs[0];
1735   *new_op->add_input() = src_op.inputs[1];
1736   *new_op->add_input() = src_op.inputs[2];
1737   *new_op->add_input() = src_op.inputs[3];
1738 
1739   const tensorflow::DataType params_type =
1740       GetTensorFlowDataType(model, src_op.inputs[0]);
1741   (*new_op->mutable_attr())["T"].set_type(params_type);
1742 
1743   (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1744   (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1745   (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1746   (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1747   (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1748   (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1749 
1750   // Create tensors for start/stop indices and strides.
1751   CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1752   CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1753   CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1754 }
1755 
ConvertSliceOperator(const Model & model,const SliceOperator & src_op,GraphDef * tensorflow_graph)1756 void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1757                           GraphDef* tensorflow_graph) {
1758   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1759   new_op->set_op("Slice");
1760   new_op->set_name(src_op.outputs[0]);
1761   CHECK_EQ(src_op.inputs.size(), 3);
1762   *new_op->add_input() = src_op.inputs[0];
1763   *new_op->add_input() = src_op.inputs[1];
1764   *new_op->add_input() = src_op.inputs[2];
1765 
1766   const tensorflow::DataType params_type =
1767       GetTensorFlowDataType(model, src_op.inputs[0]);
1768   (*new_op->mutable_attr())["T"].set_type(params_type);
1769   (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1770 
1771   // Create tensors for begin and size inputs.
1772   CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1773   CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1774 }
1775 
1776 template <typename T>
ConvertReduceOperator(const Model & model,const T & src_op,GraphDef * tensorflow_graph,const string & op_name)1777 void ConvertReduceOperator(const Model& model, const T& src_op,
1778                            GraphDef* tensorflow_graph, const string& op_name) {
1779   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1780   new_op->set_op(op_name);
1781   new_op->set_name(src_op.outputs[0]);
1782   CHECK_EQ(src_op.inputs.size(), 2);
1783   *new_op->add_input() = src_op.inputs[0];
1784   *new_op->add_input() = src_op.inputs[1];
1785 
1786   if (src_op.type != OperatorType::kAny) {
1787     const tensorflow::DataType params_type =
1788         GetTensorFlowDataType(model, src_op.inputs[0]);
1789     (*new_op->mutable_attr())["T"].set_type(params_type);
1790   }
1791   const tensorflow::DataType indices_type =
1792       GetTensorFlowDataType(model, src_op.inputs[1]);
1793   (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
1794 
1795   if (src_op.keep_dims) {
1796     (*new_op->mutable_attr())["keep_dims"].set_b(true);
1797   }
1798 
1799   // Create the params tensor.
1800   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1801   params_op->set_op("Const");
1802   params_op->set_name(src_op.inputs[1]);
1803   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1804   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1805   tensor->set_dtype(DT_INT32);
1806 
1807   for (int i = 0; i < src_op.axis.size(); ++i) {
1808     tensor->add_int_val(src_op.axis[i]);
1809   }
1810   auto* shape = tensor->mutable_tensor_shape();
1811   shape->add_dim()->set_size(src_op.axis.size());
1812 }
1813 
ConvertSqueezeOperator(const Model & model,const SqueezeOperator & src_op,GraphDef * tensorflow_graph)1814 void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1815                             GraphDef* tensorflow_graph) {
1816   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1817   new_op->set_op("Squeeze");
1818   new_op->set_name(src_op.outputs[0]);
1819   CHECK_EQ(src_op.inputs.size(), 1);
1820   *new_op->add_input() = src_op.inputs[0];
1821 
1822   const tensorflow::DataType params_type =
1823       GetTensorFlowDataType(model, src_op.inputs[0]);
1824   (*new_op->mutable_attr())["T"].set_type(params_type);
1825 
1826   if (!src_op.squeeze_dims.empty()) {
1827     auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1828     for (int i : src_op.squeeze_dims) {
1829       squeeze_dims.mutable_list()->add_i(i);
1830     }
1831   }
1832 }
1833 
ConvertSubOperator(const Model & model,const SubOperator & src_op,GraphDef * tensorflow_graph)1834 void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1835                         GraphDef* tensorflow_graph) {
1836   tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
1837   sub_op->set_op("Sub");
1838   sub_op->set_name(src_op.outputs[0]);
1839   CHECK_EQ(src_op.inputs.size(), 2);
1840   *sub_op->add_input() = src_op.inputs[0];
1841   *sub_op->add_input() = src_op.inputs[1];
1842   const tensorflow::DataType data_type =
1843       GetTensorFlowDataType(model, src_op.inputs[0]);
1844   (*sub_op->mutable_attr())["T"].set_type(data_type);
1845 }
1846 
ConvertTensorFlowMinimumOperator(const Model & model,const TensorFlowMinimumOperator & src_op,GraphDef * tensorflow_graph)1847 void ConvertTensorFlowMinimumOperator(const Model& model,
1848                                       const TensorFlowMinimumOperator& src_op,
1849                                       GraphDef* tensorflow_graph) {
1850   tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
1851   min_op->set_op("Minimum");
1852   min_op->set_name(src_op.outputs[0]);
1853   CHECK_EQ(src_op.inputs.size(), 2);
1854   *min_op->add_input() = src_op.inputs[0];
1855   *min_op->add_input() = src_op.inputs[1];
1856   const tensorflow::DataType data_type =
1857       GetTensorFlowDataType(model, src_op.inputs[0]);
1858   (*min_op->mutable_attr())["T"].set_type(data_type);
1859 }
1860 
ConvertTensorFlowMaximumOperator(const Model & model,const TensorFlowMaximumOperator & src_op,GraphDef * tensorflow_graph)1861 void ConvertTensorFlowMaximumOperator(const Model& model,
1862                                       const TensorFlowMaximumOperator& src_op,
1863                                       GraphDef* tensorflow_graph) {
1864   tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
1865   max_op->set_op("Maximum");
1866   max_op->set_name(src_op.outputs[0]);
1867   CHECK_EQ(src_op.inputs.size(), 2);
1868   *max_op->add_input() = src_op.inputs[0];
1869   *max_op->add_input() = src_op.inputs[1];
1870   const tensorflow::DataType data_type =
1871       GetTensorFlowDataType(model, src_op.inputs[0]);
1872   (*max_op->mutable_attr())["T"].set_type(data_type);
1873 }
1874 
ConvertSelectOperator(const Model & model,const SelectOperator & src_op,GraphDef * tensorflow_graph)1875 void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
1876                            GraphDef* tensorflow_graph) {
1877   tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
1878   select_op->set_op("Select");
1879   select_op->set_name(src_op.outputs[0]);
1880   CHECK_EQ(src_op.inputs.size(), 3);
1881   *select_op->add_input() = src_op.inputs[0];
1882   *select_op->add_input() = src_op.inputs[1];
1883   *select_op->add_input() = src_op.inputs[2];
1884   const tensorflow::DataType data_type =
1885       GetTensorFlowDataType(model, src_op.inputs[1]);
1886   (*select_op->mutable_attr())["T"].set_type(data_type);
1887 }
1888 
ConvertTileOperator(const Model & model,const TensorFlowTileOperator & src_op,GraphDef * tensorflow_graph)1889 void ConvertTileOperator(const Model& model,
1890                          const TensorFlowTileOperator& src_op,
1891                          GraphDef* tensorflow_graph) {
1892   tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
1893   tile_op->set_op("Tile");
1894   tile_op->set_name(src_op.outputs[0]);
1895   CHECK_EQ(src_op.inputs.size(), 2);
1896   *tile_op->add_input() = src_op.inputs[0];
1897   *tile_op->add_input() = src_op.inputs[1];
1898   const tensorflow::DataType data_type =
1899       GetTensorFlowDataType(model, src_op.inputs[0]);
1900   (*tile_op->mutable_attr())["T"].set_type(data_type);
1901   const tensorflow::DataType multiples_data_type =
1902       GetTensorFlowDataType(model, src_op.inputs[1]);
1903   (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
1904 }
1905 
ConvertTopKV2Operator(const Model & model,const TopKV2Operator & src_op,GraphDef * tensorflow_graph)1906 void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1907                            GraphDef* tensorflow_graph) {
1908   tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
1909   topk_op->set_op("TopKV2");
1910   topk_op->set_name(src_op.outputs[0]);
1911   CHECK_EQ(src_op.inputs.size(), 2);
1912   *topk_op->add_input() = src_op.inputs[0];
1913   *topk_op->add_input() = src_op.inputs[1];
1914   const tensorflow::DataType data_type =
1915       GetTensorFlowDataType(model, src_op.inputs[0]);
1916   (*topk_op->mutable_attr())["T"].set_type(data_type);
1917   (*topk_op->mutable_attr())["sorted"].set_b(true);
1918 }
1919 
ConvertRandomUniformOperator(const Model & model,const RandomUniformOperator & src_op,GraphDef * tensorflow_graph)1920 void ConvertRandomUniformOperator(const Model& model,
1921                                   const RandomUniformOperator& src_op,
1922                                   GraphDef* tensorflow_graph) {
1923   CHECK(tensorflow_graph != nullptr);
1924   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1925   new_op->set_op("RandomUniform");
1926   CHECK_EQ(src_op.inputs.size(), 1);
1927   new_op->set_name(src_op.outputs[0]);
1928   *new_op->add_input() = src_op.inputs[0];
1929   const tensorflow::DataType shape_type =
1930       GetTensorFlowDataType(model, src_op.inputs[0]);
1931   (*new_op->mutable_attr())["T"].set_type(shape_type);
1932   (*new_op->mutable_attr())["dtype"].set_type(
1933       GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1934   (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
1935   (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
1936 }
1937 
ConvertComparisonOperator(const Model & model,const Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)1938 void ConvertComparisonOperator(const Model& model, const Operator& src_op,
1939                                const char* op_name,
1940                                GraphDef* tensorflow_graph) {
1941   tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
1942   comparison_op->set_op(op_name);
1943   comparison_op->set_name(src_op.outputs[0]);
1944   CHECK_EQ(src_op.inputs.size(), 2);
1945   *comparison_op->add_input() = src_op.inputs[0];
1946   *comparison_op->add_input() = src_op.inputs[1];
1947   const tensorflow::DataType data_type =
1948       GetTensorFlowDataType(model, src_op.inputs[0]);
1949   (*comparison_op->mutable_attr())["T"].set_type(data_type);
1950 }
1951 
ConvertSparseToDenseOperator(const Model & model,const SparseToDenseOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1952 void ConvertSparseToDenseOperator(const Model& model,
1953                                   const SparseToDenseOperator& src_op,
1954                                   const char* op_name,
1955                                   GraphDef* tensorflow_graph) {
1956   tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
1957   sparse_to_dense_op->set_op(op_name);
1958   sparse_to_dense_op->set_name(src_op.outputs[0]);
1959   CHECK_EQ(src_op.inputs.size(), 4);
1960   for (int i = 0; i < 4; ++i) {
1961     *sparse_to_dense_op->add_input() = src_op.inputs[i];
1962   }
1963   const tensorflow::DataType data_type =
1964       GetTensorFlowDataType(model, src_op.inputs[3]);
1965   (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
1966   const tensorflow::DataType index_type =
1967       GetTensorFlowDataType(model, src_op.inputs[0]);
1968   (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
1969   (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
1970       src_op.validate_indices);
1971 }
1972 
ConvertPowOperator(const Model & model,const PowOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1973 void ConvertPowOperator(const Model& model, const PowOperator& src_op,
1974                         const char* op_name, GraphDef* tensorflow_graph) {
1975   tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
1976   pow_op->set_op(op_name);
1977   pow_op->set_name(src_op.outputs[0]);
1978   CHECK_EQ(src_op.inputs.size(), 2);
1979   for (int i = 0; i < 2; ++i) {
1980     *pow_op->add_input() = src_op.inputs[i];
1981   }
1982   const tensorflow::DataType data_type =
1983       GetTensorFlowDataType(model, src_op.inputs[0]);
1984   (*pow_op->mutable_attr())["T"].set_type(data_type);
1985 }
1986 
ConvertLogicalAndOperator(const Model & model,const LogicalAndOperator & src_op,GraphDef * tensorflow_graph)1987 void ConvertLogicalAndOperator(const Model& model,
1988                                const LogicalAndOperator& src_op,
1989                                GraphDef* tensorflow_graph) {
1990   tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
1991   logical_op->set_op("LogicalAnd");
1992   logical_op->set_name(src_op.outputs[0]);
1993   CHECK_EQ(src_op.inputs.size(), 2);
1994   for (int i = 0; i < 2; ++i) {
1995     *logical_op->add_input() = src_op.inputs[i];
1996   }
1997 }
1998 
ConvertLogicalNotOperator(const Model & model,const LogicalNotOperator & src_op,GraphDef * tensorflow_graph)1999 void ConvertLogicalNotOperator(const Model& model,
2000                                const LogicalNotOperator& src_op,
2001                                GraphDef* tensorflow_graph) {
2002   tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2003   logical_op->set_op("LogicalNot");
2004   logical_op->set_name(src_op.outputs[0]);
2005   CHECK_EQ(src_op.inputs.size(), 1);
2006   *logical_op->add_input() = src_op.inputs[0];
2007 }
2008 
ConvertLogicalOrOperator(const Model & model,const LogicalOrOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2009 void ConvertLogicalOrOperator(const Model& model,
2010                               const LogicalOrOperator& src_op,
2011                               const char* op_name, GraphDef* tensorflow_graph) {
2012   tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
2013   logical_or_op->set_op(op_name);
2014   logical_or_op->set_name(src_op.outputs[0]);
2015   CHECK_EQ(src_op.inputs.size(), 2);
2016   for (int i = 0; i < 2; ++i) {
2017     *logical_or_op->add_input() = src_op.inputs[i];
2018   }
2019   const tensorflow::DataType data_type =
2020       GetTensorFlowDataType(model, src_op.inputs[0]);
2021   (*logical_or_op->mutable_attr())["T"].set_type(data_type);
2022 }
2023 
ConvertCTCBeamSearchDecoderOperator(const Model & model,const CTCBeamSearchDecoderOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2024 void ConvertCTCBeamSearchDecoderOperator(
2025     const Model& model, const CTCBeamSearchDecoderOperator& src_op,
2026     const char* op_name, GraphDef* tensorflow_graph) {
2027   auto* op = tensorflow_graph->add_node();
2028   op->set_op(op_name);
2029   op->set_name(src_op.outputs[0]);
2030   CHECK_EQ(src_op.inputs.size(), 2);
2031   for (int i = 0; i < 2; ++i) {
2032     *op->add_input() = src_op.inputs[i];
2033   }
2034   (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
2035   (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
2036   (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
2037 }
2038 
ConvertUnpackOperator(const Model & model,const UnpackOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2039 void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
2040                            const char* op_name, GraphDef* tensorflow_graph) {
2041   tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
2042   unpack_op->set_op(op_name);
2043   unpack_op->set_name(src_op.outputs[0]);
2044   CHECK_EQ(src_op.inputs.size(), 2);
2045   *unpack_op->add_input() = src_op.inputs[0];
2046   const tensorflow::DataType data_type =
2047       GetTensorFlowDataType(model, src_op.inputs[0]);
2048   (*unpack_op->mutable_attr())["T"].set_type(data_type);
2049   (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
2050   (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
2051 }
2052 
ConvertZerosLikeOperator(const Model & model,const TensorFlowZerosLikeOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2053 void ConvertZerosLikeOperator(const Model& model,
2054                               const TensorFlowZerosLikeOperator& src_op,
2055                               const char* op_name, GraphDef* tensorflow_graph) {
2056   tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
2057   zeros_like_op->set_op(op_name);
2058   zeros_like_op->set_name(src_op.outputs[0]);
2059   DCHECK_EQ(src_op.inputs.size(), 1);
2060   *zeros_like_op->add_input() = src_op.inputs[0];
2061   const tensorflow::DataType data_type =
2062       GetTensorFlowDataType(model, src_op.inputs[0]);
2063   (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
2064 }
2065 
ConvertReverseV2Operator(const Model & model,const ReverseV2Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)2066 void ConvertReverseV2Operator(const Model& model,
2067                               const ReverseV2Operator& src_op,
2068                               const char* op_name, GraphDef* tensorflow_graph) {
2069   tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
2070   reverse_v2_op->set_op(op_name);
2071   reverse_v2_op->set_name(src_op.outputs[0]);
2072   DCHECK_EQ(src_op.inputs.size(), 2);
2073   *reverse_v2_op->add_input() = src_op.inputs[0];
2074   *reverse_v2_op->add_input() = src_op.inputs[1];
2075   const tensorflow::DataType data_type =
2076       GetTensorFlowDataType(model, src_op.inputs[0]);
2077   (*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
2078 }
2079 
ConvertReverseSequenceOperator(const Model & model,const ReverseSequenceOperator & src_op,GraphDef * tensorflow_graph)2080 void ConvertReverseSequenceOperator(const Model& model,
2081                                     const ReverseSequenceOperator& src_op,
2082                                     GraphDef* tensorflow_graph) {
2083   tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node();
2084   reverse_seq_op->set_op("ReverseSequence");
2085   reverse_seq_op->set_name(src_op.outputs[0]);
2086   CHECK_EQ(src_op.inputs.size(), 2);
2087   *reverse_seq_op->add_input() = src_op.inputs[0];
2088   *reverse_seq_op->add_input() = src_op.inputs[1];
2089   (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim);
2090   (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim);
2091 }
2092 
ConvertOperator(const Model & model,const Operator & src_op,GraphDef * tensorflow_graph)2093 void ConvertOperator(const Model& model, const Operator& src_op,
2094                      GraphDef* tensorflow_graph) {
2095   if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
2096     LOG(FATAL)
2097         << "Unsupported: the input model has a fused activation function";
2098   }
2099 
2100   if (src_op.type == OperatorType::kConv) {
2101     ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
2102                         tensorflow_graph);
2103   } else if (src_op.type == OperatorType::kDepthwiseConv) {
2104     ConvertDepthwiseConvOperator(
2105         model, static_cast<const DepthwiseConvOperator&>(src_op),
2106         tensorflow_graph);
2107   } else if (src_op.type == OperatorType::kDepthToSpace) {
2108     ConvertDepthToSpaceOperator(
2109         model, static_cast<const DepthToSpaceOperator&>(src_op),
2110         tensorflow_graph);
2111   } else if (src_op.type == OperatorType::kSpaceToDepth) {
2112     ConvertSpaceToDepthOperator(
2113         model, static_cast<const SpaceToDepthOperator&>(src_op),
2114         tensorflow_graph);
2115   } else if (src_op.type == OperatorType::kFullyConnected) {
2116     ConvertFullyConnectedOperator(
2117         model, static_cast<const FullyConnectedOperator&>(src_op),
2118         tensorflow_graph);
2119   } else if (src_op.type == OperatorType::kAdd) {
2120     ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
2121                        tensorflow_graph);
2122   } else if (src_op.type == OperatorType::kAddN) {
2123     ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
2124                         tensorflow_graph);
2125   } else if (src_op.type == OperatorType::kMul) {
2126     ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
2127                        tensorflow_graph);
2128   } else if (src_op.type == OperatorType::kDiv) {
2129     ConvertDivOperator(model, static_cast<const DivOperator&>(src_op),
2130                        tensorflow_graph);
2131   } else if (src_op.type == OperatorType::kRelu) {
2132     ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
2133                         tensorflow_graph);
2134   } else if (src_op.type == OperatorType::kRelu1) {
2135     ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
2136                          tensorflow_graph);
2137   } else if (src_op.type == OperatorType::kRelu6) {
2138     ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
2139                          tensorflow_graph);
2140   } else if (src_op.type == OperatorType::kLog) {
2141     ConvertLogOperator(static_cast<const LogOperator&>(src_op),
2142                        tensorflow_graph);
2143   } else if (src_op.type == OperatorType::kLogistic) {
2144     ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
2145                             tensorflow_graph);
2146   } else if (src_op.type == OperatorType::kTanh) {
2147     ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
2148                         tensorflow_graph);
2149   } else if (src_op.type == OperatorType::kL2Normalization) {
2150     ConvertL2NormalizationOperator(
2151         static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
2152   } else if (src_op.type == OperatorType::kSoftmax) {
2153     ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
2154                            tensorflow_graph);
2155   } else if (src_op.type == OperatorType::kLogSoftmax) {
2156     ConvertLogSoftmaxOperator(model,
2157                               static_cast<const LogSoftmaxOperator&>(src_op),
2158                               tensorflow_graph);
2159   } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
2160     ConvertLocalResponseNormalizationOperator(
2161         static_cast<const LocalResponseNormalizationOperator&>(src_op),
2162         tensorflow_graph);
2163   } else if (src_op.type == OperatorType::kLstmCell) {
2164     ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
2165                             tensorflow_graph);
2166   } else if (src_op.type == OperatorType::kMaxPool) {
2167     ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
2168                            tensorflow_graph);
2169   } else if (src_op.type == OperatorType::kAveragePool) {
2170     ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
2171                                tensorflow_graph);
2172   } else if (src_op.type == OperatorType::kConcatenation) {
2173     ConvertConcatenationOperator(
2174         model, static_cast<const ConcatenationOperator&>(src_op),
2175         tensorflow_graph);
2176   } else if (src_op.type == OperatorType::kReshape) {
2177     ConvertTensorFlowReshapeOperator(
2178         model, static_cast<const TensorFlowReshapeOperator&>(src_op),
2179         tensorflow_graph);
2180   } else if (src_op.type == OperatorType::kL2Pool) {
2181     ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
2182                           tensorflow_graph);
2183   } else if (src_op.type == OperatorType::kSquare) {
2184     ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
2185                           tensorflow_graph);
2186   } else if (src_op.type == OperatorType::kSqrt) {
2187     ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
2188                         tensorflow_graph);
2189   } else if (src_op.type == OperatorType::kRsqrt) {
2190     ConvertRsqrtOperator(model,
2191                          static_cast<const TensorFlowRsqrtOperator&>(src_op),
2192                          tensorflow_graph);
2193   } else if (src_op.type == OperatorType::kSplit) {
2194     ConvertSplitOperator(model,
2195                          static_cast<const TensorFlowSplitOperator&>(src_op),
2196                          tensorflow_graph);
2197   } else if (src_op.type == OperatorType::kSplitV) {
2198     ConvertSplitVOperator(model,
2199                           static_cast<const TensorFlowSplitVOperator&>(src_op),
2200                           tensorflow_graph);
2201   } else if (src_op.type == OperatorType::kFakeQuant) {
2202     ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
2203                              tensorflow_graph);
2204   } else if (src_op.type == OperatorType::kCast) {
2205     ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
2206                         tensorflow_graph);
2207   } else if (src_op.type == OperatorType::kFloor) {
2208     ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
2209                          tensorflow_graph);
2210   } else if (src_op.type == OperatorType::kCeil) {
2211     ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
2212                         tensorflow_graph);
2213   } else if (src_op.type == OperatorType::kGather) {
2214     ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
2215                           tensorflow_graph);
2216   } else if (src_op.type == OperatorType::kResizeBilinear) {
2217     ConvertResizeBilinearOperator(
2218         model, static_cast<const ResizeBilinearOperator&>(src_op),
2219         tensorflow_graph);
2220   } else if (src_op.type == OperatorType::kSpaceToBatchND) {
2221     ConvertSpaceToBatchNDOperator(
2222         model, static_cast<const SpaceToBatchNDOperator&>(src_op),
2223         tensorflow_graph);
2224   } else if (src_op.type == OperatorType::kBatchToSpaceND) {
2225     ConvertBatchToSpaceNDOperator(
2226         model, static_cast<const BatchToSpaceNDOperator&>(src_op),
2227         tensorflow_graph);
2228   } else if (src_op.type == OperatorType::kPad) {
2229     ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
2230                        tensorflow_graph);
2231   } else if (src_op.type == OperatorType::kPadV2) {
2232     ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
2233                          tensorflow_graph);
2234   } else if (src_op.type == OperatorType::kStridedSlice) {
2235     ConvertStridedSliceOperator(
2236         model, static_cast<const StridedSliceOperator&>(src_op),
2237         tensorflow_graph);
2238   } else if (src_op.type == OperatorType::kMean) {
2239     ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
2240                           tensorflow_graph, "Mean");
2241   } else if (src_op.type == OperatorType::kSum) {
2242     ConvertReduceOperator(model,
2243                           static_cast<const TensorFlowSumOperator&>(src_op),
2244                           tensorflow_graph, "Sum");
2245   } else if (src_op.type == OperatorType::kReduceProd) {
2246     ConvertReduceOperator(model,
2247                           static_cast<const TensorFlowProdOperator&>(src_op),
2248                           tensorflow_graph, "Prod");
2249   } else if (src_op.type == OperatorType::kReduceMin) {
2250     ConvertReduceOperator(model,
2251                           static_cast<const TensorFlowMinOperator&>(src_op),
2252                           tensorflow_graph, "Min");
2253   } else if (src_op.type == OperatorType::kReduceMax) {
2254     ConvertReduceOperator(model,
2255                           static_cast<const TensorFlowMaxOperator&>(src_op),
2256                           tensorflow_graph, "Max");
2257   } else if (src_op.type == OperatorType::kSub) {
2258     ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
2259                        tensorflow_graph);
2260   } else if (src_op.type == OperatorType::kMinimum) {
2261     ConvertTensorFlowMinimumOperator(
2262         model, static_cast<const TensorFlowMinimumOperator&>(src_op),
2263         tensorflow_graph);
2264   } else if (src_op.type == OperatorType::kMaximum) {
2265     ConvertTensorFlowMaximumOperator(
2266         model, static_cast<const TensorFlowMaximumOperator&>(src_op),
2267         tensorflow_graph);
2268   } else if (src_op.type == OperatorType::kSqueeze) {
2269     ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
2270                            tensorflow_graph);
2271   } else if (src_op.type == OperatorType::kSlice) {
2272     ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
2273                          tensorflow_graph);
2274   } else if (src_op.type == OperatorType::kArgMax) {
2275     ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
2276                           tensorflow_graph);
2277   } else if (src_op.type == OperatorType::kArgMin) {
2278     ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
2279                           tensorflow_graph);
2280   } else if (src_op.type == OperatorType::kTopK_V2) {
2281     ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
2282                           tensorflow_graph);
2283   } else if (src_op.type == OperatorType::kTranspose) {
2284     ConvertTransposeOperator(
2285         model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
2286   } else if (src_op.type == OperatorType::kShape) {
2287     ConvertTensorFlowShapeOperator(
2288         model, static_cast<const TensorFlowShapeOperator&>(src_op),
2289         tensorflow_graph);
2290   } else if (src_op.type == OperatorType::kRank) {
2291     ConvertRankOperator(model,
2292                         static_cast<const TensorFlowRankOperator&>(src_op),
2293                         tensorflow_graph);
2294   } else if (src_op.type == OperatorType::kRange) {
2295     ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
2296                          tensorflow_graph);
2297   } else if (src_op.type == OperatorType::kPack) {
2298     ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
2299                         tensorflow_graph);
2300   } else if (src_op.type == OperatorType::kFill) {
2301     ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
2302                         tensorflow_graph);
2303   } else if (src_op.type == OperatorType::kFloorDiv) {
2304     ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
2305                             tensorflow_graph);
2306   } else if (src_op.type == OperatorType::kFloorMod) {
2307     ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op),
2308                             tensorflow_graph);
2309   } else if (src_op.type == OperatorType::kExpandDims) {
2310     ConvertExpandDimsOperator(model,
2311                               static_cast<const ExpandDimsOperator&>(src_op),
2312                               tensorflow_graph);
2313   } else if (src_op.type == OperatorType::kTransposeConv) {
2314     ConvertTransposeConvOperator(
2315         model, static_cast<const TransposeConvOperator&>(src_op),
2316         tensorflow_graph);
2317   } else if (src_op.type == OperatorType::kRandomUniform) {
2318     ConvertRandomUniformOperator(
2319         model, static_cast<const RandomUniformOperator&>(src_op),
2320         tensorflow_graph);
2321   } else if (src_op.type == OperatorType::kEqual) {
2322     ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
2323   } else if (src_op.type == OperatorType::kNotEqual) {
2324     ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
2325   } else if (src_op.type == OperatorType::kGreater) {
2326     ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
2327   } else if (src_op.type == OperatorType::kGreaterEqual) {
2328     ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
2329   } else if (src_op.type == OperatorType::kLess) {
2330     ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
2331   } else if (src_op.type == OperatorType::kLessEqual) {
2332     ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
2333   } else if (src_op.type == OperatorType::kSelect) {
2334     ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
2335                           tensorflow_graph);
2336   } else if (src_op.type == OperatorType::kTile) {
2337     ConvertTileOperator(model,
2338                         static_cast<const TensorFlowTileOperator&>(src_op),
2339                         tensorflow_graph);
2340   } else if (src_op.type == OperatorType::kPow) {
2341     ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
2342                        tensorflow_graph);
2343   } else if (src_op.type == OperatorType::kAny) {
2344     ConvertReduceOperator(model,
2345                           static_cast<const TensorFlowAnyOperator&>(src_op),
2346                           tensorflow_graph, "Any");
2347   } else if (src_op.type == OperatorType::kLogicalAnd) {
2348     ConvertLogicalAndOperator(model,
2349                               static_cast<const LogicalAndOperator&>(src_op),
2350                               tensorflow_graph);
2351   } else if (src_op.type == OperatorType::kLogicalNot) {
2352     ConvertLogicalNotOperator(model,
2353                               static_cast<const LogicalNotOperator&>(src_op),
2354                               tensorflow_graph);
2355   } else if (src_op.type == OperatorType::kOneHot) {
2356     ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
2357                           tensorflow_graph);
2358   } else if (src_op.type == OperatorType::kLogicalOr) {
2359     ConvertLogicalOrOperator(model,
2360                              static_cast<const LogicalOrOperator&>(src_op),
2361                              "LogicalOr", tensorflow_graph);
2362   } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
2363     ConvertCTCBeamSearchDecoderOperator(
2364         model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
2365         "CTCBeamSearchDecoder", tensorflow_graph);
2366   } else if (src_op.type == OperatorType::kUnpack) {
2367     ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
2368                           "Unpack", tensorflow_graph);
2369   } else if (src_op.type == OperatorType::kZerosLike) {
2370     ConvertZerosLikeOperator(
2371         model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
2372         "ZerosLike", tensorflow_graph);
2373   } else if (src_op.type == OperatorType::kReverseV2) {
2374     ConvertReverseV2Operator(model,
2375                              static_cast<const ReverseV2Operator&>(src_op),
2376                              "Reverse_V2", tensorflow_graph);
2377   } else if (src_op.type == OperatorType::kReverseSequence) {
2378     ConvertReverseSequenceOperator(
2379         model, static_cast<const ReverseSequenceOperator&>(src_op),
2380         tensorflow_graph);
2381   } else {
2382     LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
2383   }
2384 }
2385 
AddPlaceholder(const string & name,ArrayDataType type,GraphDef * tensorflow_graph)2386 void AddPlaceholder(const string& name, ArrayDataType type,
2387                     GraphDef* tensorflow_graph) {
2388   tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2389   placeholder->set_op("Placeholder");
2390   switch (type) {
2391     case ArrayDataType::kBool:
2392       (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
2393       break;
2394     case ArrayDataType::kFloat:
2395       (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2396       break;
2397     case ArrayDataType::kUint8:
2398       (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
2399       break;
2400     case ArrayDataType::kInt32:
2401       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
2402       break;
2403     case ArrayDataType::kInt64:
2404       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
2405       break;
2406     case ArrayDataType::kInt16:
2407       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
2408       break;
2409     case ArrayDataType::kComplex64:
2410       (*placeholder->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
2411       break;
2412     default:
2413       LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
2414   }
2415   placeholder->set_name(name);
2416 }
2417 
AddPlaceholderForRNNState(const Model & model,const string & name,int size,GraphDef * tensorflow_graph)2418 void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
2419                                GraphDef* tensorflow_graph) {
2420   tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2421   placeholder->set_op("Placeholder");
2422   placeholder->set_name(name);
2423   (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2424 
2425   auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
2426   const auto& state_array = model.GetArray(name);
2427   if (state_array.has_shape()) {
2428     const auto& state_shape = state_array.shape();
2429     const int kDims = state_shape.dimensions_count();
2430     for (int i = 0; i < kDims; ++i) {
2431       shape->add_dim()->set_size(state_shape.dims(i));
2432     }
2433   } else {
2434     shape->add_dim()->set_size(1);
2435     shape->add_dim()->set_size(size);
2436   }
2437 }
2438 
ExportTensorFlowGraphDefImplementation(const Model & model,GraphDef * tensorflow_graph)2439 void ExportTensorFlowGraphDefImplementation(const Model& model,
2440                                             GraphDef* tensorflow_graph) {
2441   for (const auto& input_array : model.flags.input_arrays()) {
2442     AddPlaceholder(input_array.name(),
2443                    model.GetArray(input_array.name()).data_type,
2444                    tensorflow_graph);
2445   }
2446   for (const auto& rnn_state : model.flags.rnn_states()) {
2447     AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
2448                               tensorflow_graph);
2449   }
2450   for (const auto& op : model.operators) {
2451     ConvertOperator(model, *op, tensorflow_graph);
2452   }
2453   // Generically export arrays that haven't been exported already
2454   // by the above operators export. It's important that this comes
2455   // after, as some operators need to export arrays that they reference
2456   // in a specific way, rather than in the generic way done below.
2457   for (const auto& array_pair : model.GetArrayMap()) {
2458     const string& array_name = array_pair.first;
2459     const auto& array = *array_pair.second;
2460     if (array.buffer) {
2461       switch (array.data_type) {
2462         case ArrayDataType::kBool:
2463           ConvertBoolTensorConst(model, array_name, tensorflow_graph);
2464           break;
2465         case ArrayDataType::kFloat:
2466           ConvertFloatTensorConst(model, array_name, tensorflow_graph);
2467           break;
2468         case ArrayDataType::kInt32:
2469           ConvertIntTensorConst(model, array_name, tensorflow_graph);
2470           break;
2471         case ArrayDataType::kComplex64:
2472           ConvertComplex64TensorConst(model, array_name, tensorflow_graph);
2473           break;
2474         default:
2475           break;
2476       }
2477     }
2478   }
2479 }
2480 }  // namespace
2481 
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model * model)2482 void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
2483   for (const auto& array_kv : model->GetArrayMap()) {
2484     const string& array_name = array_kv.first;
2485     Array& array = *array_kv.second;
2486     if (!array.buffer || !array.minmax) {
2487       continue;
2488     }
2489     const string& wrapped_array_name =
2490         AvailableArrayName(*model, array_name + "/data");
2491     Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
2492     wrapped_array.data_type = array.data_type;
2493     wrapped_array.copy_shape(array.shape());
2494     wrapped_array.buffer = std::move(array.buffer);
2495     FakeQuantOperator* fakequant_op = new FakeQuantOperator;
2496     fakequant_op->inputs = {wrapped_array_name};
2497     fakequant_op->outputs = {array_name};
2498     fakequant_op->minmax.reset(new MinMax);
2499     *fakequant_op->minmax = *array.minmax;
2500     const auto& it = FindOpWithInput(*model, array_name);
2501     model->operators.emplace(it, fakequant_op);
2502   }
2503   CheckInvariants(*model);
2504 }
2505 
ExportTensorFlowGraphDef(const Model & model,string * output_file_contents)2506 void ExportTensorFlowGraphDef(const Model& model,
2507                               string* output_file_contents) {
2508   CHECK(output_file_contents->empty());
2509   GraphDef tensorflow_graph;
2510   ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
2511   LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
2512   CHECK(tensorflow_graph.SerializeToString(output_file_contents));
2513 }
2514 }  // namespace toco
2515