1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/export.h"
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/schema/schema_generated.h"
20 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
21 #include "tensorflow/lite/toco/tflite/operator.h"
22 #include "tensorflow/lite/toco/tflite/types.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24
25 namespace toco {
26 namespace tflite {
27 namespace {
28
29 using ::testing::ElementsAre;
30
31 class ExportTest : public ::testing::Test {
32 protected:
ResetOperators()33 void ResetOperators() { input_model_.operators.clear(); }
AddTensorsByName(std::initializer_list<string> names)34 void AddTensorsByName(std::initializer_list<string> names) {
35 for (const string& name : names) {
36 input_model_.GetOrCreateArray(name);
37 }
38 }
AddOperatorsByName(std::initializer_list<string> names)39 void AddOperatorsByName(std::initializer_list<string> names) {
40 for (const string& name : names) {
41 if (name == "Conv") {
42 auto* op = new ConvOperator;
43 op->padding.type = PaddingType::kSame;
44 op->inputs = {"input", "filter"};
45 op->outputs = {"output"};
46 Array& input_array = input_model_.GetOrCreateArray(op->inputs[0]);
47 Array& filter_array = input_model_.GetOrCreateArray(op->inputs[1]);
48 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
49 input_array.data_type = ArrayDataType::kFloat;
50 filter_array.data_type = ArrayDataType::kFloat;
51 output_array.data_type = ArrayDataType::kFloat;
52 input_model_.operators.emplace_back(op);
53 } else if (name == "Add") {
54 auto* op = new AddOperator;
55 op->inputs = {"input1", "input2"};
56 op->outputs = {"output"};
57 Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
58 Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
59 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
60 input1_array.data_type = ArrayDataType::kFloat;
61 input2_array.data_type = ArrayDataType::kFloat;
62 output_array.data_type = ArrayDataType::kFloat;
63 input_model_.operators.emplace_back(op);
64 } else if (name == "Sub") {
65 auto* op = new SubOperator;
66 op->inputs = {"input1", "input2"};
67 op->outputs = {"output"};
68 Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
69 Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
70 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
71 input1_array.data_type = ArrayDataType::kFloat;
72 input2_array.data_type = ArrayDataType::kFloat;
73 output_array.data_type = ArrayDataType::kFloat;
74 input_model_.operators.emplace_back(op);
75 } else if (name == "Assert") {
76 auto* op = new TensorFlowAssertOperator;
77
78 // Even though assert is known to TOCO, it doesn't have a tflite
79 // serializer, so it has to be exported as a custom op. If we attach a
80 // NodeDef to it, however, it will be exported as a flex op instead.
81 ::tensorflow::NodeDef node_def;
82 node_def.set_name("Assert");
83 node_def.set_op("Assert");
84 node_def.SerializeToString(&op->tensorflow_node_def);
85
86 input_model_.operators.emplace_back(op);
87 } else {
88 auto* op = new TensorFlowUnsupportedOperator;
89 op->tensorflow_op = name;
90 input_model_.operators.emplace_back(op);
91 }
92 }
93 }
94
BuildQuantizableTestModel()95 void BuildQuantizableTestModel() {
96 input_model_.GetOrCreateArray("inputs");
97 Array& weight_array = input_model_.GetOrCreateArray("weights");
98
99 // Make the buffer large enough for QuantizeWeights transformation to take
100 // effect.
101 int buf_size = 1296;
102 auto weight_buf = absl::make_unique<float[]>(buf_size);
103 for (int i = 0; i < buf_size; i++) {
104 // Fill the array with some garbage values.
105 weight_buf[i] = static_cast<float>(i % 128);
106 }
107
108 weight_array.data_type = ArrayDataType::kFloat;
109
110 // Initialize shape for the input array.
111 Shape* weight_array_shape = weight_array.mutable_shape();
112 std::vector<int>* weight_array_shape_dim =
113 weight_array_shape->mutable_dims();
114 weight_array_shape_dim->resize(4, 6);
115 auto& weight_array_buffer =
116 weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
117 weight_array_buffer.data.resize(buf_size);
118 float* buf_ptr =
119 weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
120 std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
121
122 {
123 auto* op = new ConvOperator;
124 op->padding.type = PaddingType::kSame;
125 op->inputs = {"inputs", "weights"};
126 op->outputs = {"output"};
127 Array& input_array = input_model_.GetArray(op->inputs[0]);
128 Array& filter_array = input_model_.GetArray(op->inputs[1]);
129 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
130 input_array.data_type = ArrayDataType::kFloat;
131 filter_array.data_type = ArrayDataType::kFloat;
132 output_array.data_type = ArrayDataType::kFloat;
133 input_model_.operators.emplace_back(op);
134 }
135 {
136 auto* op = new AddOperator;
137 op->inputs = {"input1", "input2"};
138 op->outputs = {"output"};
139 Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
140 Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
141 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
142 input1_array.data_type = ArrayDataType::kFloat;
143 input2_array.data_type = ArrayDataType::kFloat;
144 output_array.data_type = ArrayDataType::kFloat;
145 input_model_.operators.emplace_back(op);
146 }
147 }
148
ExportAndSummarizeOperators(const ExportParams & params)149 std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) {
150 std::vector<string> names;
151
152 string result;
153 auto status = Export(input_model_, &result, params);
154 if (!status.ok()) {
155 LOG(INFO) << status.error_message();
156 return names;
157 }
158
159 auto* model = ::tflite::GetModel(result.data());
160
161 for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) {
162 if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
163 names.push_back(string("builtin:") + ::tflite::EnumNameBuiltinOperator(
164 opcode->builtin_code()));
165 } else {
166 names.push_back(string("custom:") + opcode->custom_code()->c_str());
167 }
168 }
169
170 return names;
171 }
172
ExportAndGetOperatorIndices(const ExportParams & params)173 std::vector<uint32_t> ExportAndGetOperatorIndices(
174 const ExportParams& params) {
175 std::vector<uint32_t> indices;
176
177 string result;
178 if (!Export(input_model_, &result, params).ok()) return indices;
179 auto* model = ::tflite::GetModel(result.data());
180
181 auto operators = (*model->subgraphs())[0]->operators();
182 for (const auto* op : *operators) {
183 indices.push_back(op->opcode_index());
184 }
185 return indices;
186 }
187
188 Model input_model_;
189 };
190
TEST_F(ExportTest,LoadTensorsMap)191 TEST_F(ExportTest, LoadTensorsMap) {
192 AddTensorsByName({"tensor_one", "tensor_two"});
193
194 details::TensorsMap tensors;
195 details::LoadTensorsMap(input_model_, &tensors);
196 EXPECT_EQ(0, tensors["tensor_one"]);
197 EXPECT_EQ(1, tensors["tensor_two"]);
198 }
199
TEST_F(ExportTest,LoadOperatorsMap)200 TEST_F(ExportTest, LoadOperatorsMap) {
201 AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
202
203 details::OperatorsMap operators;
204 const auto ops_by_type = BuildOperatorByTypeMap();
205 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
206 EXPECT_EQ(
207 0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
208 EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
209 "", 1)]);
210 EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
211 "MyCrazyOp", 1)]);
212 EXPECT_EQ(
213 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
214 }
215
TEST_F(ExportTest,Export)216 TEST_F(ExportTest, Export) {
217 AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
218
219 ExportParams params;
220 params.allow_custom_ops = true;
221 params.enable_select_tf_ops = false;
222 params.quantize_weights = false;
223
224 EXPECT_THAT(ExportAndSummarizeOperators(params),
225 ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp",
226 "builtin:SUB"));
227 EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
228 }
229
TEST_F(ExportTest,QuantizeWeights)230 TEST_F(ExportTest, QuantizeWeights) {
231 // Sanity check for quantize_weights parameter.
232 BuildQuantizableTestModel();
233 string unquantized_result;
234 Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
235
236 BuildQuantizableTestModel();
237 string quantized_result;
238 Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
239
240 // The quantized models should be smaller.
241 EXPECT_LT(quantized_result.size(), unquantized_result.size());
242 }
243
244 class OpSetsTest : public ExportTest {
245 public:
246 enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
247
SetAllowedOpSets(std::initializer_list<OpSet> sets)248 void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
249 import_all_ops_as_unsupported_ = true;
250 params_.allow_custom_ops = false;
251 params_.enable_select_tf_ops = false;
252 params_.quantize_weights = false;
253
254 for (OpSet i : sets) {
255 switch (i) {
256 case kTfLiteBuiltins:
257 import_all_ops_as_unsupported_ = false;
258 break;
259 case kSelectTfOps:
260 params_.enable_select_tf_ops = true;
261 break;
262 case kCustomOps:
263 params_.allow_custom_ops = true;
264 break;
265 }
266 }
267 }
268
ImportExport(std::initializer_list<string> op_names)269 std::vector<string> ImportExport(std::initializer_list<string> op_names) {
270 ResetOperators();
271 if (!import_all_ops_as_unsupported_) {
272 AddOperatorsByName(op_names);
273 } else {
274 for (const string& name : op_names) {
275 auto* op = new TensorFlowUnsupportedOperator;
276 op->tensorflow_op = name;
277 input_model_.operators.emplace_back(op);
278 }
279 }
280 return ExportAndSummarizeOperators(params_);
281 }
282
283 private:
284 bool import_all_ops_as_unsupported_;
285 ExportParams params_;
286 };
287
TEST_F(OpSetsTest,BuiltinsOnly)288 TEST_F(OpSetsTest, BuiltinsOnly) {
289 // --target_op_set=TFLITE_BUILTINS
290 SetAllowedOpSets({kTfLiteBuiltins});
291 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
292 ElementsAre());
293 EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
294
295 // --target_op_set=TFLITE_BUILTINS --allow_custom_ops
296 SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
297 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
298 ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:Assert",
299 "custom:UnrollAndFold"));
300 }
301
TEST_F(OpSetsTest,TfSelectOnly)302 TEST_F(OpSetsTest, TfSelectOnly) {
303 // --target_op_set=SELECT_TF_OPS
304 SetAllowedOpSets({kSelectTfOps});
305 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "RandomUniform",
306 "UnrollAndFold", "Assert"}),
307 ElementsAre());
308 EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
309
310 // --target_op_set=SELECT_TF_OPS --allow_custom_ops
311 SetAllowedOpSets({kSelectTfOps, kCustomOps});
312 EXPECT_THAT(
313 ImportExport(
314 {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
315 ElementsAre("custom:AdjustHue", "custom:FlexAdd", "custom:FlexAssert",
316 "custom:FlexRandomUniform", "custom:UnrollAndFold"));
317 }
318
TEST_F(OpSetsTest,BuiltinsAndTfSelect)319 TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
320 // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
321 SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
322 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
323 ElementsAre());
324 EXPECT_THAT(ImportExport({"Add", "RandomUniform"}),
325 ElementsAre("builtin:ADD", "custom:FlexRandomUniform"));
326
327 // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
328 SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
329 EXPECT_THAT(
330 ImportExport(
331 {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
332 ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:FlexAssert",
333 "custom:FlexRandomUniform", "custom:UnrollAndFold"));
334 }
335
336 // This test is based on a hypothetical scenario that dilation is supported
337 // only in Conv version 2. So Toco populates version=1 when dialation
338 // parameters are all 1, and version=2 otehrwise.
339 class FakeConvolutionOperator
340 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
341 ::tflite::BuiltinOptions_Conv2DOptions> {
342 public:
FakeConvolutionOperator()343 FakeConvolutionOperator()
344 : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D,
345 OperatorType::kConv) {}
346
347 // Returning the op version according to the op parameters.
GetVersion(const OperatorSignature & op_signature) const348 int GetVersion(const OperatorSignature& op_signature) const override {
349 const TocoOperator& conv_op =
350 static_cast<const TocoOperator&>(*op_signature.op);
351 if (conv_op.dilation_width_factor != 1 ||
352 conv_op.dilation_height_factor != 1) {
353 // Version 2 if dilation is used.
354 return 2;
355 }
356 return 1;
357 }
358
359 // Note: The read / write code doesn't need to be changed if we stick with
360 // the restrictions:
361 // * Only adding parameters at the bottom of the Flatbuffer tables.
362 // * When the default value of parameters are used, the op works consistently
363 // with the previous version.
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const364 flatbuffers::Offset<TfLiteOptions> WriteOptions(
365 const TocoOperator& op,
366 flatbuffers::FlatBufferBuilder* builder) const override {
367 auto padding = Padding::Serialize(op.padding.type);
368 auto activation_function =
369 ActivationFunction::Serialize(op.fused_activation_function);
370 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
371 op.stride_height, activation_function,
372 op.dilation_width_factor,
373 op.dilation_height_factor);
374 }
375
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const376 void ReadOptions(const TfLiteOptions& options,
377 TocoOperator* op) const override {
378 op->padding.type = Padding::Deserialize(options.padding());
379 op->stride_width = options.stride_w();
380 op->stride_height = options.stride_h();
381 op->dilation_width_factor = options.dilation_w_factor();
382 op->dilation_height_factor = options.dilation_h_factor();
383 op->fused_activation_function =
384 ActivationFunction::Deserialize(options.fused_activation_function());
385 }
386 };
387
388 class VersionedOpExportTest : public ::testing::Test {
389 protected:
SetUp()390 void SetUp() override {
391 input_model_.GetOrCreateArray("input");
392 input_model_.GetOrCreateArray("filter");
393 input_model_.GetOrCreateArray("output");
394 }
AddConvOp(bool use_dialation)395 void AddConvOp(bool use_dialation) {
396 {
397 auto* op = new ConvOperator;
398 op->inputs.push_back("input");
399 op->inputs.push_back("filter");
400 op->inputs.push_back("output");
401
402 op->padding.type = PaddingType::kSame;
403 op->stride_width = 1;
404 op->stride_height = 1;
405 if (use_dialation) {
406 op->dilation_width_factor = 2;
407 op->dilation_height_factor = 2;
408 } else {
409 op->dilation_width_factor = 1;
410 op->dilation_height_factor = 1;
411 }
412 input_model_.operators.emplace_back(op);
413 }
414 }
415
416 std::map<OperatorType, std::unique_ptr<BaseOperator>>
BuildFakeOperatorByTypeMap()417 BuildFakeOperatorByTypeMap() {
418 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
419 result[OperatorType::kConv] =
420 std::unique_ptr<BaseOperator>(new FakeConvolutionOperator);
421 return result;
422 }
423
424 Model input_model_;
425 };
426
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV1)427 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
428 AddConvOp(false);
429
430 details::OperatorsMap operators;
431 const auto ops_by_type = BuildFakeOperatorByTypeMap();
432 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
433
434 EXPECT_EQ(1, operators.size());
435 EXPECT_EQ(0, operators.at(details::OperatorKey(
436 ::tflite::BuiltinOperator_CONV_2D, "", 1)));
437 }
438
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV2)439 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
440 AddConvOp(true);
441
442 details::OperatorsMap operators;
443 const auto ops_by_type = BuildFakeOperatorByTypeMap();
444 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
445
446 EXPECT_EQ(1, operators.size());
447 EXPECT_EQ(0, operators.at(details::OperatorKey(
448 ::tflite::BuiltinOperator_CONV_2D, "", 2)));
449 }
450
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithBothVersions)451 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
452 AddConvOp(false);
453 AddConvOp(true);
454
455 details::OperatorsMap operators;
456 const auto ops_by_type = BuildFakeOperatorByTypeMap();
457 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
458
459 EXPECT_EQ(2, operators.size());
460 EXPECT_EQ(0, operators.at(details::OperatorKey(
461 ::tflite::BuiltinOperator_CONV_2D, "", 1)));
462 EXPECT_EQ(1, operators.at(details::OperatorKey(
463 ::tflite::BuiltinOperator_CONV_2D, "", 2)));
464 }
465
TEST_F(VersionedOpExportTest,Export)466 TEST_F(VersionedOpExportTest, Export) {
467 AddConvOp(false);
468 AddConvOp(true);
469
470 string result;
471 const auto ops_by_type = BuildFakeOperatorByTypeMap();
472 Export(input_model_, true, false, &result, ops_by_type);
473
474 auto* model = ::tflite::GetModel(result.data());
475 auto operator_codes = model->operator_codes();
476
477 // Verify that 2 operator codes are populated. Both are CONV_2D but with
478 // different versions.
479 EXPECT_EQ(2, operator_codes->size());
480 EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
481 (*operator_codes)[0]->builtin_code());
482 EXPECT_EQ(1, (*operator_codes)[0]->version());
483 EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
484 (*operator_codes)[1]->builtin_code());
485 EXPECT_EQ(2, (*operator_codes)[1]->version());
486
487 // Verify that the 2 operators points to the correct indices of the operation
488 // codes.
489 auto operators = (*model->subgraphs())[0]->operators();
490 EXPECT_EQ(2, operators->size());
491 EXPECT_EQ(0, (*operators)[0]->opcode_index());
492 EXPECT_EQ(1, (*operators)[1]->opcode_index());
493 }
494
TEST(OperatorKeyTest,TestBuiltinOp)495 TEST(OperatorKeyTest, TestBuiltinOp) {
496 Model model;
497 auto op = absl::make_unique<ConvOperator>();
498
499 // Test a normal float operation.
500 op->inputs = {"input", "filter"};
501 op->outputs = {"output"};
502 Array& input_array = model.GetOrCreateArray(op->inputs[0]);
503 Array& filter_array = model.GetOrCreateArray(op->inputs[1]);
504 Array& output_array = model.GetOrCreateArray(op->outputs[0]);
505 input_array.data_type = ArrayDataType::kFloat;
506 filter_array.data_type = ArrayDataType::kFloat;
507 output_array.data_type = ArrayDataType::kFloat;
508
509 const auto ops_by_type = BuildOperatorByTypeMap();
510 const toco::OperatorSignature op_signature = {op.get(), &model};
511 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
512
513 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CONV_2D);
514 EXPECT_EQ(key.custom_code(), "");
515 EXPECT_EQ(key.version(), 1);
516 }
517
TEST(OperatorKeyTest,TestBuiltinOpWithVersionedInputTypes)518 TEST(OperatorKeyTest, TestBuiltinOpWithVersionedInputTypes) {
519 Model model;
520 auto op = absl::make_unique<DequantizeOperator>();
521
522 op->inputs = {"input"};
523 op->outputs = {"output"};
524 Array& input_array = model.GetOrCreateArray(op->inputs[0]);
525 Array& output_array = model.GetOrCreateArray(op->outputs[0]);
526 input_array.data_type = ArrayDataType::kInt8;
527 output_array.data_type = ArrayDataType::kFloat;
528
529 const auto ops_by_type = BuildOperatorByTypeMap();
530
531 // Test a signed int8 dequantize operation.
532 const toco::OperatorSignature op_signature = {op.get(), &model};
533 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
534
535 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_DEQUANTIZE);
536 EXPECT_EQ(key.custom_code(), "");
537 EXPECT_EQ(key.version(), 2);
538 }
539
TEST(OperatorKeyTest,TestCustomOp)540 TEST(OperatorKeyTest, TestCustomOp) {
541 Model model;
542 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
543 op->tensorflow_op = "MyCrazyCustomOp";
544
545 const auto ops_by_type = BuildOperatorByTypeMap();
546 const toco::OperatorSignature op_signature = {op.get(), &model};
547 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
548
549 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
550 EXPECT_EQ(key.custom_code(), "MyCrazyCustomOp");
551 EXPECT_EQ(key.version(), 1);
552 }
553
TEST(OperatorKeyTest,TestFlexOp)554 TEST(OperatorKeyTest, TestFlexOp) {
555 Model model;
556 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
557 op->tensorflow_op = "BatchMatMul";
558
559 const auto ops_by_type = BuildOperatorByTypeMap();
560 {
561 const toco::OperatorSignature op_signature = {op.get(), &model};
562 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
563 // It shouldn't be converted to Flex op if `allow_flex_op` is false.
564 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
565 EXPECT_EQ(key.custom_code(), "BatchMatMul");
566 EXPECT_EQ(key.version(), 1);
567 EXPECT_TRUE(key.is_custom_op());
568 EXPECT_FALSE(key.is_flex_op());
569 }
570
571 {
572 // Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
573 // is true.
574 const toco::OperatorSignature op_signature = {op.get(), &model};
575 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
576 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
577 EXPECT_EQ(key.custom_code(), "FlexBatchMatMul");
578 EXPECT_EQ(key.version(), 1);
579 EXPECT_FALSE(key.is_custom_op());
580 EXPECT_TRUE(key.is_flex_op());
581 }
582 }
583
TEST(OperatorKeyTest,TestFlexWithControlFlowOp)584 TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
585 Model model;
586 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
587 op->tensorflow_op = "Merge";
588
589 const auto ops_by_type = BuildOperatorByTypeMap();
590 const toco::OperatorSignature op_signature = {op.get(), &model};
591 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
592
593 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
594 EXPECT_EQ(key.custom_code(), "FlexMerge");
595 EXPECT_EQ(key.version(), 1);
596 EXPECT_FALSE(key.is_custom_op());
597 EXPECT_TRUE(key.is_flex_op());
598 // The control flow ops should be marked as unsupported.
599 EXPECT_TRUE(key.is_unsupported_flex_op());
600 }
601
TEST(OperatorKeyTest,TestFlexWithUnsupportedOp)602 TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
603 Model model;
604 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
605 op->tensorflow_op = "HashTableV2";
606
607 const auto ops_by_type = BuildOperatorByTypeMap();
608 const toco::OperatorSignature op_signature = {op.get(), &model};
609 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
610
611 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
612 EXPECT_EQ(key.custom_code(), "HashTableV2");
613 EXPECT_EQ(key.version(), 1);
614 // While HashTableV2 is excluded from the whitelisted flex op list, eventually
615 // it won't be, and the following expectations will need to change as the op
616 // is explicitly blacklisted due to lack of asset support.
617 EXPECT_FALSE(key.is_flex_op());
618 EXPECT_FALSE(key.is_unsupported_flex_op());
619 }
620
TEST(OperatorKeyTest,TestFlexWithPartiallySupportedOps)621 TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
622 // Test Toco-supported/TFLite-unsupported operators.
623 Model model;
624 // TODO(ycling): The test will be broken if TensorFlowAssert is implemented in
625 // TFLite. Find a more robust way to test the fallback logic.
626 auto op = absl::make_unique<TensorFlowAssertOperator>();
627
628 const auto ops_by_type = BuildOperatorByTypeMap();
629
630 {
631 // If NodeDef isn't retained in the Toco op, a regular custom op
632 // will be exported.
633 const toco::OperatorSignature op_signature = {op.get(), &model};
634 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
635 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
636 EXPECT_EQ(key.custom_code(), "Assert");
637 EXPECT_EQ(key.version(), 1);
638 EXPECT_TRUE(key.is_custom_op());
639 EXPECT_FALSE(key.is_flex_op());
640 }
641
642 ::tensorflow::NodeDef node_def;
643 node_def.set_name("TensorFlowAssert");
644 node_def.set_op("TensorFlowAssert");
645 node_def.SerializeToString(&op->tensorflow_node_def);
646
647 {
648 // If NodeDef is retained in the Toco op, a Flex op will be exported.
649 const toco::OperatorSignature op_signature = {op.get(), &model};
650 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
651 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
652 EXPECT_EQ(key.custom_code(), "FlexAssert");
653 EXPECT_EQ(key.version(), 1);
654 EXPECT_FALSE(key.is_custom_op());
655 EXPECT_TRUE(key.is_flex_op());
656 }
657 }
658
659 // TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
660
661 } // namespace
662 } // namespace tflite
663 } // namespace toco
664