1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/mutable_op_resolver.h"
17 
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/testing/util.h"
20 
21 namespace tflite {
22 namespace {
23 
24 // We need some dummy functions to identify the registrations.
DummyInvoke(TfLiteContext * context,TfLiteNode * node)25 TfLiteStatus DummyInvoke(TfLiteContext* context, TfLiteNode* node) {
26   return kTfLiteOk;
27 }
28 
GetDummyRegistration()29 TfLiteRegistration* GetDummyRegistration() {
30   static TfLiteRegistration registration = {
31       .init = nullptr,
32       .free = nullptr,
33       .prepare = nullptr,
34       .invoke = DummyInvoke,
35   };
36   return &registration;
37 }
38 
Dummy2Invoke(TfLiteContext * context,TfLiteNode * node)39 TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
40   return kTfLiteOk;
41 }
42 
GetDummy2Registration()43 TfLiteRegistration* GetDummy2Registration() {
44   static TfLiteRegistration registration = {
45       .init = nullptr,
46       .free = nullptr,
47       .prepare = nullptr,
48       .invoke = Dummy2Invoke,
49   };
50   return &registration;
51 }
52 
TEST(MutableOpResolverTest,FinOp)53 TEST(MutableOpResolverTest, FinOp) {
54   MutableOpResolver resolver;
55   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
56 
57   const TfLiteRegistration* found_registration =
58       resolver.FindOp(BuiltinOperator_ADD, 1);
59   ASSERT_NE(found_registration, nullptr);
60   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
61   EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_ADD);
62   EXPECT_EQ(found_registration->version, 1);
63 }
64 
TEST(MutableOpResolverTest,FindMissingOp)65 TEST(MutableOpResolverTest, FindMissingOp) {
66   MutableOpResolver resolver;
67   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
68 
69   const TfLiteRegistration* found_registration =
70       resolver.FindOp(BuiltinOperator_CONV_2D, 1);
71   EXPECT_EQ(found_registration, nullptr);
72 }
73 
TEST(MutableOpResolverTest,RegisterOpWithMultipleVersions)74 TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
75   MutableOpResolver resolver;
76   // The kernel supports version 2 and 3
77   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
78 
79   const TfLiteRegistration* found_registration;
80 
81   found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
82   ASSERT_NE(found_registration, nullptr);
83   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
84   EXPECT_EQ(found_registration->version, 2);
85 
86   found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
87   ASSERT_NE(found_registration, nullptr);
88   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
89   EXPECT_EQ(found_registration->version, 3);
90 }
91 
TEST(MutableOpResolverTest,FindOpWithUnsupportedVersions)92 TEST(MutableOpResolverTest, FindOpWithUnsupportedVersions) {
93   MutableOpResolver resolver;
94   // The kernel supports version 2 and 3
95   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
96 
97   const TfLiteRegistration* found_registration;
98 
99   found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
100   EXPECT_EQ(found_registration, nullptr);
101 
102   found_registration = resolver.FindOp(BuiltinOperator_ADD, 4);
103   EXPECT_EQ(found_registration, nullptr);
104 }
105 
TEST(MutableOpResolverTest,FindCustomOp)106 TEST(MutableOpResolverTest, FindCustomOp) {
107   MutableOpResolver resolver;
108   resolver.AddCustom("AWESOME", GetDummyRegistration());
109 
110   const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 1);
111   ASSERT_NE(found_registration, nullptr);
112   EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
113   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
114   EXPECT_EQ(found_registration->version, 1);
115   // TODO(ycling): The `custom_name` in TfLiteRegistration isn't properly
116   // filled yet. Fix this and add tests.
117 }
118 
TEST(MutableOpResolverTest,FindMissingCustomOp)119 TEST(MutableOpResolverTest, FindMissingCustomOp) {
120   MutableOpResolver resolver;
121   resolver.AddCustom("AWESOME", GetDummyRegistration());
122 
123   const TfLiteRegistration* found_registration =
124       resolver.FindOp("EXCELLENT", 1);
125   EXPECT_EQ(found_registration, nullptr);
126 }
127 
TEST(MutableOpResolverTest,FindCustomOpWithUnsupportedVersion)128 TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
129   MutableOpResolver resolver;
130   resolver.AddCustom("AWESOME", GetDummyRegistration());
131 
132   const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 2);
133   EXPECT_EQ(found_registration, nullptr);
134 }
135 
TEST(MutableOpResolverTest,AddAll)136 TEST(MutableOpResolverTest, AddAll) {
137   MutableOpResolver resolver1;
138   resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
139   resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
140 
141   MutableOpResolver resolver2;
142   resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
143   resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
144 
145   // resolver2's ADD op should replace resolver1's ADD op, while augmenting
146   // non-overlapping ops.
147   resolver1.AddAll(resolver2);
148   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
149             GetDummy2Registration()->invoke);
150   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
151             GetDummy2Registration()->invoke);
152   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
153             GetDummyRegistration()->invoke);
154 }
155 
156 }  // namespace
157 }  // namespace tflite
158 
main(int argc,char ** argv)159 int main(int argc, char** argv) {
160   ::tflite::LogToStderr();
161   ::testing::InitGoogleTest(&argc, argv);
162   return RUN_ALL_TESTS();
163 }
164