1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <cstdint>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/micro/all_ops_resolver.h"
22 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
23 #include "tensorflow/lite/micro/micro_utils.h"
24 #include "tensorflow/lite/micro/test_helpers.h"
25 #include "tensorflow/lite/micro/testing/micro_test.h"
26 
27 namespace tflite {
28 namespace testing {
29 namespace {
30 
31 // Simple test data for 2x2x10 input 2x3x10 weights.
32 const int simple_input_size = 20;
33 const int simple_input_dims[] = {2, 2, 10};
34 const float simple_input_data[] = {
35     1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
36     1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
37 };
38 const int simple_weights_size = 30;
39 const int simple_weights_dims[] = {2, 3, 10};
40 const float simple_weights_data[] = {
41     1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
42     1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
43     1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
44 };
45 const int simple_bias_dims[] = {1, 3};
46 const float simple_bias_data[] = {1, 2, 3};
47 const float simple_golden[] = {
48     24, 25, 26, 58, 59, 60,
49 };
50 const int simple_output_size = 6;
51 const int simple_output_dims[] = {2, 2, 3};
52 
53 // Test data for 2x2x10 input 2x3x10 weights with negative outputs to test relu.
54 const int relu_input_size = 20;
55 const int relu_input_dims[] = {2, 2, 10};
56 const float relu_input_data[] = {
57     1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
58     1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
59 };
60 const int relu_weights_size = 30;
61 const int relu_weights_dims[] = {2, 3, 10};
62 const float relu_weights_data[] = {
63     1,  2,  3,  4,  5,  6,  7,  8,  9,  10,   // u = 0
64     -1, -2, -3, -4, -5, -6, -7, -8, -9, -10,  // u = 1
65     1,  2,  3,  4,  5,  6,  7,  8,  9,  10,   // u = 2
66 };
67 const int relu_bias_dims[] = {1, 3};
68 const float relu_bias_data[] = {1, -2, 3};
69 const float relu_golden[] = {
70     24, 0, 26, 58, 0, 60,
71 };
72 const int relu_output_size = 6;
73 const int relu_output_dims[] = {2, 2, 3};
74 
75 // Input and filter similar to real model. Input shape is 1x64 and output is
76 // 1x16.
77 const int representative_64x16_input_size = 64;
78 const int representative_64x16_input_dims[] = {2, 1, 64};
79 const float representative_64x16_input_data[] = {
80     0.0000, 0.1543, 0.0000, 0.0000, 1.8520, 0.0000, 4.7844, 1.1832,
81     0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.5948, 0.0000,
82     1.5948, 1.9549, 0.0000, 1.2347, 0.0000, 1.5948, 1.5948, 0.5145,
83     0.0000, 0.0000, 0.0000, 0.0000, 2.6237, 0.0000, 0.0000, 0.0000,
84     1.3890, 5.3503, 2.3665, 2.9838, 0.0000, 1.2861, 0.0000, 3.0867,
85     0.9775, 0.0000, 5.9676, 0.0000, 0.0000, 1.4405, 0.5145, 2.5723,
86     3.1896, 4.4757, 0.0000, 0.0000, 0.0000, 0.0000, 4.1671, 0.0000,
87     2.8295, 3.0353, 0.0000, 2.7780, 0.0000, 0.0000, 0.0000, 0.0000};
88 const int representative_64x16_weights_size = 64 * 16;
89 const int representative_64x16_weights_dims[] = {2, 16, 64};
90 const float representative_64x16_weights_data[] = {
91     -0.1075, 0.1245,  0.1811,  -0.1302, -0.1868, 0.0679,  0.1245,  0.2321,
92     -0.1981, -0.2094, 0.1358,  -0.1698, 0.0113,  0.0566,  0.1358,  -0.2490,
93     0.0000,  -0.1189, -0.0170, -0.0396, -0.3113, 0.1641,  -0.4188, 0.0566,
94     -0.4471, 0.4754,  -0.0396, 0.0113,  -0.0340, 0.0170,  0.0170,  0.1811,
95     -0.0792, 0.4981,  0.2490,  -0.1924, 0.0792,  0.1868,  -0.1075, -0.3962,
96     0.1358,  0.2547,  -0.1245, -0.0962, -0.0283, 0.4132,  -0.0057, -0.5150,
97     0.1019,  0.1585,  -0.0962, -0.2207, -0.2377, 0.2830,  0.4471,  0.0170,
98     0.0566,  0.2038,  0.1019,  -0.0226, 0.2830,  0.1415,  0.0283,  -0.0792,
99     0.4301,  0.3226,  -0.1132, 0.4981,  -0.3849, -0.2943, -0.2547, -0.2264,
100     0.0453,  -0.0170, 0.0396,  0.1415,  0.3000,  0.2547,  0.0962,  0.2151,
101     -0.1585, -0.1302, -0.0057, -0.2773, 0.0283,  -0.0906, 0.1302,  -0.1075,
102     -0.0566, 0.1755,  0.2773,  0.0283,  0.0566,  0.1528,  -0.0736, -0.2830,
103     0.0792,  0.0962,  -0.2321, -0.0113, 0.2660,  -0.2887, -0.0566, 0.0057,
104     -0.2547, -0.0679, -0.2321, 0.0340,  0.1868,  0.2490,  0.2264,  -0.3509,
105     0.1585,  -0.0849, -0.0623, 0.1132,  0.3396,  -0.2490, 0.1528,  0.0679,
106     0.1755,  0.4754,  -0.0057, -0.2151, -0.1415, -0.1302, -0.2717, 0.1641,
107     0.5037,  -0.2321, 0.0170,  -0.1755, -0.1075, -0.0226, 0.2038,  -0.0340,
108     -0.5150, -0.3113, 0.1472,  -0.0226, 0.1528,  0.1189,  -0.1472, 0.0396,
109     -0.3000, -0.1924, -0.0283, 0.0283,  0.1641,  0.0736,  0.1472,  -0.1755,
110     -0.1132, 0.0113,  -0.1868, -0.2604, -0.3283, -0.0509, 0.0283,  -0.0679,
111     0.0623,  0.0792,  -0.0283, -0.0962, 0.0396,  0.1641,  0.4584,  0.3226,
112     0.0226,  -0.1811, 0.2377,  -0.1019, 0.2321,  0.1811,  -0.1924, -0.0057,
113     0.0736,  0.0113,  0.2547,  -0.2264, -0.0170, -0.0396, 0.1245,  -0.1415,
114     0.1755,  0.3679,  -0.2377, -0.0396, -0.1585, -0.3000, -0.1641, -0.1302,
115     -0.0396, -0.1698, 0.1189,  0.2434,  0.1132,  -0.1245, -0.1415, 0.0453,
116     0.1868,  -0.0906, -0.1189, -0.0509, 0.0057,  -0.1189, -0.0057, 0.0170,
117     -0.1924, 0.2207,  0.0792,  -0.4641, -0.2660, 0.2943,  0.1358,  -0.0340,
118     -0.3339, -0.1189, 0.0906,  -0.4358, 0.0453,  -0.1755, 0.1415,  0.0340,
119     0.1924,  -0.0057, 0.2321,  -0.2094, -0.1132, 0.0000,  0.1924,  -0.3000,
120     0.0340,  -0.3396, -0.0906, -0.0340, 0.1641,  -0.0226, -0.1472, -0.1019,
121     0.2377,  -0.0962, -0.3396, -0.5433, 0.0906,  0.2151,  -0.0679, 0.1755,
122     0.1528,  0.0283,  -0.4188, -0.0340, -0.0057, -0.0679, 0.0509,  0.1472,
123     -0.3849, -0.0113, 0.3962,  0.0849,  0.1472,  0.0340,  -0.1358, 0.1641,
124     -0.2038, 0.2151,  -0.1189, -0.3679, 0.0906,  -0.0679, 0.5716,  -0.0057,
125     -0.0736, 0.0113,  0.2830,  -0.2887, 0.0396,  0.0849,  -0.0736, -0.0736,
126     -0.3679, 0.2264,  0.0113,  -0.1641, 0.0396,  -0.1132, -0.0623, 0.3113,
127     0.5999,  -0.1415, 0.1472,  -0.2038, -0.1132, -0.2377, 0.0566,  0.1755,
128     -0.0057, -0.0453, 0.0226,  0.1132,  0.1698,  0.0340,  -0.0226, 0.0226,
129     0.4415,  -0.3792, 0.0792,  0.3736,  -0.5999, -0.3056, -0.1924, -0.1132,
130     -0.0962, 0.0283,  0.0000,  -0.3339, -0.3226, 0.3679,  -0.0453, -0.1641,
131     0.0170,  0.1302,  -0.0170, -0.0509, 0.1755,  -0.0283, -0.1302, -0.2887,
132     -0.0679, 0.0340,  0.4641,  0.2321,  0.7188,  0.3339,  -0.1075, 0.4754,
133     -0.0226, 0.3226,  -0.1528, -0.0849, 0.0509,  -0.1981, 0.0113,  0.2321,
134     0.2773,  -0.1019, 0.4075,  0.0396,  0.0792,  0.1132,  -0.0906, -0.4188,
135     0.1924,  -0.3679, -0.6396, 0.1358,  0.4981,  0.4132,  -0.0283, 0.3849,
136     -0.3509, -0.0566, -0.0962, 0.3113,  -0.1811, 0.4019,  0.0453,  -0.0057,
137     -0.1868, -0.2490, -0.0792, -0.3622, 0.1924,  -0.0453, -0.1528, -0.1811,
138     0.5943,  -0.1302, 0.3170,  -0.0170, 0.0509,  -0.1528, -0.1755, 0.5547,
139     0.2490,  -0.0906, 0.0000,  0.1698,  0.0000,  0.0340,  -0.1132, -0.0509,
140     -0.1755, -0.2943, 0.1472,  0.0849,  0.0000,  0.1528,  -0.0566, 0.1528,
141     -0.5264, -0.5320, -0.0736, 0.0566,  0.2604,  -0.4075, 0.0962,  -0.3453,
142     -0.1415, 0.0057,  0.3905,  0.2830,  0.3679,  0.5320,  -0.2660, 0.0340,
143     0.0736,  0.0057,  0.2207,  0.4471,  0.0849,  0.3000,  -0.0057, -0.0623,
144     0.1415,  -0.0566, 0.5264,  -0.0340, 0.0226,  -0.0623, -0.0113, -0.5037,
145     -0.4471, 0.0170,  -0.0396, -0.1358, -0.1698, 0.1924,  0.0057,  -0.1585,
146     0.0849,  -0.1698, 0.0057,  -0.1245, -0.0170, -0.1755, -0.0792, 0.5264,
147     0.1358,  0.2434,  0.1585,  -0.4188, -0.1472, -0.1358, -0.0849, -0.1189,
148     0.5037,  0.0736,  -0.0453, -0.2434, 0.1868,  -0.0679, 0.1415,  -0.2717,
149     0.2604,  0.0057,  -0.1528, -0.1811, 0.0226,  -0.1641, 0.3170,  -0.1981,
150     0.1245,  0.0226,  0.0566,  0.2830,  -0.1755, 0.0396,  -0.2094, 0.1924,
151     0.1698,  0.0283,  0.1641,  0.0849,  0.0000,  -0.1698, -0.1415, -0.3000,
152     0.4471,  0.3056,  -0.0283, -0.4245, -0.0453, 0.0226,  0.0000,  -0.1075,
153     -0.1528, -0.3226, 0.2773,  -0.2264, -0.1811, 0.1755,  -0.3566, -0.4188,
154     0.1755,  -0.0057, 0.2038,  0.1075,  0.3679,  -0.0792, 0.2207,  -0.0453,
155     0.3736,  0.2943,  -0.0113, -0.0623, 0.2264,  0.0113,  -0.0396, -0.2207,
156     0.0453,  -0.2830, -0.1302, 0.0623,  -0.1924, -0.1811, -0.2717, 0.2830,
157     0.2094,  0.0170,  -0.3170, -0.0283, -0.1189, -0.0509, -0.0566, -0.3622,
158     0.1132,  -0.0906, 0.1132,  0.4019,  -0.4698, -0.1019, -0.1075, -0.2094,
159     -0.2207, -0.0509, 0.0057,  0.1019,  -0.0509, 0.2264,  -0.5716, 0.0226,
160     -0.4019, 0.1641,  -0.3000, 0.3849,  0.1245,  0.0679,  0.3056,  0.2377,
161     0.0679,  -0.0170, -0.5377, -0.0170, 0.0057,  0.1358,  -0.1132, -0.2038,
162     0.0679,  0.1075,  -0.2773, 0.5943,  0.0623,  -0.1472, 0.3566,  0.0396,
163     -0.2377, 0.2604,  0.0849,  0.1358,  -0.3792, -0.0340, -0.1415, 0.3566,
164     -0.3736, 0.1245,  0.0566,  0.3396,  0.0736,  0.4019,  -0.1528, 0.1075,
165     0.0792,  -0.2547, 0.0453,  -0.1755, 0.1868,  -0.2547, 0.1075,  0.0623,
166     0.1698,  -0.0170, 0.1585,  -0.0736, -0.4358, -0.0113, -0.6792, -0.0849,
167     -0.0396, -0.6056, 0.1358,  0.1189,  0.2547,  0.1528,  0.2887,  0.0453,
168     -0.1075, -0.3283, -0.0453, -0.0509, 0.2038,  0.2547,  0.0849,  -0.0566,
169     -0.1698, 0.0509,  -0.0113, -0.1585, 0.1924,  -0.0792, -0.1868, 0.0509,
170     -0.1698, -0.0849, -0.0170, 0.0453,  0.3170,  0.0906,  -0.5943, -0.1245,
171     0.1585,  -0.1755, -0.2151, 0.0906,  0.1924,  0.3170,  -0.2490, -0.5660,
172     -0.0283, 0.0962,  -0.1358, 0.1585,  0.0057,  -0.2604, 0.1189,  -0.0170,
173     0.3509,  0.0623,  0.0679,  -0.1302, -0.0792, 0.0906,  -0.0792, 0.0849,
174     -0.1924, 0.2604,  -0.1245, -0.3679, 0.0340,  0.0113,  -0.1698, 0.2490,
175     0.0283,  0.1019,  -0.3736, 0.1019,  -0.2207, -0.0340, 0.3170,  0.1755,
176     0.0962,  0.3226,  -0.0113, -0.1189, -0.2321, -0.0226, -0.2434, -0.0170,
177     -0.1585, -0.0283, -0.1132, 0.0679,  -0.4188, -0.0453, 0.1528,  -0.1302,
178     -0.3792, 0.1415,  -0.1358, -0.1811, 0.1302,  0.1415,  0.5207,  0.0509,
179     -0.1358, -0.0396, -0.2434, 0.0396,  0.0792,  -0.2264, -0.1415, 0.0906,
180     0.1245,  0.0170,  0.0623,  -0.1415, 0.2773,  -0.3566, -0.0396, 0.2887,
181     0.4188,  0.1698,  -0.2547, 0.1132,  -0.0453, -0.0113, -0.1358, 0.1075,
182     0.0566,  0.1075,  0.2604,  -0.0849, -0.2490, 0.1415,  0.0509,  -0.2151,
183     0.0340,  0.1698,  0.0509,  -0.0906, 0.0566,  -0.1075, -0.2151, 0.2038,
184     -0.1924, -0.0113, 0.2830,  0.1358,  -0.1189, 0.0113,  -0.5603, -0.2830,
185     -0.2943, 0.0453,  -0.0396, 0.1358,  0.0566,  0.2038,  -0.3283, -0.0509,
186     0.0509,  0.1641,  0.2094,  -0.2038, -0.1868, -0.1585, -0.2207, -0.1302,
187     0.0396,  -0.1019, -0.0679, 0.1075,  -0.4584, -0.2207, 0.2434,  -0.0113,
188     0.0849,  0.1755,  -0.3056, 0.1585,  -0.2547, 0.0453,  0.0906,  -0.1358,
189     -0.0679, -0.0509, 0.0679,  -0.3509, 0.0057,  0.0453,  0.4132,  -0.1981,
190     0.2264,  -0.0736, 0.1075,  0.0679,  -0.0906, -0.3113, 0.0509,  0.0849,
191     0.2604,  0.0623,  -0.3113, 0.3849,  0.0000,  0.6396,  -0.2038, -0.1019,
192     0.1245,  -0.0453, 0.1641,  0.1075,  -0.1075, -0.2660, -0.4528, -0.0566,
193     -0.0170, 0.0453,  0.0340,  0.1189,  -0.2434, -0.0283, -0.1811, 0.2547,
194     0.0000,  -0.0226, 0.4471,  0.1019,  -0.1472, 0.0849,  0.1075,  0.1075,
195     0.0283,  -0.2773, 0.4415,  -0.1811, 0.2717,  0.3170,  0.0509,  0.0623,
196     -0.0962, 0.1585,  -0.0792, -0.1811, -0.0792, -0.3283, 0.0962,  -0.1698,
197     -0.0736, 0.0453,  0.0962,  -0.3566, -0.4584, 0.3396,  -0.4811, 0.3056,
198     -0.1755, 0.2490,  -0.1698, -0.2377, -0.3339, -0.0453, 0.1811,  0.0736,
199     0.0340,  -0.0962, -0.0113, -0.3056, -0.3339, 0.2038,  0.2038,  -0.1924,
200     0.2547,  -0.4471, -0.0849, -0.2038, 0.3566,  -0.4811, 0.3453,  0.0849,
201     0.1189,  0.3170,  -0.1358, 0.2717,  0.0113,  -0.4754, -0.1924, 0.4245,
202     -0.2773, 0.3453,  0.2264,  0.2943,  0.5320,  0.2773,  -0.2264, -0.1019,
203     -0.1132, -0.3962, 0.3679,  0.0509,  -0.0623, -0.0906, -0.5603, -0.1641,
204     -0.3170, -0.2377, 0.1415,  -0.0509, 0.0792,  0.0170,  -0.0226, -0.0057,
205     -0.1358, -0.4245, 0.3905,  0.3113,  0.0340,  -0.1189, 0.2887,  -0.2943,
206     -0.3056, 0.2434,  0.1019,  -0.0170, 0.3849,  0.1528,  -0.0736, -0.0170,
207     0.0792,  0.1755,  0.0509,  0.3509,  0.1472,  0.1528,  0.1472,  0.0057,
208     0.0113,  -0.0113, -0.3283, -0.3962, -0.0792, -0.1245, -0.0283, -0.1868,
209     0.4019,  0.2943,  -0.0906, -0.2321, 0.6056,  0.1189,  0.0340,  -0.2207,
210     -0.0453, 0.3339,  0.2377,  -0.1641, 0.3736,  0.2151,  -0.2547, 0.0453,
211     0.1924,  -0.1019, -0.0340, -0.2207, 0.3962,  -0.4471, -0.2547, -0.2151,
212     -0.3736, 0.0283,  0.1189,  0.0283,  0.0736,  0.0396,  0.1019,  0.0283,
213     0.0170,  0.2321,  0.3509,  -0.0226, -0.0226, 0.0736,  0.0283,  0.1641,
214     -0.0906, 0.1811,  0.0226,  0.5716,  -0.0396, -0.0509, -0.1641, -0.0509,
215     0.4132,  -0.2604, 0.1019,  -0.0283, -0.0340, 0.0453,  0.1472,  -0.0057,
216     0.2717,  -0.2094, 0.3396,  0.0340,  0.1245,  0.2547,  -0.5886, 0.2717,
217     -0.0906, 0.1641,  0.0962,  -0.0792, -0.0113, 0.2264,  -0.0736, 0.3170,
218     0.0623,  0.0679,  0.0623,  -0.0792, -0.2207, 0.1924,  0.1245,  -0.2773};
219 const int representative_64x16_bias_dims[] = {1, 16};
220 const float representative_64x16_bias_data[] = {
221     -0.0084, 0.0006,  0.0000,  0.0000,  -0.0087, -0.0006, -0.0003, -0.0003,
222     0.0006,  -0.0003, -0.0003, -0.0003, -0.0253, 0.0012,  0.0000,  0.0000};
223 const float representative_64x16_golden[] = {
224     3.8624,  -2.9580, 4.3043,  -1.2844, -1.5769, -2.7998, -0.1011, -3.4029,
225     -1.0557, -7.1931, -1.4852, -0.4163, 1.7186,  -0.6965, 0.3580,  2.7378};
226 const int representative_64x16_output_size = 16;
227 const int representative_64x16_output_dims[] = {2, 1, 16};
228 
229 template <typename T>
ValidateFullyConnectedGoldens(TfLiteTensor * tensors,const int tensors_size,const TfLiteFusedActivation activation,const float tolerance,const int output_len,const T * golden,T * output_data)230 TfLiteStatus ValidateFullyConnectedGoldens(
231     TfLiteTensor* tensors, const int tensors_size,
232     const TfLiteFusedActivation activation, const float tolerance,
233     const int output_len, const T* golden, T* output_data) {
234   TfLiteFullyConnectedParams builtin_data = {
235       activation, kTfLiteFullyConnectedWeightsFormatDefault, false, false};
236 
237   int inputs_array_data[] = {3, 0, 1, 2};
238   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
239   int outputs_array_data[] = {1, 3};
240   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
241 
242   const TfLiteRegistration registration = Register_FULLY_CONNECTED();
243   micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
244                              outputs_array,
245                              reinterpret_cast<void*>(&builtin_data));
246 
247   TfLiteStatus status = runner.InitAndPrepare();
248   if (status != kTfLiteOk) {
249     return status;
250   }
251 
252   status = runner.Invoke();
253   if (status != kTfLiteOk) {
254     return status;
255   }
256 
257   for (int i = 0; i < output_len; ++i) {
258     TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], tolerance);
259   }
260   return kTfLiteOk;
261 }
262 
263 #if !defined(XTENSA)  // Needed to avoid build error from unused functions.
TestFullyConnectedFloat(const int * input_dims_data,const float * input_data,const int * weights_dims_data,const float * weights_data,const int * bias_dims_data,const float * bias_data,const float * golden,const int * output_dims_data,TfLiteFusedActivation activation,float * output_data)264 TfLiteStatus TestFullyConnectedFloat(
265     const int* input_dims_data, const float* input_data,
266     const int* weights_dims_data, const float* weights_data,
267     const int* bias_dims_data, const float* bias_data, const float* golden,
268     const int* output_dims_data, TfLiteFusedActivation activation,
269     float* output_data) {
270   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
271   TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
272   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
273   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
274   const int output_dims_count = ElementCount(*output_dims);
275 
276   constexpr int inputs_size = 3;
277   constexpr int outputs_size = 1;
278   constexpr int tensors_size = inputs_size + outputs_size;
279   TfLiteTensor tensors[tensors_size] = {
280       CreateTensor(input_data, input_dims),
281       CreateTensor(weights_data, weights_dims),
282       CreateTensor(bias_data, bias_dims),
283       CreateTensor(output_data, output_dims),
284   };
285 
286   return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 1e-4f,
287                                        output_dims_count, golden, output_data);
288 }
289 #endif
290 
291 template <typename T>
TestFullyConnectedQuantized(const int * input_dims_data,const float * input_data,T * input_quantized,const float input_scale,const int input_zero_point,const int * weights_dims_data,const float * weights_data,T * weights_quantized,const float weights_scale,const int weights_zero_point,const int * bias_dims_data,const float * bias_data,int32_t * bias_quantized,const float * golden,T * golden_quantized,const int * output_dims_data,const float output_scale,const int output_zero_point,TfLiteFusedActivation activation,T * output_data)292 TfLiteStatus TestFullyConnectedQuantized(
293     const int* input_dims_data, const float* input_data, T* input_quantized,
294     const float input_scale, const int input_zero_point,
295     const int* weights_dims_data, const float* weights_data,
296     T* weights_quantized, const float weights_scale,
297     const int weights_zero_point, const int* bias_dims_data,
298     const float* bias_data, int32_t* bias_quantized, const float* golden,
299     T* golden_quantized, const int* output_dims_data, const float output_scale,
300     const int output_zero_point, TfLiteFusedActivation activation,
301     T* output_data) {
302   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
303   TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
304   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
305   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
306   const int output_dims_count = ElementCount(*output_dims);
307 
308   constexpr int inputs_size = 3;
309   constexpr int outputs_size = 1;
310   constexpr int tensors_size = inputs_size + outputs_size;
311   TfLiteTensor tensors[tensors_size] = {
312       CreateQuantizedTensor(input_data, input_quantized, input_dims,
313                             input_scale, input_zero_point),
314       CreateQuantizedTensor(weights_data, weights_quantized, weights_dims,
315                             weights_scale, weights_zero_point),
316       CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
317                                 input_scale, weights_scale),
318       CreateQuantizedTensor(output_data, output_dims, output_scale,
319                             output_zero_point),
320   };
321 
322   Quantize(golden, golden_quantized, output_dims_count, output_scale,
323            output_zero_point);
324 
325   return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 0.0f,
326                                        output_dims_count, golden_quantized,
327                                        output_data);
328 }
329 
330 }  // namespace
331 }  // namespace testing
332 }  // namespace tflite
333 
334 TF_LITE_MICRO_TESTS_BEGIN
335 
336 #if !defined(XTENSA)  // TODO(b/170503075): xtensa kernels are less general than
337                       // reference kernels and we ifdef out test cases that are
338                       // currently known to fail.
TF_LITE_MICRO_TEST(SimpleTest)339 TF_LITE_MICRO_TEST(SimpleTest) {
340   float output_data[tflite::testing::simple_output_size];
341   TF_LITE_MICRO_EXPECT_EQ(
342       tflite::testing::TestFullyConnectedFloat(
343           tflite::testing::simple_input_dims,
344           tflite::testing::simple_input_data,
345           tflite::testing::simple_weights_dims,
346           tflite::testing::simple_weights_data,
347           tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
348           tflite::testing::simple_golden, tflite::testing::simple_output_dims,
349           kTfLiteActNone, output_data),
350       kTfLiteOk);
351 }
352 
TF_LITE_MICRO_TEST(SimpleTestQuantizedUInt8)353 TF_LITE_MICRO_TEST(SimpleTestQuantizedUInt8) {
354   const float input_scale = 1.0f;
355   const int input_zero_point = 127;
356   const float weights_scale = 1.0f;
357   const int weights_zero_point = 128;
358   const float output_scale = 0.5f;
359   const int output_zero_point = 127;
360 
361   uint8_t input_quantized[tflite::testing::simple_input_size];
362   uint8_t weights_quantized[tflite::testing::simple_weights_size];
363   int32_t bias_quantized[tflite::testing::simple_output_size];
364   uint8_t golden_quantized[tflite::testing::simple_output_size];
365   uint8_t output_data[tflite::testing::simple_output_size];
366 
367   TF_LITE_MICRO_EXPECT_EQ(
368       tflite::testing::TestFullyConnectedQuantized(
369           tflite::testing::simple_input_dims,
370           tflite::testing::simple_input_data, input_quantized, input_scale,
371           input_zero_point, tflite::testing::simple_weights_dims,
372           tflite::testing::simple_weights_data, weights_quantized,
373           weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
374           tflite::testing::simple_bias_data, bias_quantized,
375           tflite::testing::simple_golden, golden_quantized,
376           tflite::testing::simple_output_dims, output_scale, output_zero_point,
377           kTfLiteActNone, output_data),
378       kTfLiteOk);
379 }
380 #endif
381 
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8)382 TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) {
383   const float input_scale = 1.0f;
384   const int input_zero_point = -1;
385   const float weights_scale = 1.0f;
386   const int weights_zero_point = 0;
387   const float output_scale = 0.5f;
388   const int output_zero_point = -1;
389 
390   int8_t input_quantized[tflite::testing::simple_input_size];
391   int8_t weights_quantized[tflite::testing::simple_weights_size];
392   int32_t bias_quantized[tflite::testing::simple_output_size];
393   int8_t golden_quantized[tflite::testing::simple_output_size];
394   int8_t output_data[tflite::testing::simple_output_size];
395 
396   TF_LITE_MICRO_EXPECT_EQ(
397       tflite::testing::TestFullyConnectedQuantized(
398           tflite::testing::simple_input_dims,
399           tflite::testing::simple_input_data, input_quantized, input_scale,
400           input_zero_point, tflite::testing::simple_weights_dims,
401           tflite::testing::simple_weights_data, weights_quantized,
402           weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
403           tflite::testing::simple_bias_data, bias_quantized,
404           tflite::testing::simple_golden, golden_quantized,
405           tflite::testing::simple_output_dims, output_scale, output_zero_point,
406           kTfLiteActNone, output_data),
407       kTfLiteOk);
408 }
409 
TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8)410 TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) {
411   const float input_scale = 1.0f;
412   const int input_zero_point = -1;
413   const float weights_scale = 1.0f;
414   const int weights_zero_point = 0;
415 
416   const float output_scale = 0.5f;
417   const int output_zero_point = -1;
418 
419   const int input_dims_4d[] = {4, 1, 1, 2, 10};
420 
421   int8_t input_quantized[tflite::testing::simple_input_size];
422   int8_t weights_quantized[tflite::testing::simple_weights_size];
423   int32_t bias_quantized[tflite::testing::simple_output_size];
424   int8_t golden_quantized[tflite::testing::simple_output_size];
425   int8_t output_data[tflite::testing::simple_output_size];
426 
427   TF_LITE_MICRO_EXPECT_EQ(
428       tflite::testing::TestFullyConnectedQuantized(
429           input_dims_4d, tflite::testing::simple_input_data, input_quantized,
430           input_scale, input_zero_point, tflite::testing::simple_weights_dims,
431           tflite::testing::simple_weights_data, weights_quantized,
432           weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
433           tflite::testing::simple_bias_data, bias_quantized,
434           tflite::testing::simple_golden, golden_quantized,
435           tflite::testing::simple_output_dims, output_scale, output_zero_point,
436           kTfLiteActNone, output_data),
437       kTfLiteOk);
438 }
439 
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Relu)440 TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Relu) {
441   const float input_scale = 1.0f;
442   const int input_zero_point = -1;
443   const float weights_scale = 1.0f;
444   const int weights_zero_point = 0;
445 
446   const float output_scale = 0.5f;
447   const int output_zero_point = -128;
448 
449   int8_t input_quantized[tflite::testing::relu_input_size];
450   int8_t weights_quantized[tflite::testing::relu_weights_size];
451   int32_t bias_quantized[tflite::testing::relu_output_size];
452   int8_t golden_quantized[tflite::testing::relu_output_size];
453   int8_t output_data[tflite::testing::relu_output_size];
454 
455   TF_LITE_MICRO_EXPECT_EQ(
456       tflite::testing::TestFullyConnectedQuantized(
457           tflite::testing::relu_input_dims, tflite::testing::relu_input_data,
458           input_quantized, input_scale, input_zero_point,
459           tflite::testing::relu_weights_dims,
460           tflite::testing::relu_weights_data, weights_quantized, weights_scale,
461           weights_zero_point, tflite::testing::relu_bias_dims,
462           tflite::testing::relu_bias_data, bias_quantized,
463           tflite::testing::relu_golden, golden_quantized,
464           tflite::testing::relu_output_dims, output_scale, output_zero_point,
465           kTfLiteActRelu, output_data),
466       kTfLiteOk);
467 }
468 
469 #if !defined(XTENSA)  // TODO(b/170503075): xtensa kernels are less general than
470                       // reference kernels and we ifdef out test cases that are
471                       // currently known to fail.
TF_LITE_MICRO_TEST(SimpleTestQuantizedUInt8Relu)472 TF_LITE_MICRO_TEST(SimpleTestQuantizedUInt8Relu) {
473   const float input_scale = 1.0f;
474   const int input_zero_point = 127;
475   const float weights_scale = 1.0f;
476   const int weights_zero_point = 128;
477 
478   const float output_scale = 0.5f;
479   const int output_zero_point = 0;
480 
481   uint8_t input_quantized[tflite::testing::relu_input_size];
482   uint8_t weights_quantized[tflite::testing::relu_weights_size];
483   int32_t bias_quantized[tflite::testing::relu_output_size];
484   uint8_t golden_quantized[tflite::testing::relu_output_size];
485   uint8_t output_data[tflite::testing::relu_output_size];
486 
487   TF_LITE_MICRO_EXPECT_EQ(
488       tflite::testing::TestFullyConnectedQuantized(
489           tflite::testing::relu_input_dims, tflite::testing::relu_input_data,
490           input_quantized, input_scale, input_zero_point,
491           tflite::testing::relu_weights_dims,
492           tflite::testing::relu_weights_data, weights_quantized, weights_scale,
493           weights_zero_point, tflite::testing::relu_bias_dims,
494           tflite::testing::relu_bias_data, bias_quantized,
495           tflite::testing::relu_golden, golden_quantized,
496           tflite::testing::relu_output_dims, output_scale, output_zero_point,
497           kTfLiteActRelu, output_data),
498       kTfLiteOk);
499 }
500 
TF_LITE_MICRO_TEST(SimpleTest4DInput)501 TF_LITE_MICRO_TEST(SimpleTest4DInput) {
502   const int input_dims_4d[] = {4, 1, 1, 2, 10};
503 
504   float output_data[tflite::testing::simple_output_size];
505 
506   TF_LITE_MICRO_EXPECT_EQ(
507       tflite::testing::TestFullyConnectedFloat(
508           input_dims_4d, tflite::testing::simple_input_data,
509           tflite::testing::simple_weights_dims,
510           tflite::testing::simple_weights_data,
511           tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
512           tflite::testing::simple_golden, tflite::testing::simple_output_dims,
513           kTfLiteActNone, output_data),
514       kTfLiteOk);
515 }
516 
TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedUInt8)517 TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedUInt8) {
518   const float input_scale = 1.0f;
519   const int input_zero_point = 127;
520   const float weights_scale = 1.0f;
521   const int weights_zero_point = 128;
522 
523   const float output_scale = 0.5f;
524   const int output_zero_point = 127;
525 
526   const int input_dims_4d[] = {4, 1, 1, 2, 10};
527 
528   uint8_t input_quantized[tflite::testing::simple_input_size];
529   uint8_t weights_quantized[tflite::testing::simple_weights_size];
530   int32_t bias_quantized[tflite::testing::simple_output_size];
531   uint8_t golden_quantized[tflite::testing::simple_output_size];
532   uint8_t output_data[tflite::testing::simple_output_size];
533 
534   TF_LITE_MICRO_EXPECT_EQ(
535       tflite::testing::TestFullyConnectedQuantized(
536           input_dims_4d, tflite::testing::simple_input_data, input_quantized,
537           input_scale, input_zero_point, tflite::testing::simple_weights_dims,
538           tflite::testing::simple_weights_data, weights_quantized,
539           weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
540           tflite::testing::simple_bias_data, bias_quantized,
541           tflite::testing::simple_golden, golden_quantized,
542           tflite::testing::simple_output_dims, output_scale, output_zero_point,
543           kTfLiteActNone, output_data),
544       kTfLiteOk);
545 }
546 
TF_LITE_MICRO_TEST(Representative1x64Input1x16Output)547 TF_LITE_MICRO_TEST(Representative1x64Input1x16Output) {
548   float output_data[tflite::testing::representative_64x16_output_size];
549 
550   TF_LITE_MICRO_EXPECT_EQ(
551       tflite::testing::TestFullyConnectedFloat(
552           tflite::testing::representative_64x16_input_dims,
553           tflite::testing::representative_64x16_input_data,
554           tflite::testing::representative_64x16_weights_dims,
555           tflite::testing::representative_64x16_weights_data,
556           tflite::testing::representative_64x16_bias_dims,
557           tflite::testing::representative_64x16_bias_data,
558           tflite::testing::representative_64x16_golden,
559           tflite::testing::representative_64x16_output_dims, kTfLiteActNone,
560           output_data),
561       kTfLiteOk);
562 }
563 
TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedUInt8)564 TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedUInt8) {
565   const float input_scale = 0.051445;
566   const int input_zero_point = 0;
567   const float weights_scale = 0.005660;
568   const int weights_zero_point = 128;
569 
570   const float output_scale = 0.069785;
571   const int output_zero_point = 119;
572 
573   uint8_t input_quantized[tflite::testing::representative_64x16_input_size];
574   uint8_t weights_quantized[tflite::testing::representative_64x16_weights_size];
575   int32_t bias_quantized[tflite::testing::representative_64x16_output_size];
576   uint8_t golden_quantized[tflite::testing::representative_64x16_output_size];
577   uint8_t output_data[tflite::testing::representative_64x16_output_size];
578 
579   TF_LITE_MICRO_EXPECT_EQ(
580       tflite::testing::TestFullyConnectedQuantized(
581           tflite::testing::representative_64x16_input_dims,
582           tflite::testing::representative_64x16_input_data, input_quantized,
583           input_scale, input_zero_point,
584           tflite::testing::representative_64x16_weights_dims,
585           tflite::testing::representative_64x16_weights_data, weights_quantized,
586           weights_scale, weights_zero_point,
587           tflite::testing::representative_64x16_bias_dims,
588           tflite::testing::representative_64x16_bias_data, bias_quantized,
589           tflite::testing::representative_64x16_golden, golden_quantized,
590           tflite::testing::representative_64x16_output_dims, output_scale,
591           output_zero_point, kTfLiteActNone, output_data),
592       kTfLiteOk);
593 }
594 
595 #endif
596 
TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedInt8)597 TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedInt8) {
598   const float input_scale = 0.051445;
599   const int input_zero_point = -128;
600   const float weights_scale = 0.005660;
601   const int weights_zero_point = 0;
602 
603   const float output_scale = 0.069785;
604   const int output_zero_point = -9;
605 
606   int8_t input_quantized[tflite::testing::representative_64x16_input_size];
607   int8_t weights_quantized[tflite::testing::representative_64x16_weights_size];
608   int32_t bias_quantized[tflite::testing::representative_64x16_output_size];
609   int8_t golden_quantized[tflite::testing::representative_64x16_output_size];
610   int8_t output_data[tflite::testing::representative_64x16_output_size];
611 
612   TF_LITE_MICRO_EXPECT_EQ(
613       tflite::testing::TestFullyConnectedQuantized(
614           tflite::testing::representative_64x16_input_dims,
615           tflite::testing::representative_64x16_input_data, input_quantized,
616           input_scale, input_zero_point,
617           tflite::testing::representative_64x16_weights_dims,
618           tflite::testing::representative_64x16_weights_data, weights_quantized,
619           weights_scale, weights_zero_point,
620           tflite::testing::representative_64x16_bias_dims,
621           tflite::testing::representative_64x16_bias_data, bias_quantized,
622           tflite::testing::representative_64x16_golden, golden_quantized,
623           tflite::testing::representative_64x16_output_dims, output_scale,
624           output_zero_point, kTfLiteActNone, output_data),
625       kTfLiteOk);
626 }
627 
628 TF_LITE_MICRO_TESTS_END
629