1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/export.h"
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/schema/schema_generated.h"
20 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
21 #include "tensorflow/lite/toco/tflite/operator.h"
22 #include "tensorflow/lite/toco/tflite/types.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 
25 namespace toco {
26 namespace tflite {
27 namespace {
28 
29 using ::testing::ElementsAre;
30 
31 class ExportTest : public ::testing::Test {
32  protected:
ResetOperators()33   void ResetOperators() { input_model_.operators.clear(); }
AddTensorsByName(std::initializer_list<string> names)34   void AddTensorsByName(std::initializer_list<string> names) {
35     for (const string& name : names) {
36       input_model_.GetOrCreateArray(name);
37     }
38   }
AddOperatorsByName(std::initializer_list<string> names)39   void AddOperatorsByName(std::initializer_list<string> names) {
40     for (const string& name : names) {
41       if (name == "Conv") {
42         auto* op = new ConvOperator;
43         op->padding.type = PaddingType::kSame;
44         op->inputs = {"input", "filter"};
45         op->outputs = {"output"};
46         Array& input_array = input_model_.GetOrCreateArray(op->inputs[0]);
47         Array& filter_array = input_model_.GetOrCreateArray(op->inputs[1]);
48         Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
49         input_array.data_type = ArrayDataType::kFloat;
50         filter_array.data_type = ArrayDataType::kFloat;
51         output_array.data_type = ArrayDataType::kFloat;
52         input_model_.operators.emplace_back(op);
53       } else if (name == "Add") {
54         auto* op = new AddOperator;
55         op->inputs = {"input1", "input2"};
56         op->outputs = {"output"};
57         Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
58         Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
59         Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
60         input1_array.data_type = ArrayDataType::kFloat;
61         input2_array.data_type = ArrayDataType::kFloat;
62         output_array.data_type = ArrayDataType::kFloat;
63         input_model_.operators.emplace_back(op);
64       } else if (name == "Sub") {
65         auto* op = new SubOperator;
66         op->inputs = {"input1", "input2"};
67         op->outputs = {"output"};
68         Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
69         Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
70         Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
71         input1_array.data_type = ArrayDataType::kFloat;
72         input2_array.data_type = ArrayDataType::kFloat;
73         output_array.data_type = ArrayDataType::kFloat;
74         input_model_.operators.emplace_back(op);
75       } else if (name == "Assert") {
76         auto* op = new TensorFlowAssertOperator;
77 
78         // Even though assert is known to TOCO, it doesn't have a tflite
79         // serializer, so it has to be exported as a custom op. If we attach a
80         // NodeDef to it, however, it will be exported as a flex op instead.
81         ::tensorflow::NodeDef node_def;
82         node_def.set_name("Assert");
83         node_def.set_op("Assert");
84         node_def.SerializeToString(&op->tensorflow_node_def);
85 
86         input_model_.operators.emplace_back(op);
87       } else {
88         auto* op = new TensorFlowUnsupportedOperator;
89         op->tensorflow_op = name;
90         input_model_.operators.emplace_back(op);
91       }
92     }
93   }
94 
BuildQuantizableTestModel()95   void BuildQuantizableTestModel() {
96     input_model_.GetOrCreateArray("inputs");
97     Array& weight_array = input_model_.GetOrCreateArray("weights");
98 
99     // Make the buffer large enough for QuantizeWeights transformation to take
100     // effect.
101     int buf_size = 1296;
102     auto weight_buf = absl::make_unique<float[]>(buf_size);
103     for (int i = 0; i < buf_size; i++) {
104       // Fill the array with some garbage values.
105       weight_buf[i] = static_cast<float>(i % 128);
106     }
107 
108     weight_array.data_type = ArrayDataType::kFloat;
109 
110     // Initialize shape for the input array.
111     Shape* weight_array_shape = weight_array.mutable_shape();
112     std::vector<int>* weight_array_shape_dim =
113         weight_array_shape->mutable_dims();
114     weight_array_shape_dim->resize(4, 6);
115     auto& weight_array_buffer =
116         weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
117     weight_array_buffer.data.resize(buf_size);
118     float* buf_ptr =
119         weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
120     std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
121 
122     {
123       auto* op = new ConvOperator;
124       op->padding.type = PaddingType::kSame;
125       op->inputs = {"inputs", "weights"};
126       op->outputs = {"output"};
127       Array& input_array = input_model_.GetArray(op->inputs[0]);
128       Array& filter_array = input_model_.GetArray(op->inputs[1]);
129       Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
130       input_array.data_type = ArrayDataType::kFloat;
131       filter_array.data_type = ArrayDataType::kFloat;
132       output_array.data_type = ArrayDataType::kFloat;
133       input_model_.operators.emplace_back(op);
134     }
135     {
136       auto* op = new AddOperator;
137       op->inputs = {"input1", "input2"};
138       op->outputs = {"output"};
139       Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
140       Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
141       Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
142       input1_array.data_type = ArrayDataType::kFloat;
143       input2_array.data_type = ArrayDataType::kFloat;
144       output_array.data_type = ArrayDataType::kFloat;
145       input_model_.operators.emplace_back(op);
146     }
147   }
148 
ExportAndSummarizeOperators(const ExportParams & params)149   std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) {
150     std::vector<string> names;
151 
152     string result;
153     auto status = Export(input_model_, &result, params);
154     if (!status.ok()) {
155       LOG(INFO) << status.error_message();
156       return names;
157     }
158 
159     auto* model = ::tflite::GetModel(result.data());
160 
161     for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) {
162       if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
163         names.push_back(string("builtin:") + ::tflite::EnumNameBuiltinOperator(
164                                                  opcode->builtin_code()));
165       } else {
166         names.push_back(string("custom:") + opcode->custom_code()->c_str());
167       }
168     }
169 
170     return names;
171   }
172 
ExportAndGetOperatorIndices(const ExportParams & params)173   std::vector<uint32_t> ExportAndGetOperatorIndices(
174       const ExportParams& params) {
175     std::vector<uint32_t> indices;
176 
177     string result;
178     if (!Export(input_model_, &result, params).ok()) return indices;
179     auto* model = ::tflite::GetModel(result.data());
180 
181     auto operators = (*model->subgraphs())[0]->operators();
182     for (const auto* op : *operators) {
183       indices.push_back(op->opcode_index());
184     }
185     return indices;
186   }
187 
188   Model input_model_;
189 };
190 
TEST_F(ExportTest,LoadTensorsMap)191 TEST_F(ExportTest, LoadTensorsMap) {
192   AddTensorsByName({"tensor_one", "tensor_two"});
193 
194   details::TensorsMap tensors;
195   details::LoadTensorsMap(input_model_, &tensors);
196   EXPECT_EQ(0, tensors["tensor_one"]);
197   EXPECT_EQ(1, tensors["tensor_two"]);
198 }
199 
TEST_F(ExportTest,LoadOperatorsMap)200 TEST_F(ExportTest, LoadOperatorsMap) {
201   AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
202 
203   details::OperatorsMap operators;
204   const auto ops_by_type = BuildOperatorByTypeMap();
205   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
206   EXPECT_EQ(
207       0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
208   EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
209                                               "", 1)]);
210   EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
211                                               "MyCrazyOp", 1)]);
212   EXPECT_EQ(
213       3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
214 }
215 
TEST_F(ExportTest,Export)216 TEST_F(ExportTest, Export) {
217   AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
218 
219   ExportParams params;
220   params.allow_custom_ops = true;
221   params.enable_select_tf_ops = false;
222   params.quantize_weights = false;
223 
224   EXPECT_THAT(ExportAndSummarizeOperators(params),
225               ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp",
226                           "builtin:SUB"));
227   EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
228 }
229 
TEST_F(ExportTest,QuantizeWeights)230 TEST_F(ExportTest, QuantizeWeights) {
231   // Sanity check for quantize_weights parameter.
232   BuildQuantizableTestModel();
233   string unquantized_result;
234   Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
235 
236   BuildQuantizableTestModel();
237   string quantized_result;
238   Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
239 
240   // The quantized models should be smaller.
241   EXPECT_LT(quantized_result.size(), unquantized_result.size());
242 }
243 
244 class OpSetsTest : public ExportTest {
245  public:
246   enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
247 
SetAllowedOpSets(std::initializer_list<OpSet> sets)248   void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
249     import_all_ops_as_unsupported_ = true;
250     params_.allow_custom_ops = false;
251     params_.enable_select_tf_ops = false;
252     params_.quantize_weights = false;
253 
254     for (OpSet i : sets) {
255       switch (i) {
256         case kTfLiteBuiltins:
257           import_all_ops_as_unsupported_ = false;
258           break;
259         case kSelectTfOps:
260           params_.enable_select_tf_ops = true;
261           break;
262         case kCustomOps:
263           params_.allow_custom_ops = true;
264           break;
265       }
266     }
267   }
268 
ImportExport(std::initializer_list<string> op_names)269   std::vector<string> ImportExport(std::initializer_list<string> op_names) {
270     ResetOperators();
271     if (!import_all_ops_as_unsupported_) {
272       AddOperatorsByName(op_names);
273     } else {
274       for (const string& name : op_names) {
275         auto* op = new TensorFlowUnsupportedOperator;
276         op->tensorflow_op = name;
277         input_model_.operators.emplace_back(op);
278       }
279     }
280     return ExportAndSummarizeOperators(params_);
281   }
282 
283  private:
284   bool import_all_ops_as_unsupported_;
285   ExportParams params_;
286 };
287 
TEST_F(OpSetsTest,BuiltinsOnly)288 TEST_F(OpSetsTest, BuiltinsOnly) {
289   // --target_op_set=TFLITE_BUILTINS
290   SetAllowedOpSets({kTfLiteBuiltins});
291   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
292               ElementsAre());
293   EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
294 
295   // --target_op_set=TFLITE_BUILTINS --allow_custom_ops
296   SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
297   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
298               ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:Assert",
299                           "custom:UnrollAndFold"));
300 }
301 
TEST_F(OpSetsTest,TfSelectOnly)302 TEST_F(OpSetsTest, TfSelectOnly) {
303   // --target_op_set=SELECT_TF_OPS
304   SetAllowedOpSets({kSelectTfOps});
305   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "RandomUniform",
306                             "UnrollAndFold", "Assert"}),
307               ElementsAre());
308   EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
309 
310   // --target_op_set=SELECT_TF_OPS --allow_custom_ops
311   SetAllowedOpSets({kSelectTfOps, kCustomOps});
312   EXPECT_THAT(
313       ImportExport(
314           {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
315       ElementsAre("custom:AdjustHue", "custom:FlexAdd", "custom:FlexAssert",
316                   "custom:FlexRandomUniform", "custom:UnrollAndFold"));
317 }
318 
TEST_F(OpSetsTest,BuiltinsAndTfSelect)319 TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
320   // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
321   SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
322   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
323               ElementsAre());
324   EXPECT_THAT(ImportExport({"Add", "RandomUniform"}),
325               ElementsAre("builtin:ADD", "custom:FlexRandomUniform"));
326 
327   // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
328   SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
329   EXPECT_THAT(
330       ImportExport(
331           {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
332       ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:FlexAssert",
333                   "custom:FlexRandomUniform", "custom:UnrollAndFold"));
334 }
335 
336 // This test is based on a hypothetical scenario that dilation is supported
337 // only in Conv version 2. So Toco populates version=1 when dialation
338 // parameters are all 1, and version=2 otehrwise.
339 class FakeConvolutionOperator
340     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
341                              ::tflite::BuiltinOptions_Conv2DOptions> {
342  public:
FakeConvolutionOperator()343   FakeConvolutionOperator()
344       : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D,
345                         OperatorType::kConv) {}
346 
347   // Returning the op version according to the op parameters.
GetVersion(const OperatorSignature & op_signature) const348   int GetVersion(const OperatorSignature& op_signature) const override {
349     const TocoOperator& conv_op =
350         static_cast<const TocoOperator&>(*op_signature.op);
351     if (conv_op.dilation_width_factor != 1 ||
352         conv_op.dilation_height_factor != 1) {
353       // Version 2 if dilation is used.
354       return 2;
355     }
356     return 1;
357   }
358 
359   // Note: The read / write code doesn't need to be changed if we stick with
360   // the restrictions:
361   // * Only adding parameters at the bottom of the Flatbuffer tables.
362   // * When the default value of parameters are used, the op works consistently
363   //   with the previous version.
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const364   flatbuffers::Offset<TfLiteOptions> WriteOptions(
365       const TocoOperator& op,
366       flatbuffers::FlatBufferBuilder* builder) const override {
367     auto padding = Padding::Serialize(op.padding.type);
368     auto activation_function =
369         ActivationFunction::Serialize(op.fused_activation_function);
370     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
371                                          op.stride_height, activation_function,
372                                          op.dilation_width_factor,
373                                          op.dilation_height_factor);
374   }
375 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const376   void ReadOptions(const TfLiteOptions& options,
377                    TocoOperator* op) const override {
378     op->padding.type = Padding::Deserialize(options.padding());
379     op->stride_width = options.stride_w();
380     op->stride_height = options.stride_h();
381     op->dilation_width_factor = options.dilation_w_factor();
382     op->dilation_height_factor = options.dilation_h_factor();
383     op->fused_activation_function =
384         ActivationFunction::Deserialize(options.fused_activation_function());
385   }
386 };
387 
388 class VersionedOpExportTest : public ::testing::Test {
389  protected:
SetUp()390   void SetUp() override {
391     input_model_.GetOrCreateArray("input");
392     input_model_.GetOrCreateArray("filter");
393     input_model_.GetOrCreateArray("output");
394   }
AddConvOp(bool use_dialation)395   void AddConvOp(bool use_dialation) {
396     {
397       auto* op = new ConvOperator;
398       op->inputs.push_back("input");
399       op->inputs.push_back("filter");
400       op->inputs.push_back("output");
401 
402       op->padding.type = PaddingType::kSame;
403       op->stride_width = 1;
404       op->stride_height = 1;
405       if (use_dialation) {
406         op->dilation_width_factor = 2;
407         op->dilation_height_factor = 2;
408       } else {
409         op->dilation_width_factor = 1;
410         op->dilation_height_factor = 1;
411       }
412       input_model_.operators.emplace_back(op);
413     }
414   }
415 
416   std::map<OperatorType, std::unique_ptr<BaseOperator>>
BuildFakeOperatorByTypeMap()417   BuildFakeOperatorByTypeMap() {
418     std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
419     result[OperatorType::kConv] =
420         std::unique_ptr<BaseOperator>(new FakeConvolutionOperator);
421     return result;
422   }
423 
424   Model input_model_;
425 };
426 
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV1)427 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
428   AddConvOp(false);
429 
430   details::OperatorsMap operators;
431   const auto ops_by_type = BuildFakeOperatorByTypeMap();
432   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
433 
434   EXPECT_EQ(1, operators.size());
435   EXPECT_EQ(0, operators.at(details::OperatorKey(
436                    ::tflite::BuiltinOperator_CONV_2D, "", 1)));
437 }
438 
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV2)439 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
440   AddConvOp(true);
441 
442   details::OperatorsMap operators;
443   const auto ops_by_type = BuildFakeOperatorByTypeMap();
444   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
445 
446   EXPECT_EQ(1, operators.size());
447   EXPECT_EQ(0, operators.at(details::OperatorKey(
448                    ::tflite::BuiltinOperator_CONV_2D, "", 2)));
449 }
450 
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithBothVersions)451 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
452   AddConvOp(false);
453   AddConvOp(true);
454 
455   details::OperatorsMap operators;
456   const auto ops_by_type = BuildFakeOperatorByTypeMap();
457   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
458 
459   EXPECT_EQ(2, operators.size());
460   EXPECT_EQ(0, operators.at(details::OperatorKey(
461                    ::tflite::BuiltinOperator_CONV_2D, "", 1)));
462   EXPECT_EQ(1, operators.at(details::OperatorKey(
463                    ::tflite::BuiltinOperator_CONV_2D, "", 2)));
464 }
465 
TEST_F(VersionedOpExportTest,Export)466 TEST_F(VersionedOpExportTest, Export) {
467   AddConvOp(false);
468   AddConvOp(true);
469 
470   string result;
471   const auto ops_by_type = BuildFakeOperatorByTypeMap();
472   Export(input_model_, true, false, &result, ops_by_type);
473 
474   auto* model = ::tflite::GetModel(result.data());
475   auto operator_codes = model->operator_codes();
476 
477   // Verify that 2 operator codes are populated. Both are CONV_2D but with
478   // different versions.
479   EXPECT_EQ(2, operator_codes->size());
480   EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
481             (*operator_codes)[0]->builtin_code());
482   EXPECT_EQ(1, (*operator_codes)[0]->version());
483   EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
484             (*operator_codes)[1]->builtin_code());
485   EXPECT_EQ(2, (*operator_codes)[1]->version());
486 
487   // Verify that the 2 operators points to the correct indices of the operation
488   // codes.
489   auto operators = (*model->subgraphs())[0]->operators();
490   EXPECT_EQ(2, operators->size());
491   EXPECT_EQ(0, (*operators)[0]->opcode_index());
492   EXPECT_EQ(1, (*operators)[1]->opcode_index());
493 }
494 
TEST(OperatorKeyTest,TestBuiltinOp)495 TEST(OperatorKeyTest, TestBuiltinOp) {
496   Model model;
497   auto op = absl::make_unique<ConvOperator>();
498 
499   // Test a normal float operation.
500   op->inputs = {"input", "filter"};
501   op->outputs = {"output"};
502   Array& input_array = model.GetOrCreateArray(op->inputs[0]);
503   Array& filter_array = model.GetOrCreateArray(op->inputs[1]);
504   Array& output_array = model.GetOrCreateArray(op->outputs[0]);
505   input_array.data_type = ArrayDataType::kFloat;
506   filter_array.data_type = ArrayDataType::kFloat;
507   output_array.data_type = ArrayDataType::kFloat;
508 
509   const auto ops_by_type = BuildOperatorByTypeMap();
510   const toco::OperatorSignature op_signature = {op.get(), &model};
511   const auto key = details::OperatorKey(op_signature, ops_by_type, false);
512 
513   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CONV_2D);
514   EXPECT_EQ(key.custom_code(), "");
515   EXPECT_EQ(key.version(), 1);
516 }
517 
TEST(OperatorKeyTest,TestBuiltinOpWithVersionedInputTypes)518 TEST(OperatorKeyTest, TestBuiltinOpWithVersionedInputTypes) {
519   Model model;
520   auto op = absl::make_unique<DequantizeOperator>();
521 
522   op->inputs = {"input"};
523   op->outputs = {"output"};
524   Array& input_array = model.GetOrCreateArray(op->inputs[0]);
525   Array& output_array = model.GetOrCreateArray(op->outputs[0]);
526   input_array.data_type = ArrayDataType::kInt8;
527   output_array.data_type = ArrayDataType::kFloat;
528 
529   const auto ops_by_type = BuildOperatorByTypeMap();
530 
531   // Test a signed int8 dequantize operation.
532   const toco::OperatorSignature op_signature = {op.get(), &model};
533   const auto key = details::OperatorKey(op_signature, ops_by_type, false);
534 
535   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_DEQUANTIZE);
536   EXPECT_EQ(key.custom_code(), "");
537   EXPECT_EQ(key.version(), 2);
538 }
539 
TEST(OperatorKeyTest,TestCustomOp)540 TEST(OperatorKeyTest, TestCustomOp) {
541   Model model;
542   auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
543   op->tensorflow_op = "MyCrazyCustomOp";
544 
545   const auto ops_by_type = BuildOperatorByTypeMap();
546   const toco::OperatorSignature op_signature = {op.get(), &model};
547   const auto key = details::OperatorKey(op_signature, ops_by_type, false);
548 
549   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
550   EXPECT_EQ(key.custom_code(), "MyCrazyCustomOp");
551   EXPECT_EQ(key.version(), 1);
552 }
553 
TEST(OperatorKeyTest,TestFlexOp)554 TEST(OperatorKeyTest, TestFlexOp) {
555   Model model;
556   auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
557   op->tensorflow_op = "BatchMatMul";
558 
559   const auto ops_by_type = BuildOperatorByTypeMap();
560   {
561     const toco::OperatorSignature op_signature = {op.get(), &model};
562     const auto key = details::OperatorKey(op_signature, ops_by_type, false);
563     // It shouldn't be converted to Flex op if `allow_flex_op` is false.
564     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
565     EXPECT_EQ(key.custom_code(), "BatchMatMul");
566     EXPECT_EQ(key.version(), 1);
567     EXPECT_TRUE(key.is_custom_op());
568     EXPECT_FALSE(key.is_flex_op());
569   }
570 
571   {
572     // Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
573     // is true.
574     const toco::OperatorSignature op_signature = {op.get(), &model};
575     const auto key = details::OperatorKey(op_signature, ops_by_type, true);
576     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
577     EXPECT_EQ(key.custom_code(), "FlexBatchMatMul");
578     EXPECT_EQ(key.version(), 1);
579     EXPECT_FALSE(key.is_custom_op());
580     EXPECT_TRUE(key.is_flex_op());
581   }
582 }
583 
TEST(OperatorKeyTest,TestFlexWithControlFlowOp)584 TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
585   Model model;
586   auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
587   op->tensorflow_op = "Merge";
588 
589   const auto ops_by_type = BuildOperatorByTypeMap();
590   const toco::OperatorSignature op_signature = {op.get(), &model};
591   const auto key = details::OperatorKey(op_signature, ops_by_type, true);
592 
593   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
594   EXPECT_EQ(key.custom_code(), "FlexMerge");
595   EXPECT_EQ(key.version(), 1);
596   EXPECT_FALSE(key.is_custom_op());
597   EXPECT_TRUE(key.is_flex_op());
598   // The control flow ops should be marked as unsupported.
599   EXPECT_TRUE(key.is_unsupported_flex_op());
600 }
601 
TEST(OperatorKeyTest,TestFlexWithUnsupportedOp)602 TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
603   Model model;
604   auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
605   op->tensorflow_op = "HashTableV2";
606 
607   const auto ops_by_type = BuildOperatorByTypeMap();
608   const toco::OperatorSignature op_signature = {op.get(), &model};
609   const auto key = details::OperatorKey(op_signature, ops_by_type, true);
610 
611   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
612   EXPECT_EQ(key.custom_code(), "HashTableV2");
613   EXPECT_EQ(key.version(), 1);
614   // While HashTableV2 is excluded from the whitelisted flex op list, eventually
615   // it won't be, and the following expectations will need to change as the op
616   // is explicitly blacklisted due to lack of asset support.
617   EXPECT_FALSE(key.is_flex_op());
618   EXPECT_FALSE(key.is_unsupported_flex_op());
619 }
620 
TEST(OperatorKeyTest,TestFlexWithPartiallySupportedOps)621 TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
622   // Test Toco-supported/TFLite-unsupported operators.
623   Model model;
624   // TODO(ycling): The test will be broken if TensorFlowAssert is implemented in
625   // TFLite. Find a more robust way to test the fallback logic.
626   auto op = absl::make_unique<TensorFlowAssertOperator>();
627 
628   const auto ops_by_type = BuildOperatorByTypeMap();
629 
630   {
631     // If NodeDef isn't retained in the Toco op, a regular custom op
632     // will be exported.
633     const toco::OperatorSignature op_signature = {op.get(), &model};
634     const auto key = details::OperatorKey(op_signature, ops_by_type, true);
635     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
636     EXPECT_EQ(key.custom_code(), "Assert");
637     EXPECT_EQ(key.version(), 1);
638     EXPECT_TRUE(key.is_custom_op());
639     EXPECT_FALSE(key.is_flex_op());
640   }
641 
642   ::tensorflow::NodeDef node_def;
643   node_def.set_name("TensorFlowAssert");
644   node_def.set_op("TensorFlowAssert");
645   node_def.SerializeToString(&op->tensorflow_node_def);
646 
647   {
648     // If NodeDef is retained in the Toco op, a Flex op will be exported.
649     const toco::OperatorSignature op_signature = {op.get(), &model};
650     const auto key = details::OperatorKey(op_signature, ops_by_type, true);
651     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
652     EXPECT_EQ(key.custom_code(), "FlexAssert");
653     EXPECT_EQ(key.version(), 1);
654     EXPECT_FALSE(key.is_custom_op());
655     EXPECT_TRUE(key.is_flex_op());
656   }
657 }
658 
659 // TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
660 
661 }  // namespace
662 }  // namespace tflite
663 }  // namespace toco
664