1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "google/protobuf/map.h"
22 #include "google/protobuf/text_format.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/contrib/lite/toco/model.h"
26 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
27 #include "tensorflow/contrib/lite/toco/runtime/types.h"
28 #include "tensorflow/contrib/lite/toco/tensorflow_util.h"
29 #include "tensorflow/contrib/lite/toco/tooling_util.h"
30 #include "tensorflow/core/framework/attr_value.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 using tensorflow::DT_BOOL;
39 using tensorflow::DT_FLOAT;
40 using tensorflow::DT_INT32;
41 using tensorflow::DT_INT64;
42 using tensorflow::DT_UINT8;
43 using tensorflow::GraphDef;
44 using tensorflow::TensorProto;
45
46 namespace toco {
47 namespace {
48
GetTensorFlowDataType(ArrayDataType data_type)49 tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) {
50 switch (data_type) {
51 case ArrayDataType::kBool:
52 return tensorflow::DT_BOOL;
53 case ArrayDataType::kFloat:
54 return tensorflow::DT_FLOAT;
55 case ArrayDataType::kUint8:
56 return tensorflow::DT_UINT8;
57 case ArrayDataType::kInt32:
58 return tensorflow::DT_INT32;
59 case ArrayDataType::kInt64:
60 return tensorflow::DT_INT64;
61 case ArrayDataType::kString:
62 return tensorflow::DT_STRING;
63 default:
64 case ArrayDataType::kNone:
65 LOG(FATAL) << "Unsupported data type: " << static_cast<int>(data_type);
66 return tensorflow::DT_INVALID;
67 }
68 }
69
GetTensorFlowDataType(const Model & model,const string & array_name)70 tensorflow::DataType GetTensorFlowDataType(const Model& model,
71 const string& array_name) {
72 return GetTensorFlowDataType(model.GetArray(array_name).data_type);
73 }
74
75 // TensorFlow sometimes forbids what it calls "legacy scalars",
76 // which are 1-D shapes where the unique shape size is 1.
77 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
78 // For that reason, we generally avoid creating legacy scalars,
79 // by detecting the case where a 1-D shape would be of size 1 and
80 // replacing that by a 0-D shape.
81 // However, there is a special circumstance where we must not do that
82 // and must unconditionally create a 1-D shape even if it is going to
83 // be of size 1: that is the case of bias vectors, with BiasAdd nodes.
84 // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
85 // a depth of 1, that would be a legacy scalar, so in that case we
86 // must go ahead and keep the shape 1-D, letting it be a legacy scalar.
87 enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
88
ExportFloatArray(const Shape & input_shape,const float * input_data,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)89 void ExportFloatArray(const Shape& input_shape, const float* input_data,
90 TensorProto* output_tensor,
91 LegacyScalarPolicy legacy_scalar_policy) {
92 output_tensor->set_dtype(DT_FLOAT);
93 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
94 auto* shape = output_tensor->mutable_tensor_shape();
95
96 const int kDims = input_shape.dimensions_count();
97 if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
98 kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
99 for (int i = 0; i < kDims; ++i) {
100 shape->add_dim()->set_size(input_shape.dims(i));
101 }
102 }
103 output_tensor->set_tensor_content(
104 string(reinterpret_cast<const char*>(input_data),
105 sizeof(*input_data) * input_flat_size));
106 }
107
ExportFloatArray(AxesOrder input_axes_order,const Shape & input_shape,const float * input_data,AxesOrder output_axes_order,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)108 void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
109 const float* input_data, AxesOrder output_axes_order,
110 TensorProto* output_tensor,
111 LegacyScalarPolicy legacy_scalar_policy) {
112 CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
113 output_tensor->set_dtype(DT_FLOAT);
114 CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
115 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
116
117 Shape shuffled_shape;
118 ShuffleDims(input_shape, input_axes_order, output_axes_order,
119 &shuffled_shape);
120 std::vector<float> shuffled_data(input_flat_size);
121 ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
122 input_data, shuffled_data.data());
123
124 ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
125 legacy_scalar_policy);
126 }
127
HasAlreadyExportedConst(const string & name,const GraphDef & tensorflow_graph)128 bool HasAlreadyExportedConst(const string& name,
129 const GraphDef& tensorflow_graph) {
130 for (const auto& node : tensorflow_graph.node()) {
131 if (node.op() == "Const" && node.name() == name) {
132 return true;
133 }
134 }
135 return false;
136 }
137
ConvertFloatTensorConst(const string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph,LegacyScalarPolicy legacy_scalar_policy)138 void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
139 const float* input_data,
140 AxesOrder input_axes_order,
141 AxesOrder output_axes_order,
142 GraphDef* tensorflow_graph,
143 LegacyScalarPolicy legacy_scalar_policy) {
144 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
145 return;
146 }
147 auto* const_op = tensorflow_graph->add_node();
148 const_op->set_op("Const");
149 const_op->set_name(name);
150 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
151 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
152 ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
153 tensor, legacy_scalar_policy);
154 }
155
ConvertFloatTensorConst(const string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)156 void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
157 const float* input_data,
158 AxesOrder input_axes_order,
159 AxesOrder output_axes_order,
160 GraphDef* tensorflow_graph) {
161 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
162 return;
163 }
164 auto* const_op = tensorflow_graph->add_node();
165 const_op->set_op("Const");
166 const_op->set_name(name);
167 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
168 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
169 ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
170 tensor, LegacyScalarPolicy::kAvoidLegacyScalars);
171 }
172
ConvertFloatTensorConst(const Model & model,const string & name,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)173 void ConvertFloatTensorConst(const Model& model, const string& name,
174 AxesOrder input_axes_order,
175 AxesOrder output_axes_order,
176 GraphDef* tensorflow_graph) {
177 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
178 return;
179 }
180 auto* const_op = tensorflow_graph->add_node();
181 const_op->set_op("Const");
182 const_op->set_name(name);
183 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
184 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
185 CHECK(model.HasArray(name));
186 const auto& input_array = model.GetArray(name);
187 const auto& input_shape = input_array.shape();
188 CHECK(input_array.buffer);
189 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
190 const float* input_data =
191 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
192 ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
193 tensor, LegacyScalarPolicy::kAvoidLegacyScalars);
194 }
195
ConvertFloatTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)196 void ConvertFloatTensorConst(const Model& model, const string& name,
197 GraphDef* tensorflow_graph) {
198 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
199 return;
200 }
201 auto* const_op = tensorflow_graph->add_node();
202 const_op->set_op("Const");
203 const_op->set_name(name);
204 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
205 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
206 CHECK(model.HasArray(name));
207 const auto& input_array = model.GetArray(name);
208 const auto& input_shape = input_array.shape();
209 CHECK(input_array.buffer);
210 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
211 const float* input_data =
212 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
213 ExportFloatArray(input_shape, input_data, tensor,
214 LegacyScalarPolicy::kAvoidLegacyScalars);
215 }
216
ConvertIntTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)217 void ConvertIntTensorConst(const Model& model, const string& name,
218 GraphDef* tensorflow_graph) {
219 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
220 return;
221 }
222 CHECK(model.HasArray(name));
223 const auto& array = model.GetArray(name);
224 auto* const_op = tensorflow_graph->add_node();
225 const_op->set_op("Const");
226 const_op->set_name(name);
227 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
228 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
229 tensor->set_dtype(DT_INT32);
230 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
231 for (auto index : data) {
232 tensor->add_int_val(index);
233 }
234 const auto& array_shape = array.shape();
235 auto* shape = tensor->mutable_tensor_shape();
236 for (int i = 0; i < array_shape.dimensions_count(); i++) {
237 shape->add_dim()->set_size(array_shape.dims(i));
238 }
239 }
240
CreateIntTensorConst(const string & name,const std::vector<int32> & data,GraphDef * tensorflow_graph)241 void CreateIntTensorConst(const string& name, const std::vector<int32>& data,
242 GraphDef* tensorflow_graph) {
243 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
244 return;
245 }
246 auto* const_op = tensorflow_graph->add_node();
247 const_op->set_op("Const");
248 const_op->set_name(name);
249 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
250 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
251 tensor->set_dtype(DT_INT32);
252 for (auto index : data) {
253 tensor->add_int_val(index);
254 }
255 auto* shape = tensor->mutable_tensor_shape();
256 shape->add_dim()->set_size(data.size());
257 }
258
CreateMatrixShapeTensorConst(const string & name,int rows,int cols,GraphDef * tensorflow_graph)259 void CreateMatrixShapeTensorConst(const string& name, int rows, int cols,
260 GraphDef* tensorflow_graph) {
261 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
262 return;
263 }
264 auto* const_op = tensorflow_graph->add_node();
265 const_op->set_op("Const");
266 const_op->set_name(name);
267 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
268 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
269 tensor->set_dtype(DT_INT32);
270 const int32 data[2] = {cols, rows};
271 tensor->set_tensor_content(
272 string(reinterpret_cast<const char*>(data), sizeof(data)));
273 auto* shape = tensor->mutable_tensor_shape();
274 shape->add_dim()->set_size(2);
275 }
276
CreateDummyConcatDimTensorConst(const string & name,int dim,GraphDef * tensorflow_graph)277 void CreateDummyConcatDimTensorConst(const string& name, int dim,
278 GraphDef* tensorflow_graph) {
279 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
280 return;
281 }
282 auto* const_op = tensorflow_graph->add_node();
283 const_op->set_op("Const");
284 const_op->set_name(name);
285 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
286 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
287 tensor->set_dtype(DT_INT32);
288 tensor->add_int_val(dim);
289 }
290
CreateReshapeShapeTensorConst(const string & name,const std::vector<int32> & shape,GraphDef * tensorflow_graph)291 void CreateReshapeShapeTensorConst(const string& name,
292 const std::vector<int32>& shape,
293 GraphDef* tensorflow_graph) {
294 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
295 return;
296 }
297 auto* const_op = tensorflow_graph->add_node();
298 const_op->set_op("Const");
299 const_op->set_name(name);
300 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
301 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
302 tensor->set_dtype(DT_INT32);
303 for (auto s : shape) {
304 tensor->add_int_val(s);
305 }
306 // TensorFlow sometimes forbids what it calls "legacy scalars",
307 // which are shapes of size 1 where the unique shape size is 1.
308 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
309 if (shape.size() > 1) {
310 auto* tensor_shape = tensor->mutable_tensor_shape();
311 tensor_shape->add_dim()->set_size(shape.size());
312 }
313 }
314
WalkUpToConstantArray(const Model & model,const string & name)315 string WalkUpToConstantArray(const Model& model, const string& name) {
316 const Array& original_array = model.GetArray(name);
317 if (original_array.buffer) {
318 return name;
319 }
320 const auto* op = GetOpWithOutput(model, name);
321 CHECK(op);
322 CHECK(op->type == OperatorType::kFakeQuant);
323 const string& input_of_fakequant_name = op->inputs[0];
324 const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
325 CHECK(input_of_fakequant.buffer);
326 return input_of_fakequant_name;
327 }
328
ConvertConvOperator(const Model & model,const ConvOperator & src_op,GraphDef * tensorflow_graph)329 void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
330 GraphDef* tensorflow_graph) {
331 const bool has_bias = src_op.inputs.size() >= 3;
332 string conv_output = src_op.outputs[0];
333 if (has_bias) {
334 conv_output += "/conv";
335 }
336
337 auto* conv2d_op = tensorflow_graph->add_node();
338 conv2d_op->set_op("Conv2D");
339 conv2d_op->set_name(conv_output);
340 *conv2d_op->add_input() = src_op.inputs[0];
341 *conv2d_op->add_input() = src_op.inputs[1];
342 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
343 const string& weights_array_name =
344 WalkUpToConstantArray(model, src_op.inputs[1]);
345 const auto& weights_array = model.GetArray(weights_array_name);
346 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
347 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
348 AxesOrder::kHWIO, tensorflow_graph);
349 auto& strides = (*conv2d_op->mutable_attr())["strides"];
350 strides.mutable_list()->add_i(1);
351 strides.mutable_list()->add_i(src_op.stride_height);
352 strides.mutable_list()->add_i(src_op.stride_width);
353 strides.mutable_list()->add_i(1);
354 string padding;
355 if (src_op.padding.type == PaddingType::kSame) {
356 padding = "SAME";
357 } else if (src_op.padding.type == PaddingType::kValid) {
358 padding = "VALID";
359 } else {
360 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
361 }
362 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
363
364 if (has_bias) {
365 auto* biasadd_op = tensorflow_graph->add_node();
366 biasadd_op->set_op("BiasAdd");
367 biasadd_op->set_name(src_op.outputs[0]);
368 biasadd_op->add_input(conv_output);
369 biasadd_op->add_input(src_op.inputs[2]);
370 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
371 CHECK(model.HasArray(src_op.inputs[2]));
372 const string& bias_array_name =
373 WalkUpToConstantArray(model, src_op.inputs[2]);
374 const auto& bias_array = model.GetArray(bias_array_name);
375 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
376 Shape bias_shape_1d = bias_array.shape();
377 UnextendShape(&bias_shape_1d, 1);
378 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
379 const float* bias_data =
380 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
381 ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
382 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
383 tensorflow_graph,
384 LegacyScalarPolicy::kDoCreateLegacyScalars);
385 }
386 }
387
ConvertDepthwiseConvOperator(const Model & model,const DepthwiseConvOperator & src_op,GraphDef * tensorflow_graph)388 void ConvertDepthwiseConvOperator(const Model& model,
389 const DepthwiseConvOperator& src_op,
390 GraphDef* tensorflow_graph) {
391 const bool has_bias = src_op.inputs.size() >= 3;
392 string conv_output = src_op.outputs[0];
393 if (has_bias) {
394 conv_output += "/conv";
395 }
396
397 auto* dc2d_op = tensorflow_graph->add_node();
398 dc2d_op->set_op("DepthwiseConv2dNative");
399 dc2d_op->set_name(conv_output);
400 *dc2d_op->add_input() = src_op.inputs[0];
401 *dc2d_op->add_input() = src_op.inputs[1];
402 (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
403
404 // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
405 // We need to convert that to H x W x InputDepth x Multiplier.
406 // That's only a matter of constructing a Dims object; the actual
407 // array layout is the same.
408 CHECK(model.HasArray(src_op.inputs[1]));
409 const string& src_weights_name =
410 WalkUpToConstantArray(model, src_op.inputs[1]);
411 const auto& src_weights_array = model.GetArray(src_weights_name);
412 const auto& src_weights_shape = src_weights_array.shape();
413 CHECK_EQ(src_weights_shape.dimensions_count(), 4);
414 const Shape dst_weights_shape =
415 Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
416 src_weights_shape.dims(3) / src_op.depth_multiplier,
417 src_op.depth_multiplier});
418 CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
419 CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
420 src_weights_shape.dims(3));
421 CHECK_EQ(src_weights_shape.dims(0), 1);
422
423 CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
424 const float* src_weights_data =
425 src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
426 ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
427 AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
428
429 auto& strides = (*dc2d_op->mutable_attr())["strides"];
430 strides.mutable_list()->add_i(1);
431 strides.mutable_list()->add_i(src_op.stride_height);
432 strides.mutable_list()->add_i(src_op.stride_width);
433 strides.mutable_list()->add_i(1);
434 string padding;
435 if (src_op.padding.type == PaddingType::kSame) {
436 padding = "SAME";
437 } else if (src_op.padding.type == PaddingType::kValid) {
438 padding = "VALID";
439 } else {
440 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
441 }
442 (*dc2d_op->mutable_attr())["padding"].set_s(padding);
443
444 if (has_bias) {
445 auto* biasadd_op = tensorflow_graph->add_node();
446 biasadd_op->set_op("BiasAdd");
447 biasadd_op->set_name(src_op.outputs[0]);
448 biasadd_op->add_input(conv_output);
449 biasadd_op->add_input(src_op.inputs[2]);
450 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
451 CHECK(model.HasArray(src_op.inputs[2]));
452 const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
453 const auto& bias_array = model.GetArray(bias_name);
454 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
455 Shape bias_shape_1d = bias_array.shape();
456 UnextendShape(&bias_shape_1d, 1);
457 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
458 const float* bias_data =
459 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
460 ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
461 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
462 tensorflow_graph,
463 LegacyScalarPolicy::kDoCreateLegacyScalars);
464 }
465 }
466
ConvertDepthToSpaceOperator(const Model & model,const DepthToSpaceOperator & src_op,GraphDef * tensorflow_graph)467 void ConvertDepthToSpaceOperator(const Model& model,
468 const DepthToSpaceOperator& src_op,
469 GraphDef* tensorflow_graph) {
470 auto* op = tensorflow_graph->add_node();
471 op->set_op("DepthToSpace");
472 op->set_name(src_op.outputs[0]);
473 *op->add_input() = src_op.inputs[0];
474 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
475 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
476 }
477
ConvertSpaceToDepthOperator(const Model & model,const SpaceToDepthOperator & src_op,GraphDef * tensorflow_graph)478 void ConvertSpaceToDepthOperator(const Model& model,
479 const SpaceToDepthOperator& src_op,
480 GraphDef* tensorflow_graph) {
481 auto* op = tensorflow_graph->add_node();
482 op->set_op("SpaceToDepth");
483 op->set_name(src_op.outputs[0]);
484 *op->add_input() = src_op.inputs[0];
485 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
486 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
487 }
488
ConvertFullyConnectedOperator(const Model & model,const FullyConnectedOperator & src_op,GraphDef * tensorflow_graph)489 void ConvertFullyConnectedOperator(const Model& model,
490 const FullyConnectedOperator& src_op,
491 GraphDef* tensorflow_graph) {
492 // Reshape input activations to have the shape expected by the MatMul.
493 const string reshape_output =
494 AvailableArrayName(model, src_op.outputs[0] + "/reshape");
495 const string reshape_shape =
496 AvailableArrayName(model, reshape_output + "/shape");
497 const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
498 const auto& fc_weights_shape = fc_weights_array.shape();
499 CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
500 CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
501 tensorflow_graph);
502 auto* reshape_op = tensorflow_graph->add_node();
503 reshape_op->set_op("Reshape");
504 reshape_op->set_name(reshape_output);
505 reshape_op->add_input(src_op.inputs[0]);
506 reshape_op->add_input(reshape_shape);
507 (*reshape_op->mutable_attr())["T"].set_type(
508 GetTensorFlowDataType(model, src_op.inputs[0]));
509
510 const bool has_bias = src_op.inputs.size() >= 3;
511 string matmul_output = src_op.outputs[0];
512 if (has_bias) {
513 matmul_output += "/matmul";
514 }
515
516 // Transpose the RHS input from column-major to row-major to match TensorFlow
517 // expectations. This is the inverse of the transpose we do during
518 // ResolveTensorFlowMatMul.
519 const string transpose_output =
520 AvailableArrayName(model, matmul_output + "/transpose_weights");
521 const string transpose_perm =
522 AvailableArrayName(model, transpose_output + "/perm");
523 CreateIntTensorConst(transpose_perm, {1, 0}, tensorflow_graph);
524 auto transpose_op = tensorflow_graph->add_node();
525 transpose_op->set_op("Transpose");
526 transpose_op->set_name(transpose_output);
527 *transpose_op->add_input() = src_op.inputs[1];
528 *transpose_op->add_input() = transpose_perm;
529 (*transpose_op->mutable_attr())["T"].set_type(
530 GetTensorFlowDataType(model, src_op.inputs[1]));
531 (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
532
533 auto* matmul_op = tensorflow_graph->add_node();
534 matmul_op->set_op("MatMul");
535 matmul_op->set_name(matmul_output);
536 *matmul_op->add_input() = reshape_output;
537 *matmul_op->add_input() = transpose_op->name();
538 (*matmul_op->mutable_attr())["T"].set_type(
539 GetTensorFlowDataType(model, src_op.inputs[0]));
540 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
541 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
542 CHECK(model.HasArray(src_op.inputs[1]));
543
544 // Add the bias, if it exists.
545 if (has_bias) {
546 auto* biasadd_op = tensorflow_graph->add_node();
547 biasadd_op->set_op("BiasAdd");
548 biasadd_op->set_name(src_op.outputs[0]);
549 biasadd_op->add_input(matmul_output);
550 biasadd_op->add_input(src_op.inputs[2]);
551 (*biasadd_op->mutable_attr())["T"].set_type(
552 GetTensorFlowDataType(model, src_op.inputs[0]));
553 CHECK(model.HasArray(src_op.inputs[2]));
554 const auto& bias_array = model.GetArray(src_op.inputs[2]);
555 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
556 Shape bias_shape_1d = bias_array.shape();
557 UnextendShape(&bias_shape_1d, 1);
558 CHECK(bias_array.buffer);
559 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
560 const float* bias_data =
561 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
562 ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
563 bias_shape_1d, bias_data, AxesOrder::kOneAxis,
564 AxesOrder::kOneAxis, tensorflow_graph,
565 LegacyScalarPolicy::kDoCreateLegacyScalars);
566 }
567 }
568
ConvertAddOperator(const Model & model,const AddOperator & src_op,GraphDef * tensorflow_graph)569 void ConvertAddOperator(const Model& model, const AddOperator& src_op,
570 GraphDef* tensorflow_graph) {
571 auto* add_op = tensorflow_graph->add_node();
572 add_op->set_op("Add");
573 add_op->set_name(src_op.outputs[0]);
574 CHECK_EQ(src_op.inputs.size(), 2);
575 *add_op->add_input() = src_op.inputs[0];
576 *add_op->add_input() = src_op.inputs[1];
577 (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
578 }
579
ConvertAddNOperator(const Model & model,const AddNOperator & src_op,GraphDef * tensorflow_graph)580 void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
581 GraphDef* tensorflow_graph) {
582 auto* add_op = tensorflow_graph->add_node();
583 add_op->set_op("AddN");
584 add_op->set_name(src_op.outputs[0]);
585 for (const auto& input : src_op.inputs) {
586 *add_op->add_input() = input;
587 }
588 (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
589 (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
590 }
591
ConvertMulOperator(const Model & model,const MulOperator & src_op,GraphDef * tensorflow_graph)592 void ConvertMulOperator(const Model& model, const MulOperator& src_op,
593 GraphDef* tensorflow_graph) {
594 auto* add_op = tensorflow_graph->add_node();
595 add_op->set_op("Mul");
596 add_op->set_name(src_op.outputs[0]);
597 CHECK_EQ(src_op.inputs.size(), 2);
598 *add_op->add_input() = src_op.inputs[0];
599 *add_op->add_input() = src_op.inputs[1];
600 (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
601 }
602
ConvertReluOperator(const ReluOperator & src_op,GraphDef * tensorflow_graph)603 void ConvertReluOperator(const ReluOperator& src_op,
604 GraphDef* tensorflow_graph) {
605 auto* relu_op = tensorflow_graph->add_node();
606 relu_op->set_op("Relu");
607 relu_op->set_name(src_op.outputs[0]);
608 *relu_op->add_input() = src_op.inputs[0];
609 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
610 }
611
ConvertRelu1Operator(const Relu1Operator & src_op,GraphDef * tensorflow_graph)612 void ConvertRelu1Operator(const Relu1Operator& src_op,
613 GraphDef* tensorflow_graph) {
614 const string max_bounds = src_op.outputs[0] + "/max_bounds";
615 const string min_bounds = src_op.outputs[0] + "/min_bounds";
616 const string max_output = src_op.outputs[0] + "/max_output";
617
618 auto* max_bounds_const_op = tensorflow_graph->add_node();
619 max_bounds_const_op->set_op("Const");
620 max_bounds_const_op->set_name(max_bounds);
621 (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
622 auto* max_bounds_const_op_tensor =
623 (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
624 max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
625 max_bounds_const_op_tensor->add_float_val(-1.0f);
626
627 auto* min_bounds_const_op = tensorflow_graph->add_node();
628 min_bounds_const_op->set_op("Const");
629 min_bounds_const_op->set_name(min_bounds);
630 (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
631 auto* min_bounds_const_op_tensor =
632 (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
633 min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
634 min_bounds_const_op_tensor->add_float_val(1.0f);
635
636 auto* max_op = tensorflow_graph->add_node();
637 max_op->set_op("Maximum");
638 max_op->set_name(max_output);
639 *max_op->add_input() = src_op.inputs[0];
640 *max_op->add_input() = max_bounds;
641 (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
642
643 auto* min_op = tensorflow_graph->add_node();
644 min_op->set_op("Minimum");
645 min_op->set_name(src_op.outputs[0]);
646 *min_op->add_input() = max_output;
647 *min_op->add_input() = min_bounds;
648 (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
649 }
650
ConvertRelu6Operator(const Relu6Operator & src_op,GraphDef * tensorflow_graph)651 void ConvertRelu6Operator(const Relu6Operator& src_op,
652 GraphDef* tensorflow_graph) {
653 auto* relu_op = tensorflow_graph->add_node();
654 relu_op->set_op("Relu6");
655 relu_op->set_name(src_op.outputs[0]);
656 *relu_op->add_input() = src_op.inputs[0];
657 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
658 }
659
ConvertLogisticOperator(const LogisticOperator & src_op,GraphDef * tensorflow_graph)660 void ConvertLogisticOperator(const LogisticOperator& src_op,
661 GraphDef* tensorflow_graph) {
662 auto* relu_op = tensorflow_graph->add_node();
663 relu_op->set_op("Sigmoid");
664 relu_op->set_name(src_op.outputs[0]);
665 *relu_op->add_input() = src_op.inputs[0];
666 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
667 }
668
ConvertTanhOperator(const TanhOperator & src_op,GraphDef * tensorflow_graph)669 void ConvertTanhOperator(const TanhOperator& src_op,
670 GraphDef* tensorflow_graph) {
671 auto* tanh_op = tensorflow_graph->add_node();
672 tanh_op->set_op("Tanh");
673 tanh_op->set_name(src_op.outputs[0]);
674 *tanh_op->add_input() = src_op.inputs[0];
675 (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
676 }
677
ConvertSoftmaxOperator(const Model & model,const SoftmaxOperator & src_op,GraphDef * tensorflow_graph)678 void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
679 GraphDef* tensorflow_graph) {
680 string softmax_input;
681 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
682 if (providing_op != nullptr &&
683 providing_op->type == OperatorType::kTensorFlowReshape) {
684 softmax_input = src_op.inputs[0];
685 } else {
686 // Insert a reshape operator that reduces the dimensions down to the 2 that
687 // are required for TensorFlow Logits.
688 const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape";
689 const string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
690 softmax_input = reshape_output;
691
692 auto* reshape_op = tensorflow_graph->add_node();
693 reshape_op->set_op("Reshape");
694 reshape_op->set_name(reshape_output);
695 *reshape_op->add_input() = src_op.inputs[0];
696 *reshape_op->add_input() = softmax_size;
697 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
698
699 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
700 int32 flattened_size = 1;
701 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
702 flattened_size *= input_shape.dims(i);
703 }
704 const std::vector<int32> shape_data = {
705 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
706 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
707 }
708
709 auto* softmax_op = tensorflow_graph->add_node();
710 softmax_op->set_op("Softmax");
711 softmax_op->set_name(src_op.outputs[0]);
712 *softmax_op->add_input() = softmax_input;
713 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
714 CHECK_EQ(src_op.beta, 1.f);
715 (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
716 }
717
ConvertLogSoftmaxOperator(const Model & model,const LogSoftmaxOperator & src_op,GraphDef * tensorflow_graph)718 void ConvertLogSoftmaxOperator(const Model& model,
719 const LogSoftmaxOperator& src_op,
720 GraphDef* tensorflow_graph) {
721 string softmax_input;
722 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
723 if (providing_op->type == OperatorType::kTensorFlowReshape) {
724 softmax_input = src_op.inputs[0];
725 } else {
726 // Insert a reshape operator that reduces the dimensions down to the 2 that
727 // are required for TensorFlow Logits.
728 const string reshape_output =
729 src_op.outputs[0] + "/log_softmax_insert_reshape";
730 const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size";
731 softmax_input = reshape_output;
732
733 auto* reshape_op = tensorflow_graph->add_node();
734 reshape_op->set_op("Reshape");
735 reshape_op->set_name(reshape_output);
736 *reshape_op->add_input() = src_op.inputs[0];
737 *reshape_op->add_input() = softmax_size;
738 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
739
740 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
741 int32 flattened_size = 1;
742 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
743 flattened_size *= input_shape.dims(i);
744 }
745 const std::vector<int32> shape_data = {
746 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
747 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
748 }
749
750 auto* log_softmax_op = tensorflow_graph->add_node();
751 log_softmax_op->set_op("LogSoftmax");
752 log_softmax_op->set_name(src_op.outputs[0]);
753 *log_softmax_op->add_input() = softmax_input;
754 (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
755 }
756
ConvertL2NormalizationOperator(const L2NormalizationOperator & src_op,GraphDef * tensorflow_graph)757 void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
758 GraphDef* tensorflow_graph) {
759 const string square_output = src_op.outputs[0] + "/square";
760 const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices";
761 const string sum_output = src_op.outputs[0] + "/sum";
762 const string rsqrt_output = src_op.outputs[0] + "/rsqrt";
763 const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
764
765 auto* sum_reduction_indices_op = tensorflow_graph->add_node();
766 sum_reduction_indices_op->set_op("Const");
767 sum_reduction_indices_op->set_name(sum_reduction_indices);
768 (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
769 auto* sum_reduction_indices_tensor =
770 (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
771 sum_reduction_indices_tensor->set_dtype(DT_INT32);
772 auto* sum_reduction_indices_shape =
773 sum_reduction_indices_tensor->mutable_tensor_shape();
774 auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
775 sum_reduction_indices_dim->set_size(2);
776 sum_reduction_indices_tensor->add_int_val(0);
777 sum_reduction_indices_tensor->add_int_val(1);
778
779 auto* square_op = tensorflow_graph->add_node();
780 square_op->set_op("Square");
781 square_op->set_name(square_output);
782 *square_op->add_input() = src_op.inputs[0];
783 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
784
785 auto* sum_op = tensorflow_graph->add_node();
786 sum_op->set_op("Sum");
787 sum_op->set_name(sum_output);
788 *sum_op->add_input() = square_output;
789 *sum_op->add_input() = sum_reduction_indices;
790 (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
791
792 auto* rsqrt_op = tensorflow_graph->add_node();
793 rsqrt_op->set_op("Rsqrt");
794 rsqrt_op->set_name(rsqrt_output);
795 *rsqrt_op->add_input() = sum_output;
796 (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
797
798 auto* mul_op = tensorflow_graph->add_node();
799 mul_op->set_op("Mul");
800 mul_op->set_name(src_op.outputs[0]);
801 *mul_op->add_input() = src_op.inputs[0];
802 *mul_op->add_input() = rsqrt_output;
803 (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
804 }
805
ConvertLocalResponseNormalizationOperator(const LocalResponseNormalizationOperator & src_op,GraphDef * tensorflow_graph)806 void ConvertLocalResponseNormalizationOperator(
807 const LocalResponseNormalizationOperator& src_op,
808 GraphDef* tensorflow_graph) {
809 auto* lrn_op = tensorflow_graph->add_node();
810 lrn_op->set_op("LRN");
811 lrn_op->set_name(src_op.outputs[0]);
812 *lrn_op->add_input() = src_op.inputs[0];
813 (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
814 (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
815 (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
816 (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
817 }
818
ConvertFakeQuantOperator(const FakeQuantOperator & src_op,GraphDef * tensorflow_graph)819 void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
820 GraphDef* tensorflow_graph) {
821 auto* fakequant_op = tensorflow_graph->add_node();
822 fakequant_op->set_op("FakeQuantWithMinMaxArgs");
823 fakequant_op->set_name(src_op.outputs[0]);
824 CHECK_EQ(src_op.inputs.size(), 1);
825 *fakequant_op->add_input() = src_op.inputs[0];
826 CHECK(src_op.minmax);
827 (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
828 (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
829 }
830
ConvertMaxPoolOperator(const MaxPoolOperator & src_op,GraphDef * tensorflow_graph)831 void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
832 GraphDef* tensorflow_graph) {
833 auto* maxpool_op = tensorflow_graph->add_node();
834 maxpool_op->set_op("MaxPool");
835 maxpool_op->set_name(src_op.outputs[0]);
836 *maxpool_op->add_input() = src_op.inputs[0];
837 auto& strides = (*maxpool_op->mutable_attr())["strides"];
838 strides.mutable_list()->add_i(1);
839 strides.mutable_list()->add_i(src_op.stride_height);
840 strides.mutable_list()->add_i(src_op.stride_width);
841 strides.mutable_list()->add_i(1);
842 string padding;
843 if (src_op.padding.type == PaddingType::kSame) {
844 padding = "SAME";
845 } else if (src_op.padding.type == PaddingType::kValid) {
846 padding = "VALID";
847 } else {
848 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
849 }
850 (*maxpool_op->mutable_attr())["padding"].set_s(padding);
851 (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
852 auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
853 ksize.mutable_list()->add_i(1);
854 ksize.mutable_list()->add_i(src_op.kheight);
855 ksize.mutable_list()->add_i(src_op.kwidth);
856 ksize.mutable_list()->add_i(1);
857 }
858
ConvertAveragePoolOperator(const AveragePoolOperator & src_op,GraphDef * tensorflow_graph)859 void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
860 GraphDef* tensorflow_graph) {
861 auto* avgpool_op = tensorflow_graph->add_node();
862 avgpool_op->set_op("AvgPool");
863 avgpool_op->set_name(src_op.outputs[0]);
864 *avgpool_op->add_input() = src_op.inputs[0];
865 auto& strides = (*avgpool_op->mutable_attr())["strides"];
866 strides.mutable_list()->add_i(1);
867 strides.mutable_list()->add_i(src_op.stride_height);
868 strides.mutable_list()->add_i(src_op.stride_width);
869 strides.mutable_list()->add_i(1);
870 string padding;
871 if (src_op.padding.type == PaddingType::kSame) {
872 padding = "SAME";
873 } else if (src_op.padding.type == PaddingType::kValid) {
874 padding = "VALID";
875 } else {
876 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
877 }
878 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
879 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
880 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
881 ksize.mutable_list()->add_i(1);
882 ksize.mutable_list()->add_i(src_op.kheight);
883 ksize.mutable_list()->add_i(src_op.kwidth);
884 ksize.mutable_list()->add_i(1);
885 }
886
ConvertConcatenationOperator(const Model & model,const ConcatenationOperator & src_op,GraphDef * tensorflow_graph)887 void ConvertConcatenationOperator(const Model& model,
888 const ConcatenationOperator& src_op,
889 GraphDef* tensorflow_graph) {
890 auto* dc_op = tensorflow_graph->add_node();
891 dc_op->set_op("ConcatV2");
892 dc_op->set_name(src_op.outputs[0]);
893 const string dummy_axis = src_op.outputs[0] + "/axis";
894 CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
895 for (const auto& input : src_op.inputs) {
896 *dc_op->add_input() = input;
897 }
898 *dc_op->add_input() = dummy_axis;
899 (*dc_op->mutable_attr())["T"].set_type(
900 GetTensorFlowDataType(model, src_op.inputs[0]));
901 (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
902 (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
903 }
904
ConvertTensorFlowReshapeOperator(const Model & model,const TensorFlowReshapeOperator & src_op,GraphDef * tensorflow_graph)905 void ConvertTensorFlowReshapeOperator(const Model& model,
906 const TensorFlowReshapeOperator& src_op,
907 GraphDef* tensorflow_graph) {
908 auto* reshape_op = tensorflow_graph->add_node();
909 reshape_op->set_op("Reshape");
910 reshape_op->set_name(src_op.outputs[0]);
911 CHECK_EQ(src_op.inputs.size(), 2);
912 *reshape_op->add_input() = src_op.inputs[0];
913 *reshape_op->add_input() = src_op.inputs[1];
914 (*reshape_op->mutable_attr())["T"].set_type(
915 GetTensorFlowDataType(model, src_op.outputs[0]));
916 const auto& shape_array = model.GetArray(src_op.inputs[1]);
917 QCHECK(shape_array.data_type == ArrayDataType::kInt32)
918 << "Only int32 shape is supported.";
919 QCHECK(shape_array.buffer != nullptr)
920 << "Shape inferred at runtime is not supported.";
921 const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
922 CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
923 }
924
ConvertL2PoolOperator(const L2PoolOperator & src_op,GraphDef * tensorflow_graph)925 void ConvertL2PoolOperator(const L2PoolOperator& src_op,
926 GraphDef* tensorflow_graph) {
927 const string square_output = src_op.outputs[0] + "/square";
928 const string avgpool_output = src_op.outputs[0] + "/avgpool";
929
930 auto* square_op = tensorflow_graph->add_node();
931 square_op->set_op("Square");
932 square_op->set_name(square_output);
933 *square_op->add_input() = src_op.inputs[0];
934 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
935
936 string padding;
937 if (src_op.padding.type == PaddingType::kSame) {
938 padding = "SAME";
939 } else if (src_op.padding.type == PaddingType::kValid) {
940 padding = "VALID";
941 } else {
942 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
943 }
944
945 auto* avgpool_op = tensorflow_graph->add_node();
946 avgpool_op->set_op("AvgPool");
947 avgpool_op->set_name(avgpool_output);
948 *avgpool_op->add_input() = square_output;
949 auto& strides = (*avgpool_op->mutable_attr())["strides"];
950 strides.mutable_list()->add_i(1);
951 strides.mutable_list()->add_i(src_op.stride_height);
952 strides.mutable_list()->add_i(src_op.stride_width);
953 strides.mutable_list()->add_i(1);
954
955 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
956 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
957 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
958 ksize.mutable_list()->add_i(1);
959 ksize.mutable_list()->add_i(src_op.kheight);
960 ksize.mutable_list()->add_i(src_op.kwidth);
961 ksize.mutable_list()->add_i(1);
962
963 auto* sqrt_op = tensorflow_graph->add_node();
964 sqrt_op->set_op("Sqrt");
965 sqrt_op->set_name(src_op.outputs[0]);
966 *sqrt_op->add_input() = avgpool_output;
967 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
968 }
969
ConvertSquareOperator(const TensorFlowSquareOperator & src_op,GraphDef * tensorflow_graph)970 void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
971 GraphDef* tensorflow_graph) {
972 auto* square_op = tensorflow_graph->add_node();
973 square_op->set_op("Square");
974 square_op->set_name(src_op.outputs[0]);
975 CHECK_EQ(src_op.inputs.size(), 1);
976 *square_op->add_input() = src_op.inputs[0];
977 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
978 }
979
ConvertSqrtOperator(const TensorFlowSqrtOperator & src_op,GraphDef * tensorflow_graph)980 void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
981 GraphDef* tensorflow_graph) {
982 auto* sqrt_op = tensorflow_graph->add_node();
983 sqrt_op->set_op("Sqrt");
984 sqrt_op->set_name(src_op.outputs[0]);
985 CHECK_EQ(src_op.inputs.size(), 1);
986 *sqrt_op->add_input() = src_op.inputs[0];
987 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
988 }
989
ConvertSplitOperator(const Model & model,const TensorFlowSplitOperator & src_op,GraphDef * tensorflow_graph)990 void ConvertSplitOperator(const Model& model,
991 const TensorFlowSplitOperator& src_op,
992 GraphDef* tensorflow_graph) {
993 auto* split_op = tensorflow_graph->add_node();
994 split_op->set_op("Split");
995 split_op->set_name(src_op.outputs[0]);
996 for (const auto& input : src_op.inputs) {
997 *split_op->add_input() = input;
998 }
999 (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1000 (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1001 const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1002 CHECK(split_dim_array.buffer);
1003 CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1004 const auto& split_dim_data =
1005 split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1006 CHECK_EQ(split_dim_data.size(), 1);
1007 const int split_dim = split_dim_data[0];
1008 CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1009 tensorflow_graph);
1010 }
1011
ConvertCastOperator(const Model & model,const CastOperator & src_op,GraphDef * tensorflow_graph)1012 void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1013 GraphDef* tensorflow_graph) {
1014 auto* cast_op = tensorflow_graph->add_node();
1015 cast_op->set_op("Cast");
1016 cast_op->set_name(src_op.outputs[0]);
1017 CHECK_EQ(src_op.inputs.size(), 1);
1018 *cast_op->add_input() = src_op.inputs[0];
1019
1020 (*cast_op->mutable_attr())["DstT"].set_type(
1021 GetTensorFlowDataType(model, src_op.outputs[0]));
1022 (*cast_op->mutable_attr())["SrcT"].set_type(
1023 GetTensorFlowDataType(model, src_op.inputs[0]));
1024 }
1025
ConvertFloorOperator(const Model & model,const FloorOperator & src_op,GraphDef * tensorflow_graph)1026 void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1027 GraphDef* tensorflow_graph) {
1028 auto* floor_op = tensorflow_graph->add_node();
1029 floor_op->set_op("Floor");
1030 floor_op->set_name(src_op.outputs[0]);
1031 CHECK_EQ(src_op.inputs.size(), 1);
1032 *floor_op->add_input() = src_op.inputs[0];
1033 (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1034 }
1035
ConvertGatherOperator(const Model & model,const GatherOperator & src_op,GraphDef * tensorflow_graph)1036 void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1037 GraphDef* tensorflow_graph) {
1038 auto* gather_op = tensorflow_graph->add_node();
1039 gather_op->set_op("Gather");
1040 gather_op->set_name(src_op.outputs[0]);
1041 CHECK_EQ(src_op.inputs.size(), 2);
1042 *gather_op->add_input() = src_op.inputs[0];
1043 *gather_op->add_input() = src_op.inputs[1];
1044
1045 (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1046 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1047 (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1048 }
1049
ConvertArgMaxOperator(const Model & model,const ArgMaxOperator & src_op,GraphDef * tensorflow_graph)1050 void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1051 GraphDef* tensorflow_graph) {
1052 auto* argmax_op = tensorflow_graph->add_node();
1053 argmax_op->set_op("ArgMax");
1054 argmax_op->set_name(src_op.outputs[0]);
1055 CHECK_EQ(src_op.inputs.size(), 2);
1056 *argmax_op->add_input() = src_op.inputs[0];
1057 *argmax_op->add_input() = src_op.inputs[1];
1058 (*argmax_op->mutable_attr())["T"].set_type(
1059 GetTensorFlowDataType(model, src_op.inputs[0]));
1060 (*argmax_op->mutable_attr())["Tidx"].set_type(
1061 GetTensorFlowDataType(model, src_op.inputs[1]));
1062 (*argmax_op->mutable_attr())["output_type"].set_type(
1063 GetTensorFlowDataType(model, src_op.outputs[0]));
1064 }
1065
ConvertTransposeOperator(const Model & model,const TransposeOperator & src_op,GraphDef * tensorflow_graph)1066 void ConvertTransposeOperator(const Model& model,
1067 const TransposeOperator& src_op,
1068 GraphDef* tensorflow_graph) {
1069 auto* transpose_op = tensorflow_graph->add_node();
1070 transpose_op->set_op("Transpose");
1071 transpose_op->set_name(src_op.outputs[0]);
1072 CHECK_EQ(src_op.inputs.size(), 2);
1073 *transpose_op->add_input() = src_op.inputs[0];
1074 *transpose_op->add_input() = src_op.inputs[1];
1075 (*transpose_op->mutable_attr())["T"].set_type(
1076 GetTensorFlowDataType(model, src_op.inputs[0]));
1077 (*transpose_op->mutable_attr())["Tperm"].set_type(
1078 GetTensorFlowDataType(model, src_op.inputs[1]));
1079 }
1080
ConvertTensorFlowShapeOperator(const Model & model,const TensorFlowShapeOperator & src_op,GraphDef * tensorflow_graph)1081 void ConvertTensorFlowShapeOperator(const Model& model,
1082 const TensorFlowShapeOperator& src_op,
1083 GraphDef* tensorflow_graph) {
1084 auto* shape_op = tensorflow_graph->add_node();
1085 shape_op->set_op("Shape");
1086 shape_op->set_name(src_op.outputs[0]);
1087 CHECK_EQ(src_op.inputs.size(), 1);
1088 *shape_op->add_input() = src_op.inputs[0];
1089 (*shape_op->mutable_attr())["T"].set_type(
1090 GetTensorFlowDataType(model, src_op.inputs[0]));
1091 (*shape_op->mutable_attr())["out_type"].set_type(
1092 GetTensorFlowDataType(model, src_op.outputs[0]));
1093 }
1094
ConvertRankOperator(const Model & model,const RankOperator & src_op,GraphDef * tensorflow_graph)1095 void ConvertRankOperator(const Model& model, const RankOperator& src_op,
1096 GraphDef* tensorflow_graph) {
1097 auto* rank_op = tensorflow_graph->add_node();
1098 rank_op->set_op("Rank");
1099 rank_op->set_name(src_op.outputs[0]);
1100 CHECK_EQ(src_op.inputs.size(), 1);
1101 *rank_op->add_input() = src_op.inputs[0];
1102 (*rank_op->mutable_attr())["T"].set_type(
1103 GetTensorFlowDataType(model, src_op.inputs[0]));
1104 }
1105
ConvertRangeOperator(const Model & model,const RangeOperator & src_op,GraphDef * tensorflow_graph)1106 void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1107 GraphDef* tensorflow_graph) {
1108 auto* range_op = tensorflow_graph->add_node();
1109 range_op->set_op("Range");
1110 range_op->set_name(src_op.outputs[0]);
1111 CHECK_EQ(src_op.inputs.size(), 3);
1112 *range_op->add_input() = src_op.inputs[0];
1113 *range_op->add_input() = src_op.inputs[1];
1114 *range_op->add_input() = src_op.inputs[2];
1115 (*range_op->mutable_attr())["Tidx"].set_type(
1116 GetTensorFlowDataType(src_op.dtype));
1117 }
1118
ConvertStackOperator(const Model & model,const StackOperator & src_op,GraphDef * tensorflow_graph)1119 void ConvertStackOperator(const Model& model, const StackOperator& src_op,
1120 GraphDef* tensorflow_graph) {
1121 auto* stack_op = tensorflow_graph->add_node();
1122 stack_op->set_op("Stack");
1123 stack_op->set_name(src_op.outputs[0]);
1124 for (const auto& input : src_op.inputs) {
1125 *stack_op->add_input() = input;
1126 }
1127 (*stack_op->mutable_attr())["elem_type"].set_type(
1128 GetTensorFlowDataType(model, src_op.outputs[0]));
1129 (*stack_op->mutable_attr())["axis"].set_i(src_op.axis);
1130 }
1131
ConvertFillOperator(const Model & model,const FillOperator & src_op,GraphDef * tensorflow_graph)1132 void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1133 GraphDef* tensorflow_graph) {
1134 auto* fill_op = tensorflow_graph->add_node();
1135 fill_op->set_op("Fill");
1136 fill_op->set_name(src_op.outputs[0]);
1137 CHECK_EQ(src_op.inputs.size(), 2);
1138 *fill_op->add_input() = src_op.inputs[0];
1139 *fill_op->add_input() = src_op.inputs[1];
1140 (*fill_op->mutable_attr())["index_type"].set_type(
1141 GetTensorFlowDataType(model, src_op.inputs[0]));
1142 (*fill_op->mutable_attr())["T"].set_type(
1143 GetTensorFlowDataType(model, src_op.inputs[1]));
1144 }
1145
ConvertFloorDivOperator(const Model & model,const FloorDivOperator & src_op,GraphDef * tensorflow_graph)1146 void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1147 GraphDef* tensorflow_graph) {
1148 auto* floor_div_op = tensorflow_graph->add_node();
1149 floor_div_op->set_op("FloorDiv");
1150 floor_div_op->set_name(src_op.outputs[0]);
1151 CHECK_EQ(src_op.inputs.size(), 2);
1152 *floor_div_op->add_input() = src_op.inputs[0];
1153 *floor_div_op->add_input() = src_op.inputs[1];
1154 (*floor_div_op->mutable_attr())["T"].set_type(
1155 GetTensorFlowDataType(model, src_op.inputs[0]));
1156 }
1157
ConvertExpandDimsOperator(const Model & model,const ExpandDimsOperator & src_op,GraphDef * tensorflow_graph)1158 void ConvertExpandDimsOperator(const Model& model,
1159 const ExpandDimsOperator& src_op,
1160 GraphDef* tensorflow_graph) {
1161 auto* expand_dims_op = tensorflow_graph->add_node();
1162 expand_dims_op->set_op("ExpandDims");
1163 expand_dims_op->set_name(src_op.outputs[0]);
1164 CHECK_EQ(src_op.inputs.size(), 2);
1165 *expand_dims_op->add_input() = src_op.inputs[0];
1166 *expand_dims_op->add_input() = src_op.inputs[1];
1167 (*expand_dims_op->mutable_attr())["T"].set_type(
1168 GetTensorFlowDataType(model, src_op.inputs[0]));
1169 (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1170 GetTensorFlowDataType(model, src_op.inputs[1]));
1171 }
1172
ConvertResizeBilinearOperator(const Model & model,const ResizeBilinearOperator & src_op,GraphDef * tensorflow_graph)1173 void ConvertResizeBilinearOperator(const Model& model,
1174 const ResizeBilinearOperator& src_op,
1175 GraphDef* tensorflow_graph) {
1176 auto* resize_op = tensorflow_graph->add_node();
1177 resize_op->set_op("ResizeBilinear");
1178 resize_op->set_name(src_op.outputs[0]);
1179 CHECK_EQ(src_op.inputs.size(), 2);
1180 *resize_op->add_input() = src_op.inputs[0];
1181 *resize_op->add_input() = src_op.inputs[1];
1182 (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1183 (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1184 }
1185
1186 namespace {
1187 // TODO(aselle): Remove when available in absl
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)1188 absl::string_view FindLongestCommonPrefix(absl::string_view a,
1189 absl::string_view b) {
1190 if (a.empty() || b.empty()) return absl::string_view();
1191
1192 const char* pa = a.data();
1193 const char* pb = b.data();
1194 string::difference_type count = 0;
1195 const string::difference_type limit = std::min(a.size(), b.size());
1196 while (count < limit && *pa == *pb) {
1197 ++pa;
1198 ++pb;
1199 ++count;
1200 }
1201
1202 return absl::string_view(a.data(), count);
1203 }
1204 } // namespace
1205
ConvertLstmCellOperator(const Model & model,const LstmCellOperator & src_op,GraphDef * tensorflow_graph)1206 void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1207 GraphDef* tensorflow_graph) {
1208 // Find the base name
1209 const string base(
1210 FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1211 src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1212
1213 // Concatenate inputs
1214 const string concat_output = base + "basic_lstm_cell/concat";
1215 // Op names have been chosen to match the tf.slim LSTM naming
1216 // as closely as possible.
1217 const int axis =
1218 model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1219 .shape()
1220 .dimensions_count() -
1221 1;
1222 // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1223 // works the same since the tensor has the same underlying data layout.
1224 const string axis_output = concat_output + "/axis";
1225 CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1226 auto* concat_op = tensorflow_graph->add_node();
1227 concat_op->set_op("ConcatV2");
1228 concat_op->set_name(concat_output);
1229 *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1230 *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1231 *concat_op->add_input() = axis_output;
1232 (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1233 (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1234 (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
1235
1236 // Write weights
1237 const string weights_output = base + "weights";
1238 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1239 const string weights_name = WalkUpToConstantArray(
1240 model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1241 const auto& weights_array = model.GetArray(weights_name);
1242 // Convert 4D FullyConnected weights into 2D matrix
1243 const auto& weights_shape = weights_array.shape();
1244 CHECK_EQ(weights_shape.dimensions_count(), 2);
1245 CHECK(weights_array.buffer);
1246 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1247 const float* weights_data =
1248 weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1249 ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1250 AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1251
1252 // Fully connected matrix multiply
1253 const string matmul_output = base + "MatMul";
1254 auto* matmul_op = tensorflow_graph->add_node();
1255 matmul_op->set_op("MatMul");
1256 matmul_op->set_name(matmul_output);
1257 *matmul_op->add_input() = concat_output;
1258 *matmul_op->add_input() = weights_output;
1259 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1260 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1261 (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1262
1263 // Write biases
1264 const string biases_output = base + "biases";
1265 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1266 const string bias_name = WalkUpToConstantArray(
1267 model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1268 const auto& bias_array = model.GetArray(bias_name);
1269 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1270 Shape bias_shape_1d = bias_array.shape();
1271 UnextendShape(&bias_shape_1d, 1);
1272 CHECK(bias_array.buffer);
1273 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1274 const float* bias_data =
1275 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1276 ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1277 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1278 tensorflow_graph,
1279 LegacyScalarPolicy::kDoCreateLegacyScalars);
1280
1281 // Add biases
1282 string biasadd_output = base + "BiasAdd";
1283 auto* biasadd_op = tensorflow_graph->add_node();
1284 biasadd_op->set_op("BiasAdd");
1285 biasadd_op->set_name(biasadd_output);
1286 biasadd_op->add_input(matmul_output);
1287 biasadd_op->add_input(biases_output);
1288 (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1289 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1290
1291 // Split
1292 string split_dim_output = base + "split/split_dim";
1293 // The dimension is the same as the concatenation dimension
1294 CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1295 string split_output = base + "split";
1296 auto* split_op = tensorflow_graph->add_node();
1297 split_op->set_op("Split");
1298 split_op->set_name(split_output);
1299 *split_op->add_input() = split_dim_output;
1300 *split_op->add_input() = biasadd_output;
1301 (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1302 (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs
1303
1304 // Activation functions and memory computations
1305 const string tanh_0_output = base + "Tanh";
1306 auto* tanh_0_op = tensorflow_graph->add_node();
1307 tanh_0_op->set_op("Tanh");
1308 tanh_0_op->set_name(tanh_0_output);
1309 *tanh_0_op->add_input() = split_output + ":1";
1310 (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1311
1312 const string sigmoid_1_output = base + "Sigmoid_1";
1313 auto* logistic_1_op = tensorflow_graph->add_node();
1314 logistic_1_op->set_op("Sigmoid");
1315 logistic_1_op->set_name(sigmoid_1_output);
1316 *logistic_1_op->add_input() = split_output;
1317 (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1318
1319 const string mul_1_output = base + "mul_1";
1320 auto* mul_1_op = tensorflow_graph->add_node();
1321 mul_1_op->set_op("Mul");
1322 mul_1_op->set_name(mul_1_output);
1323 *mul_1_op->add_input() = sigmoid_1_output;
1324 *mul_1_op->add_input() = tanh_0_output;
1325 (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1326
1327 const string sigmoid_0_output = base + "Sigmoid";
1328 auto* logistic_2_op = tensorflow_graph->add_node();
1329 logistic_2_op->set_op("Sigmoid");
1330 logistic_2_op->set_name(sigmoid_0_output);
1331 *logistic_2_op->add_input() = split_output + ":2";
1332 (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1333
1334 const string sigmoid_2_output = base + "Sigmoid_2";
1335 auto* logistic_3_op = tensorflow_graph->add_node();
1336 logistic_3_op->set_op("Sigmoid");
1337 logistic_3_op->set_name(sigmoid_2_output);
1338 *logistic_3_op->add_input() = split_output + ":3";
1339 (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1340
1341 const string mul_0_output = base + "mul";
1342 auto* mul_0_op = tensorflow_graph->add_node();
1343 mul_0_op->set_op("Mul");
1344 mul_0_op->set_name(mul_0_output);
1345 *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1346 *mul_0_op->add_input() = sigmoid_0_output;
1347 (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1348
1349 const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1350 auto* add_1_op = tensorflow_graph->add_node();
1351 add_1_op->set_op("Add");
1352 add_1_op->set_name(add_1_output);
1353 *add_1_op->add_input() = mul_0_output;
1354 *add_1_op->add_input() = mul_1_output;
1355 (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1356
1357 const string tanh_1_output = base + "Tanh_1";
1358 auto* tanh_1_op = tensorflow_graph->add_node();
1359 tanh_1_op->set_op("Tanh");
1360 tanh_1_op->set_name(tanh_1_output);
1361 *tanh_1_op->add_input() = add_1_output;
1362 (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1363
1364 const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1365 auto* mul_2_op = tensorflow_graph->add_node();
1366 mul_2_op->set_op("Mul");
1367 mul_2_op->set_name(mul_2_output);
1368 *mul_2_op->add_input() = tanh_1_output;
1369 *mul_2_op->add_input() = sigmoid_2_output;
1370 (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1371 }
1372
ConvertSpaceToBatchNDOperator(const Model & model,const SpaceToBatchNDOperator & src_op,GraphDef * tensorflow_graph)1373 void ConvertSpaceToBatchNDOperator(const Model& model,
1374 const SpaceToBatchNDOperator& src_op,
1375 GraphDef* tensorflow_graph) {
1376 auto* new_op = tensorflow_graph->add_node();
1377 new_op->set_op("SpaceToBatchND");
1378 new_op->set_name(src_op.outputs[0]);
1379 CHECK_EQ(src_op.inputs.size(), 3);
1380 *new_op->add_input() = src_op.inputs[0];
1381 *new_op->add_input() = src_op.inputs[1];
1382 *new_op->add_input() = src_op.inputs[2];
1383 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1384 (*new_op->mutable_attr())["T"].set_type(params_type);
1385 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1386 (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1387 }
1388
ConvertBatchToSpaceNDOperator(const Model & model,const BatchToSpaceNDOperator & src_op,GraphDef * tensorflow_graph)1389 void ConvertBatchToSpaceNDOperator(const Model& model,
1390 const BatchToSpaceNDOperator& src_op,
1391 GraphDef* tensorflow_graph) {
1392 auto* new_op = tensorflow_graph->add_node();
1393 new_op->set_op("BatchToSpaceND");
1394 new_op->set_name(src_op.outputs[0]);
1395 CHECK_EQ(src_op.inputs.size(), 3);
1396 *new_op->add_input() = src_op.inputs[0];
1397 *new_op->add_input() = src_op.inputs[1];
1398 *new_op->add_input() = src_op.inputs[2];
1399 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1400 (*new_op->mutable_attr())["T"].set_type(params_type);
1401 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1402 (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1403 }
1404
ConvertPadOperator(const Model & model,const PadOperator & src_op,GraphDef * tensorflow_graph)1405 void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1406 GraphDef* tensorflow_graph) {
1407 auto* new_op = tensorflow_graph->add_node();
1408 new_op->set_op("Pad");
1409 new_op->set_name(src_op.outputs[0]);
1410 CHECK_EQ(src_op.inputs.size(), 2);
1411 *new_op->add_input() = src_op.inputs[0];
1412 *new_op->add_input() = src_op.inputs[1];
1413
1414 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1415 (*new_op->mutable_attr())["T"].set_type(params_type);
1416
1417 // Create the params tensor.
1418 auto* params_op = tensorflow_graph->add_node();
1419 params_op->set_op("Const");
1420 params_op->set_name(src_op.inputs[1]);
1421 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1422 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1423 tensor->set_dtype(DT_INT32);
1424
1425 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1426 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1427 tensor->add_int_val(src_op.left_padding[i]);
1428 tensor->add_int_val(src_op.right_padding[i]);
1429 }
1430 auto* shape = tensor->mutable_tensor_shape();
1431 shape->add_dim()->set_size(src_op.left_padding.size());
1432 shape->add_dim()->set_size(2);
1433 }
1434
CreateSliceInput(const string & input_name,const std::vector<int> & values,GraphDef * tensorflow_graph)1435 void CreateSliceInput(const string& input_name, const std::vector<int>& values,
1436 GraphDef* tensorflow_graph) {
1437 auto* params_op = tensorflow_graph->add_node();
1438 params_op->set_op("Const");
1439 params_op->set_name(input_name);
1440 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1441 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1442 tensor->set_dtype(DT_INT32);
1443
1444 for (int i = 0; i < values.size(); ++i) {
1445 tensor->add_int_val(values[i]);
1446 }
1447 auto* shape = tensor->mutable_tensor_shape();
1448 shape->add_dim()->set_size(values.size());
1449 }
1450
ConvertStridedSliceOperator(const Model & model,const StridedSliceOperator & src_op,GraphDef * tensorflow_graph)1451 void ConvertStridedSliceOperator(const Model& model,
1452 const StridedSliceOperator& src_op,
1453 GraphDef* tensorflow_graph) {
1454 auto* new_op = tensorflow_graph->add_node();
1455 new_op->set_op("StridedSlice");
1456 new_op->set_name(src_op.outputs[0]);
1457 CHECK_EQ(src_op.inputs.size(), 4);
1458 *new_op->add_input() = src_op.inputs[0];
1459 *new_op->add_input() = src_op.inputs[1];
1460 *new_op->add_input() = src_op.inputs[2];
1461 *new_op->add_input() = src_op.inputs[3];
1462
1463 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1464 (*new_op->mutable_attr())["T"].set_type(params_type);
1465
1466 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1467 (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1468 (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1469 (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1470 (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1471 (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1472
1473 // Create tensors for start/stop indices and strides.
1474 CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1475 CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1476 CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1477 }
1478
ConvertSliceOperator(const Model & model,const SliceOperator & src_op,GraphDef * tensorflow_graph)1479 void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1480 GraphDef* tensorflow_graph) {
1481 auto* new_op = tensorflow_graph->add_node();
1482 new_op->set_op("Slice");
1483 new_op->set_name(src_op.outputs[0]);
1484 CHECK_EQ(src_op.inputs.size(), 3);
1485 *new_op->add_input() = src_op.inputs[0];
1486 *new_op->add_input() = src_op.inputs[1];
1487 *new_op->add_input() = src_op.inputs[2];
1488
1489 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1490 (*new_op->mutable_attr())["T"].set_type(params_type);
1491 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1492
1493 // Create tensors for begin and size inputs.
1494 CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1495 CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1496 }
1497
ConvertMeanOperator(const Model & model,const MeanOperator & src_op,GraphDef * tensorflow_graph)1498 void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
1499 GraphDef* tensorflow_graph) {
1500 auto* new_op = tensorflow_graph->add_node();
1501 new_op->set_op("Mean");
1502 new_op->set_name(src_op.outputs[0]);
1503 CHECK_EQ(src_op.inputs.size(), 2);
1504 *new_op->add_input() = src_op.inputs[0];
1505 *new_op->add_input() = src_op.inputs[1];
1506
1507 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1508 (*new_op->mutable_attr())["T"].set_type(params_type);
1509
1510 if (src_op.keep_dims) {
1511 (*new_op->mutable_attr())["keep_dims"].set_b(true);
1512 }
1513
1514 // Create the params tensor.
1515 auto* params_op = tensorflow_graph->add_node();
1516 params_op->set_op("Const");
1517 params_op->set_name(src_op.inputs[1]);
1518 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1519 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1520 tensor->set_dtype(DT_INT32);
1521
1522 for (int i = 0; i < src_op.axis.size(); ++i) {
1523 tensor->add_int_val(src_op.axis[i]);
1524 }
1525 auto* shape = tensor->mutable_tensor_shape();
1526 shape->add_dim()->set_size(src_op.axis.size());
1527 }
1528
ConvertSqueezeOperator(const Model & model,const SqueezeOperator & src_op,GraphDef * tensorflow_graph)1529 void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1530 GraphDef* tensorflow_graph) {
1531 auto* new_op = tensorflow_graph->add_node();
1532 new_op->set_op("Squeeze");
1533 new_op->set_name(src_op.outputs[0]);
1534 CHECK_EQ(src_op.inputs.size(), 1);
1535 *new_op->add_input() = src_op.inputs[0];
1536
1537 const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1538 (*new_op->mutable_attr())["T"].set_type(params_type);
1539
1540 auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1541 for (int i : src_op.squeeze_dims) {
1542 squeeze_dims.mutable_list()->add_i(i);
1543 }
1544 }
1545
ConvertSubOperator(const Model & model,const SubOperator & src_op,GraphDef * tensorflow_graph)1546 void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1547 GraphDef* tensorflow_graph) {
1548 auto* sub_op = tensorflow_graph->add_node();
1549 sub_op->set_op("Sub");
1550 sub_op->set_name(src_op.outputs[0]);
1551 CHECK_EQ(src_op.inputs.size(), 2);
1552 *sub_op->add_input() = src_op.inputs[0];
1553 *sub_op->add_input() = src_op.inputs[1];
1554 const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1555 (*sub_op->mutable_attr())["T"].set_type(data_type);
1556 }
1557
ConvertTensorFlowMinimumOperator(const Model & model,const TensorFlowMinimumOperator & src_op,GraphDef * tensorflow_graph)1558 void ConvertTensorFlowMinimumOperator(const Model& model,
1559 const TensorFlowMinimumOperator& src_op,
1560 GraphDef* tensorflow_graph) {
1561 auto* sub_op = tensorflow_graph->add_node();
1562 sub_op->set_op("Minimum");
1563 sub_op->set_name(src_op.outputs[0]);
1564 CHECK_EQ(src_op.inputs.size(), 2);
1565 *sub_op->add_input() = src_op.inputs[0];
1566 *sub_op->add_input() = src_op.inputs[1];
1567 const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1568 (*sub_op->mutable_attr())["T"].set_type(data_type);
1569 }
1570
ConvertTensorFlowMaximumOperator(const Model & model,const TensorFlowMaximumOperator & src_op,GraphDef * tensorflow_graph)1571 void ConvertTensorFlowMaximumOperator(const Model& model,
1572 const TensorFlowMaximumOperator& src_op,
1573 GraphDef* tensorflow_graph) {
1574 auto* sub_op = tensorflow_graph->add_node();
1575 sub_op->set_op("Maximum");
1576 sub_op->set_name(src_op.outputs[0]);
1577 CHECK_EQ(src_op.inputs.size(), 2);
1578 *sub_op->add_input() = src_op.inputs[0];
1579 *sub_op->add_input() = src_op.inputs[1];
1580 const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
1581 (*sub_op->mutable_attr())["T"].set_type(data_type);
1582 }
1583
ConvertTopKV2Operator(const Model & model,const TopKV2Operator & src_op,GraphDef * tensorflow_graph)1584 void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1585 GraphDef* tensorflow_graph) {
1586 auto* topk_op = tensorflow_graph->add_node();
1587 topk_op->set_op("TOPKV2");
1588 topk_op->set_name(src_op.outputs[0]);
1589 CHECK_EQ(src_op.inputs.size(), 2);
1590 *topk_op->add_input() = src_op.inputs[0];
1591 *topk_op->add_input() = src_op.inputs[1];
1592 (*topk_op->mutable_attr())["sorted"].set_b(true);
1593 }
1594
ConvertOperator(const Model & model,const Operator & src_op,GraphDef * tensorflow_graph)1595 void ConvertOperator(const Model& model, const Operator& src_op,
1596 GraphDef* tensorflow_graph) {
1597 if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
1598 LOG(FATAL)
1599 << "Unsupported: the input model has a fused activation function";
1600 }
1601
1602 if (src_op.type == OperatorType::kConv) {
1603 ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
1604 tensorflow_graph);
1605 } else if (src_op.type == OperatorType::kDepthwiseConv) {
1606 ConvertDepthwiseConvOperator(
1607 model, static_cast<const DepthwiseConvOperator&>(src_op),
1608 tensorflow_graph);
1609 } else if (src_op.type == OperatorType::kDepthToSpace) {
1610 ConvertDepthToSpaceOperator(
1611 model, static_cast<const DepthToSpaceOperator&>(src_op),
1612 tensorflow_graph);
1613 } else if (src_op.type == OperatorType::kSpaceToDepth) {
1614 ConvertSpaceToDepthOperator(
1615 model, static_cast<const SpaceToDepthOperator&>(src_op),
1616 tensorflow_graph);
1617 } else if (src_op.type == OperatorType::kFullyConnected) {
1618 ConvertFullyConnectedOperator(
1619 model, static_cast<const FullyConnectedOperator&>(src_op),
1620 tensorflow_graph);
1621 } else if (src_op.type == OperatorType::kAdd) {
1622 ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
1623 tensorflow_graph);
1624 } else if (src_op.type == OperatorType::kAddN) {
1625 ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
1626 tensorflow_graph);
1627 } else if (src_op.type == OperatorType::kMul) {
1628 ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
1629 tensorflow_graph);
1630 } else if (src_op.type == OperatorType::kRelu) {
1631 ConvertReluOperator(static_cast<const ReluOperator&>(src_op),
1632 tensorflow_graph);
1633 } else if (src_op.type == OperatorType::kRelu1) {
1634 ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
1635 tensorflow_graph);
1636 } else if (src_op.type == OperatorType::kRelu6) {
1637 ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
1638 tensorflow_graph);
1639 } else if (src_op.type == OperatorType::kLogistic) {
1640 ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
1641 tensorflow_graph);
1642 } else if (src_op.type == OperatorType::kTanh) {
1643 ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
1644 tensorflow_graph);
1645 } else if (src_op.type == OperatorType::kL2Normalization) {
1646 ConvertL2NormalizationOperator(
1647 static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
1648 } else if (src_op.type == OperatorType::kSoftmax) {
1649 ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
1650 tensorflow_graph);
1651 } else if (src_op.type == OperatorType::kLogSoftmax) {
1652 ConvertLogSoftmaxOperator(model,
1653 static_cast<const LogSoftmaxOperator&>(src_op),
1654 tensorflow_graph);
1655 } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
1656 ConvertLocalResponseNormalizationOperator(
1657 static_cast<const LocalResponseNormalizationOperator&>(src_op),
1658 tensorflow_graph);
1659 } else if (src_op.type == OperatorType::kLstmCell) {
1660 ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
1661 tensorflow_graph);
1662 } else if (src_op.type == OperatorType::kMaxPool) {
1663 ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
1664 tensorflow_graph);
1665 } else if (src_op.type == OperatorType::kAveragePool) {
1666 ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
1667 tensorflow_graph);
1668 } else if (src_op.type == OperatorType::kConcatenation) {
1669 ConvertConcatenationOperator(
1670 model, static_cast<const ConcatenationOperator&>(src_op),
1671 tensorflow_graph);
1672 } else if (src_op.type == OperatorType::kTensorFlowReshape) {
1673 ConvertTensorFlowReshapeOperator(
1674 model, static_cast<const TensorFlowReshapeOperator&>(src_op),
1675 tensorflow_graph);
1676 } else if (src_op.type == OperatorType::kL2Pool) {
1677 ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
1678 tensorflow_graph);
1679 } else if (src_op.type == OperatorType::kTensorFlowSquare) {
1680 ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
1681 tensorflow_graph);
1682 } else if (src_op.type == OperatorType::kTensorFlowSqrt) {
1683 ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
1684 tensorflow_graph);
1685 } else if (src_op.type == OperatorType::kTensorFlowSplit) {
1686 ConvertSplitOperator(model,
1687 static_cast<const TensorFlowSplitOperator&>(src_op),
1688 tensorflow_graph);
1689 } else if (src_op.type == OperatorType::kFakeQuant) {
1690 ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
1691 tensorflow_graph);
1692 } else if (src_op.type == OperatorType::kCast) {
1693 ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
1694 tensorflow_graph);
1695 } else if (src_op.type == OperatorType::kFloor) {
1696 ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
1697 tensorflow_graph);
1698 } else if (src_op.type == OperatorType::kGather) {
1699 ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
1700 tensorflow_graph);
1701 } else if (src_op.type == OperatorType::kResizeBilinear) {
1702 ConvertResizeBilinearOperator(
1703 model, static_cast<const ResizeBilinearOperator&>(src_op),
1704 tensorflow_graph);
1705 } else if (src_op.type == OperatorType::kSpaceToBatchND) {
1706 ConvertSpaceToBatchNDOperator(
1707 model, static_cast<const SpaceToBatchNDOperator&>(src_op),
1708 tensorflow_graph);
1709 } else if (src_op.type == OperatorType::kBatchToSpaceND) {
1710 ConvertBatchToSpaceNDOperator(
1711 model, static_cast<const BatchToSpaceNDOperator&>(src_op),
1712 tensorflow_graph);
1713 } else if (src_op.type == OperatorType::kPad) {
1714 ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
1715 tensorflow_graph);
1716 } else if (src_op.type == OperatorType::kStridedSlice) {
1717 ConvertStridedSliceOperator(
1718 model, static_cast<const StridedSliceOperator&>(src_op),
1719 tensorflow_graph);
1720 } else if (src_op.type == OperatorType::kMean) {
1721 ConvertMeanOperator(model, static_cast<const MeanOperator&>(src_op),
1722 tensorflow_graph);
1723 } else if (src_op.type == OperatorType::kSub) {
1724 ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
1725 tensorflow_graph);
1726 } else if (src_op.type == OperatorType::kTensorFlowMinimum) {
1727 ConvertTensorFlowMinimumOperator(
1728 model, static_cast<const TensorFlowMinimumOperator&>(src_op),
1729 tensorflow_graph);
1730 } else if (src_op.type == OperatorType::kTensorFlowMaximum) {
1731 ConvertTensorFlowMaximumOperator(
1732 model, static_cast<const TensorFlowMaximumOperator&>(src_op),
1733 tensorflow_graph);
1734 } else if (src_op.type == OperatorType::kSqueeze) {
1735 ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
1736 tensorflow_graph);
1737 } else if (src_op.type == OperatorType::kSlice) {
1738 ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
1739 tensorflow_graph);
1740 } else if (src_op.type == OperatorType::kArgMax) {
1741 ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
1742 tensorflow_graph);
1743 } else if (src_op.type == OperatorType::kTopK_V2) {
1744 ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
1745 tensorflow_graph);
1746 } else if (src_op.type == OperatorType::kTranspose) {
1747 ConvertTransposeOperator(
1748 model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
1749 } else if (src_op.type == OperatorType::kTensorFlowShape) {
1750 ConvertTensorFlowShapeOperator(
1751 model, static_cast<const TensorFlowShapeOperator&>(src_op),
1752 tensorflow_graph);
1753 } else if (src_op.type == OperatorType::kRank) {
1754 ConvertRankOperator(model, static_cast<const RankOperator&>(src_op),
1755 tensorflow_graph);
1756 } else if (src_op.type == OperatorType::kRange) {
1757 ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
1758 tensorflow_graph);
1759 } else if (src_op.type == OperatorType::kStack) {
1760 ConvertStackOperator(model, static_cast<const StackOperator&>(src_op),
1761 tensorflow_graph);
1762 } else if (src_op.type == OperatorType::kFill) {
1763 ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
1764 tensorflow_graph);
1765 } else if (src_op.type == OperatorType::kFloorDiv) {
1766 ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
1767 tensorflow_graph);
1768 } else if (src_op.type == OperatorType::kExpandDims) {
1769 ConvertExpandDimsOperator(model,
1770 static_cast<const ExpandDimsOperator&>(src_op),
1771 tensorflow_graph);
1772 } else {
1773 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
1774 }
1775 }
1776
AddPlaceholder(const string & name,ArrayDataType type,GraphDef * tensorflow_graph)1777 void AddPlaceholder(const string& name, ArrayDataType type,
1778 GraphDef* tensorflow_graph) {
1779 auto* placeholder = tensorflow_graph->add_node();
1780 placeholder->set_op("Placeholder");
1781 switch (type) {
1782 case ArrayDataType::kBool:
1783 (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
1784 break;
1785 case ArrayDataType::kFloat:
1786 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
1787 break;
1788 case ArrayDataType::kUint8:
1789 (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
1790 break;
1791 case ArrayDataType::kInt32:
1792 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
1793 break;
1794 case ArrayDataType::kInt64:
1795 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
1796 break;
1797 default:
1798 LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
1799 }
1800 placeholder->set_name(name);
1801 }
1802
AddPlaceholderForRNNState(const Model & model,const string & name,int size,GraphDef * tensorflow_graph)1803 void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
1804 GraphDef* tensorflow_graph) {
1805 auto* placeholder = tensorflow_graph->add_node();
1806 placeholder->set_op("Placeholder");
1807 placeholder->set_name(name);
1808 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
1809
1810 auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
1811 const auto& state_array = model.GetArray(name);
1812 if (state_array.has_shape()) {
1813 const auto& state_shape = state_array.shape();
1814 const int kDims = state_shape.dimensions_count();
1815 for (int i = 0; i < kDims; ++i) {
1816 shape->add_dim()->set_size(state_shape.dims(i));
1817 }
1818 } else {
1819 shape->add_dim()->set_size(1);
1820 shape->add_dim()->set_size(size);
1821 }
1822 }
1823
ExportTensorFlowGraphDefImplementation(const Model & model,GraphDef * tensorflow_graph)1824 void ExportTensorFlowGraphDefImplementation(const Model& model,
1825 GraphDef* tensorflow_graph) {
1826 for (const auto& input_array : model.flags.input_arrays()) {
1827 AddPlaceholder(input_array.name(),
1828 model.GetArray(input_array.name()).data_type,
1829 tensorflow_graph);
1830 }
1831 for (const auto& rnn_state : model.flags.rnn_states()) {
1832 AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
1833 tensorflow_graph);
1834 }
1835 for (const auto& op : model.operators) {
1836 ConvertOperator(model, *op, tensorflow_graph);
1837 }
1838 // Generically export arrays that haven't been exported already
1839 // by the above operators export. It's important that this comes
1840 // after, as some operators need to export arrays that they reference
1841 // in a specific way, rather than in the generic way done below.
1842 for (const auto& array_pair : model.GetArrayMap()) {
1843 const string& array_name = array_pair.first;
1844 const auto& array = *array_pair.second;
1845 if (array.buffer) {
1846 switch (array.data_type) {
1847 case ArrayDataType::kFloat:
1848 ConvertFloatTensorConst(model, array_name, tensorflow_graph);
1849 break;
1850 case ArrayDataType::kInt32:
1851 ConvertIntTensorConst(model, array_name, tensorflow_graph);
1852 break;
1853 default:
1854 break;
1855 }
1856 }
1857 }
1858 }
1859 } // namespace
1860
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model * model)1861 void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
1862 for (const auto& array_kv : model->GetArrayMap()) {
1863 const string& array_name = array_kv.first;
1864 Array& array = *array_kv.second;
1865 if (!array.buffer || !array.minmax) {
1866 continue;
1867 }
1868 const string& wrapped_array_name =
1869 AvailableArrayName(*model, array_name + "/data");
1870 Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
1871 wrapped_array.data_type = array.data_type;
1872 wrapped_array.copy_shape(array.shape());
1873 wrapped_array.buffer = std::move(array.buffer);
1874 FakeQuantOperator* fakequant_op = new FakeQuantOperator;
1875 fakequant_op->inputs = {wrapped_array_name};
1876 fakequant_op->outputs = {array_name};
1877 fakequant_op->minmax.reset(new MinMax);
1878 *fakequant_op->minmax = *array.minmax;
1879 const auto& it = FindOpWithInput(*model, array_name);
1880 model->operators.emplace(it, fakequant_op);
1881 }
1882 CheckInvariants(*model);
1883 }
1884
ExportTensorFlowGraphDef(const Model & model,string * output_file_contents)1885 void ExportTensorFlowGraphDef(const Model& model,
1886 string* output_file_contents) {
1887 CHECK(output_file_contents->empty());
1888 GraphDef tensorflow_graph;
1889 ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
1890 LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
1891 CHECK(tensorflow_graph.SerializeToString(output_file_contents));
1892 }
1893 } // namespace toco
1894