1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/mutable_op_resolver.h"
17
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/testing/util.h"
20
21 namespace tflite {
22 namespace {
23
24 // We need some dummy functions to identify the registrations.
DummyInvoke(TfLiteContext * context,TfLiteNode * node)25 TfLiteStatus DummyInvoke(TfLiteContext* context, TfLiteNode* node) {
26 return kTfLiteOk;
27 }
28
GetDummyRegistration()29 TfLiteRegistration* GetDummyRegistration() {
30 static TfLiteRegistration registration = {
31 .init = nullptr,
32 .free = nullptr,
33 .prepare = nullptr,
34 .invoke = DummyInvoke,
35 };
36 return ®istration;
37 }
38
Dummy2Invoke(TfLiteContext * context,TfLiteNode * node)39 TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
40 return kTfLiteOk;
41 }
42
GetDummy2Registration()43 TfLiteRegistration* GetDummy2Registration() {
44 static TfLiteRegistration registration = {
45 .init = nullptr,
46 .free = nullptr,
47 .prepare = nullptr,
48 .invoke = Dummy2Invoke,
49 };
50 return ®istration;
51 }
52
TEST(MutableOpResolverTest,FinOp)53 TEST(MutableOpResolverTest, FinOp) {
54 MutableOpResolver resolver;
55 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
56
57 const TfLiteRegistration* found_registration =
58 resolver.FindOp(BuiltinOperator_ADD, 1);
59 ASSERT_NE(found_registration, nullptr);
60 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
61 EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_ADD);
62 EXPECT_EQ(found_registration->version, 1);
63 }
64
TEST(MutableOpResolverTest,FindMissingOp)65 TEST(MutableOpResolverTest, FindMissingOp) {
66 MutableOpResolver resolver;
67 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
68
69 const TfLiteRegistration* found_registration =
70 resolver.FindOp(BuiltinOperator_CONV_2D, 1);
71 EXPECT_EQ(found_registration, nullptr);
72 }
73
TEST(MutableOpResolverTest,RegisterOpWithMultipleVersions)74 TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
75 MutableOpResolver resolver;
76 // The kernel supports version 2 and 3
77 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
78
79 const TfLiteRegistration* found_registration;
80
81 found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
82 ASSERT_NE(found_registration, nullptr);
83 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
84 EXPECT_EQ(found_registration->version, 2);
85
86 found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
87 ASSERT_NE(found_registration, nullptr);
88 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
89 EXPECT_EQ(found_registration->version, 3);
90 }
91
TEST(MutableOpResolverTest,FindOpWithUnsupportedVersions)92 TEST(MutableOpResolverTest, FindOpWithUnsupportedVersions) {
93 MutableOpResolver resolver;
94 // The kernel supports version 2 and 3
95 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
96
97 const TfLiteRegistration* found_registration;
98
99 found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
100 EXPECT_EQ(found_registration, nullptr);
101
102 found_registration = resolver.FindOp(BuiltinOperator_ADD, 4);
103 EXPECT_EQ(found_registration, nullptr);
104 }
105
TEST(MutableOpResolverTest,FindCustomOp)106 TEST(MutableOpResolverTest, FindCustomOp) {
107 MutableOpResolver resolver;
108 resolver.AddCustom("AWESOME", GetDummyRegistration());
109
110 const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 1);
111 ASSERT_NE(found_registration, nullptr);
112 EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
113 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
114 EXPECT_EQ(found_registration->version, 1);
115 // TODO(ycling): The `custom_name` in TfLiteRegistration isn't properly
116 // filled yet. Fix this and add tests.
117 }
118
TEST(MutableOpResolverTest,FindMissingCustomOp)119 TEST(MutableOpResolverTest, FindMissingCustomOp) {
120 MutableOpResolver resolver;
121 resolver.AddCustom("AWESOME", GetDummyRegistration());
122
123 const TfLiteRegistration* found_registration =
124 resolver.FindOp("EXCELLENT", 1);
125 EXPECT_EQ(found_registration, nullptr);
126 }
127
TEST(MutableOpResolverTest,FindCustomOpWithUnsupportedVersion)128 TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
129 MutableOpResolver resolver;
130 resolver.AddCustom("AWESOME", GetDummyRegistration());
131
132 const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 2);
133 EXPECT_EQ(found_registration, nullptr);
134 }
135
TEST(MutableOpResolverTest,AddAll)136 TEST(MutableOpResolverTest, AddAll) {
137 MutableOpResolver resolver1;
138 resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
139 resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
140
141 MutableOpResolver resolver2;
142 resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
143 resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
144
145 // resolver2's ADD op should replace resolver1's ADD op, while augmenting
146 // non-overlapping ops.
147 resolver1.AddAll(resolver2);
148 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
149 GetDummy2Registration()->invoke);
150 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
151 GetDummy2Registration()->invoke);
152 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
153 GetDummyRegistration()->invoke);
154 }
155
156 } // namespace
157 } // namespace tflite
158
main(int argc,char ** argv)159 int main(int argc, char** argv) {
160 ::tflite::LogToStderr();
161 ::testing::InitGoogleTest(&argc, argv);
162 return RUN_ALL_TESTS();
163 }
164