1 /*
2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <assert.h>
13 #include <math.h>
14 #include <stdio.h>
15
16 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
17
18 #include "config/av1_rtcd.h"
19
20 #include "av1/encoder/cnn.h"
21
22 #define SQR(x) ((x) * (x))
23
24 // Best possible pixelwise guarenteed preicison given each float has at most
25 // 3 specified decimals.
26 #define PIXELWISE_FLOAT_TOL 1E-2
27
28 #define MSE_FLOAT_TOL 1E-6
29 #define MSE_INT_TOL 0
30
31 namespace {
32
33 class CNNTest : public ::testing::Test {
34 protected:
RunCNNTest(int image_width,int image_height,const float * input,const float * expected,const CNN_CONFIG * cnn_config,int in_stride,CNN_THREAD_DATA * thread_data,double tolerance)35 static void RunCNNTest(int image_width, int image_height, const float *input,
36 const float *expected, const CNN_CONFIG *cnn_config,
37 int in_stride, CNN_THREAD_DATA *thread_data,
38 double tolerance) {
39 int out_width, out_height, out_channels;
40 av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
41 &out_height, &out_channels);
42
43 const int out_size = out_width * out_height;
44 const int out_stride = out_width;
45
46 float *output_ =
47 (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
48 float *output[CNN_MAX_CHANNELS] = { nullptr };
49 for (int channel = 0; channel < out_channels; ++channel) {
50 output[channel] = output_ + (channel * out_size);
51 }
52 const int num_outputs = 1;
53 const int output_chs[1] = { out_channels };
54 const int output_strides[1] = { out_stride };
55 CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
56 output };
57
58 RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
59 thread_data, &output_struct, &expected, tolerance);
60
61 aom_free(output_);
62 }
63
RunMultiOutCNNTest(const float ** input,int image_width,int image_height,int in_stride,const CNN_CONFIG * cnn_config,CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output,const float ** expected,double tolerance)64 static void RunMultiOutCNNTest(const float **input, int image_width,
65 int image_height, int in_stride,
66 const CNN_CONFIG *cnn_config,
67 CNN_THREAD_DATA *thread_data,
68 CNN_MULTI_OUT *output, const float **expected,
69 double tolerance) {
70 const int num_outputs = output->num_outputs;
71 const int *output_chs = output->output_channels;
72
73 int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
74 int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
75 int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
76
77 av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
78 out_heights, not_used);
79 av1_cnn_predict(input, image_width, image_height, in_stride, cnn_config,
80 thread_data, output);
81
82 int channel_offset = 0;
83 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
84 const float *expected_out = expected[output_idx];
85 const int curr_output_chs = output_chs[output_idx];
86 const int out_size = out_widths[output_idx] * out_heights[output_idx];
87
88 double mse = 0;
89 int expected_ite = 0;
90 for (int channel = 0; channel < curr_output_chs; ++channel) {
91 const float *buf_out = output->output_buffer[channel_offset];
92
93 for (int i = 0; i < out_size; ++i) {
94 EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
95 PIXELWISE_FLOAT_TOL)
96 << " output " << output_idx << " channel " << channel << " pixel "
97 << expected_ite % out_size << ": " << expected_out[expected_ite]
98 << "/" << buf_out[i] << std::endl;
99 mse += SQR(expected_out[expected_ite] - buf_out[i]);
100 expected_ite++;
101 }
102
103 channel_offset++;
104 }
105 mse /= (out_size * curr_output_chs);
106 EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
107 }
108
109 aom_free(out_widths);
110 aom_free(out_heights);
111 aom_free(not_used);
112 }
113
AssignLayerWeightsBiases(CNN_CONFIG * cnn_config,float * weights,float * bias)114 static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
115 float *bias) {
116 size_t weight_offset = 0;
117 size_t bias_offset = 0;
118 for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
119 CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
120 layer_config->weights = weights + weight_offset;
121 layer_config->bias = bias + bias_offset;
122 weight_offset += layer_config->filter_width *
123 layer_config->filter_height * layer_config->in_channels *
124 layer_config->out_channels;
125 bias_offset += layer_config->out_channels;
126
127 ASSERT_NE(layer_config->weights, nullptr);
128 ASSERT_NE(layer_config->bias, nullptr);
129 }
130 }
131 };
132
133 } // namespace
134
TEST_F(CNNTest,TestMultilayerConvolution)135 TEST_F(CNNTest, TestMultilayerConvolution) {
136 int image_height = 16;
137 int image_width = 16;
138 int filter_height = 5;
139 int filter_width = 4;
140
141 float input[] = {
142 -3, 1, -3, 2, -2, -2, 2, -2, 1, -2, -3, 1, 2, 2, 2, -2, 0, 1, -1,
143 -3, -1, -1, 1, 0, -3, 1, 0, -1, 1, 0, 0, -3, -3, -3, 0, 2, 1, -1,
144 2, 0, 1, -3, -1, 2, 2, 1, -2, 0, -1, 0, -2, -2, -1, 1, 0, 0, 0,
145 -2, -2, -2, 1, 1, -2, 1, 1, -2, -2, 1, -2, -1, -2, -3, 2, -3, -1, 1,
146 0, -2, -2, -2, 1, -2, -2, -1, -1, 2, 2, 2, -1, 1, -3, -3, 0, 2, 0,
147 2, 1, -3, -3, 1, 2, 2, 1, -2, -3, 0, -3, 0, -3, -2, 0, 1, 1, 0,
148 -3, 2, -1, 2, 1, 0, 1, -2, 1, -1, -1, 2, 0, -2, -3, 1, 1, -2, -1,
149 -3, -3, -1, 0, -3, -2, 0, 0, 1, 0, -3, -2, -1, 1, 0, 2, 1, 0, -3,
150 -2, -3, -3, -1, 0, -2, 2, -1, -3, 0, -1, -1, 2, 0, -3, -2, -1, 0, 0,
151 1, -2, 1, 2, 1, 2, 2, -3, 2, -1, 0, 0, -1, 0, 2, 2, -1, 2, -2,
152 1, 1, -3, -3, 1, -1, -1, -2, 2, -2, -2, 2, -1, -3, 2, -3, 1, -1, -1,
153 -3, 1, -1, 1, 0, -3, -3, 1, -3, -3, 0, 2, 2, -2, -1, 2, 0, 2, 1,
154 -1, -3, 0, 0, -1, -1, 1, 0, 2, 0, -3, 2, 1, 0, 1, -3, 2, -3, -3,
155 -1, -3, -3, 2, 0, 2, -2, 1, -1,
156 };
157
158 float weights[] = {
159 -2, 2, -2, 2, -1, -3, 2, 2, 0, 0, -3, -1, -2, -3, 1, -1, 0, 0, 0,
160 2, -2, 2, -2, -3, 1, 1, 1, -3, -1, 0, 1, 2, -2, 0, -1, -3, -1, -2,
161 2, -3, -3, 1, -2, -3, 0, 2, 1, -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
162 -1, -3, -1, -2, -2, -3, 2, 0, -3, 0, -3, -3, 1, -3, -1, 0, -1, 1, 1,
163 -1, 1, -2, 0, 2, 0, -3, 1, -1, -1, 2, 0, 1, -3, -3, 1, 2, -3, -3,
164 1, -3, 2, 0, -3, 1, 2, 2, -2, -1, -2, 1, 1, 0, -2, -2, 1, 2, -1,
165 -3, 1, -2, 2, -3, -2, -3, 2, 1, 0, -2, 0, 1, -3, 2, -2, -2, 0, 2,
166 -3, 2, 0, 0, 1, -2, 1, 1, -2, -1, -2, 1, -2, 0, -2, -2, 0, -1, -1,
167 -3, -3, -3, 1, -3, -2, 2, -1, 2, 0, 2, -2, 2, -2, 1, -3, -3, -1, 0,
168 2, 2, 1, -1, -3, -1, -3, 2, 1, -2, 0, -3, -1, -3, -1, 2, 1, 0, 2,
169 -1, 1, 0, 1, 2, -1, -2, 2, 1, -3, -1, -3, 0, 1, -2, 0, -2, -3, 0,
170 -2, 2, 2, 0, 0, 2, -3, 2, -3, -2, 1, 2, -3, -3, -1, -3, 0, -3, -3,
171 -2, -2, -2, 0, 0, 1, 0, 0, -1, 0, 0, -3, 0, -3, -1, -2, 1, -2, -1,
172 2, -2, 0, 0, 1, 0, -2, -1, 0, -3, 1, 0, -1, -3, 1, -1, 1, -1, -3,
173 1, 0, 1, 1, -1, 2, 2, 0, 0, 1, -3, 2, -2, -2, -3, -2, -1, -2, 2,
174 0, 2, -2, -3, -1, -3, 2, 2, -1, 2, 2, -1, 0, -3, 1,
175 };
176
177 float bias[] = {
178 1, -1, 0, 1, 1, 1, -2,
179 };
180
181 float expected_same[] = {
182 -1125, 2926, 6406, 631, -1244, 97, -1454, 2526, 1065, 3292, 3464,
183 2553, -330, 532, 1038, 1182, -402, 3758, 3392, 9854, 4365, 1408,
184 4736, 3134, 3838, 2409, 3221, 4350, 6750, 4045, 815, 1188, 2959,
185 9802, 9590, 4572, 5740, 4253, 1701, 7974, 7012, 6854, 7093, 3907,
186 4539, 3886, 4267, 3505, 465, 7824, 9219, 10026, 7968, 957, 2295,
187 5594, 10811, 9641, 5950, 10043, 8783, 3132, 1421, 1110, 4108, 13929,
188 10660, -84, -61, 3932, -180, 6811, 13393, 15147, 15640, 9337, 6961,
189 3808, 1604, 1398, 1047, 6739, 10144, 6517, 4698, 2678, 7389, 2595,
190 5248, 12075, 11272, 13951, 8820, 1090, 2199, 2206, 2788, 12116, 6683,
191 2612, -291, 3183, 9414, 12316, 14524, 12333, 13208, 7832, 4664, 4657,
192 3534, 1298, -666, 4250, 7707, 9103, 5760, 688, 9571, 15782, 14203,
193 14878, 17339, 14684, 8690, 5671, 875, 1429, 1531, 6173, 2984, 5558,
194 2996, 7928, 6733, 16117, 15262, 12757, 7980, 3923, 4795, 5973, 2051,
195 455, -1922, 1816, 5906, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
196 7451, 6666, 74, -1645, -35, -391, 3813, 7324, 892, 1656, 6095,
197 12193, 14648, 12156, 14663, 10251, 10325, 7821, 3925, 323, 697, 442,
198 1324, 4669, 7002, 5485, 5171, 5086, 10582, 11053, 9709, 11353, 8543,
199 5256, 2873, 235, -628, 1496, 1878, -867, 3420, 6865, 5937, 10182,
200 13277, 10069, 10789, 5998, 624, -2082, 4417, 1258, -1080, -819, -1430,
201 1033, 5220, 6335, 8471, 8980, 11908, 14430, 12584, 8404, 1576, -803,
202 985, 1481, 1367, -193, 873, 3684, 2288, 6676, 9477, 11155, 9602,
203 9707, 10507, 4739, 3174, -575, -178, 3002, 1710, 423, -477, 554,
204 3088, 2029, 5113, 5000, 3771, 6090, 5365, 1185, 2855, 399, -312,
205 -1577, 176, 955,
206 };
207
208 float expected_replicate[] = {
209 13768, 13528, 12999, 6906, 4618, 4043, 2611, 9955, 6685, 4776, 2753,
210 1036, 3063, 4544, 5183, 7349, 12451, 12501, 9131, 12753, 8908, 4058,
211 6299, 7542, 7115, 3307, 3360, 3543, 9754, 7808, 5991, 9019, 14320,
212 14919, 12492, 6871, 7373, 3336, 2085, 10604, 9377, 6882, 5009, 3103,
213 6220, 6278, 7588, 10196, 11045, 11563, 11842, 11911, 8279, 2030, 1858,
214 6368, 12123, 9909, 6347, 10345, 9365, 4038, 1673, 3051, 16492, 16649,
215 12276, 408, -301, 4122, -654, 7864, 14038, 15279, 15315, 9744, 8243,
216 5298, 746, 380, 9824, 9124, 10895, 6640, 4712, 2669, 6980, 2759,
217 5385, 12345, 11336, 13129, 8600, 2370, 3682, 5219, 12407, 13123, 6784,
218 2612, -291, 3183, 9414, 12316, 14524, 12333, 13397, 7543, 3916, 4153,
219 4477, 4314, 7983, 8418, 9163, 9103, 5760, 688, 9571, 15782, 14203,
220 14878, 17718, 14570, 7940, 6642, 5094, 7133, 9964, 10219, 3224, 5558,
221 2996, 7928, 6733, 16117, 15262, 12757, 7958, 4401, 5187, 5476, 5529,
222 6055, 2206, 3909, 6015, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
223 6967, 6840, 481, -1600, 274, 1, 10373, 8514, 1123, 2117, 6758,
224 12736, 16223, 13585, 15988, 11771, 10600, 7918, 4156, 2840, 3111, 3287,
225 6359, 7652, 8813, 6530, 6967, 7789, 13671, 13990, 13247, 13241, 9836,
226 5251, 3024, 2313, 1834, 4187, 2637, -1312, 2139, 7378, 7665, 11933,
227 15591, 15314, 15678, 9531, 2820, -1516, 3400, 1314, 22, 363, -2896,
228 -898, 5906, 7308, 10650, 12975, 16978, 20370, 18817, 12381, 4118, -861,
229 -137, 236, 1802, 1632, -350, 2334, 3400, 8680, 14064, 18216, 18675,
230 21765, 22871, 11491, 4937, -1555, -11, 1669, 2392, 3265, -5254, -217,
231 5001, 8063, 13444, 18884, 19706, 22794, 21064, 9545, 6689, -7, 289,
232 -2021, 504, 2347,
233 };
234
235 float expected_valid[] = {
236 2612, -291, 3183, 9414, 12316, 14524, 12333, 9103, 5760, 688,
237 9571, 15782, 14203, 14878, 5558, 2996, 7928, 6733, 16117, 15262,
238 12757, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
239 };
240
241 CNN_CONFIG cnn_config = { 3,
242 0,
243 0,
244 0,
245 0,
246 {
247 {
248 1,
249 filter_width,
250 filter_height,
251 3,
252 1,
253 1,
254 0,
255 nullptr,
256 nullptr,
257 PADDING_SAME_ZERO,
258 NONE,
259 0,
260 0,
261 BRANCH_NO_COPY,
262 BRANCH_NOC,
263 {},
264 {},
265 -1,
266 },
267 {
268 3,
269 filter_width,
270 filter_height,
271 3,
272 1,
273 1,
274 0,
275 nullptr,
276 nullptr,
277 PADDING_SAME_ZERO,
278 NONE,
279 0,
280 0,
281 BRANCH_NO_COPY,
282 BRANCH_NOC,
283 {},
284 {},
285 -1,
286 },
287 {
288 3,
289 filter_width,
290 filter_height,
291 1,
292 1,
293 1,
294 0,
295 nullptr,
296 nullptr,
297 PADDING_SAME_ZERO,
298 NONE,
299 0,
300 0,
301 BRANCH_NO_COPY,
302 BRANCH_NOC,
303 {},
304 {},
305 0,
306 },
307 } };
308
309 // Weights and biases need to be specified separately because
310 // of the offset.
311 AssignLayerWeightsBiases(&cnn_config, weights, bias);
312
313 CNN_THREAD_DATA thread_data = { 1, NULL };
314
315 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
316 image_width, &thread_data, MSE_INT_TOL);
317
318 for (int i = 0; i < cnn_config.num_layers; ++i) {
319 cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
320 }
321
322 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
323 image_width, &thread_data, MSE_INT_TOL);
324
325 for (int i = 0; i < cnn_config.num_layers; ++i) {
326 cnn_config.layer_config[i].pad = PADDING_VALID;
327 }
328
329 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
330 image_width, &thread_data, MSE_INT_TOL);
331 }
332
TEST_F(CNNTest,TestRELUSingleLayer)333 TEST_F(CNNTest, TestRELUSingleLayer) {
334 int image_width = 8;
335 int image_height = 8;
336 int filter_height = 5;
337 int filter_width = 4;
338 float input[] = {
339 0, -2, -3, 1, -1, 2, -2, 1, -3, -1, 0, 1, -2, -3, -2, -2,
340 1, -3, 2, -3, -1, -1, 2, 0, -2, -3, 0, -2, -3, 1, -1, -1,
341 2, -2, 0, -2, -3, -3, 1, 1, -1, 1, 0, 1, -3, 0, 2, 2,
342 0, -3, 1, -3, 2, -2, 1, -1, -1, -2, -3, -2, -1, -3, -2, -1,
343 };
344 float expected_same[] = {
345 9, 0, 1, 1, 0, 3, 0, 19, 0, 12, 10, 0, 0, 0, 5, 0,
346 0, 18, 21, 7, 19, 4, 3, 0, 0, 9, 16, 0, 11, 16, 0, 11,
347 12, 2, 0, 11, 0, 16, 6, 0, 8, 22, 13, 10, 12, 0, 0, 0,
348 0, 1, 2, 12, 29, 6, 10, 0, 13, 0, 0, 5, 8, 10, 0, 0,
349 };
350 float expected_replicate[] = {
351 18, 17, 12, 2, 0, 0, 5, 11, 0, 17, 22, 6, 0, 0, 17, 0,
352 0, 18, 21, 7, 19, 4, 3, 5, 3, 9, 16, 0, 11, 16, 0, 3,
353 3, 2, 0, 11, 0, 16, 6, 0, 17, 22, 13, 10, 12, 0, 0, 0,
354 0, 4, 1, 10, 30, 7, 10, 0, 23, 8, 0, 13, 15, 19, 8, 10,
355 };
356 float expected_valid[] = {
357 18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
358 };
359 float weights[] = {
360 -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
361 };
362 float bias[] = { -3 };
363
364 CNN_CONFIG cnn_config = { 1,
365 0,
366 0,
367 0,
368 0,
369 { {
370 1,
371 filter_width,
372 filter_height,
373 1,
374 1,
375 1,
376 0,
377 weights,
378 bias,
379 PADDING_SAME_ZERO,
380 RELU,
381 0,
382 0,
383 BRANCH_NO_COPY,
384 BRANCH_NOC,
385 {},
386 {},
387 0,
388 } } };
389
390 CNN_THREAD_DATA thread_data = { 1, NULL };
391
392 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
393 image_width, &thread_data, MSE_INT_TOL);
394
395 cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
396
397 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
398 image_width, &thread_data, MSE_INT_TOL);
399
400 cnn_config.layer_config[0].pad = PADDING_VALID;
401
402 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
403 image_width, &thread_data, MSE_INT_TOL);
404 }
405
TEST_F(CNNTest,TestVaryingStridesVaryingDimImages)406 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
407 float weights[] = {
408 1, -5, -3, -4, -1, 1, 2, -3, 2, 2, -1, 1, -5, 1, 1,
409 -3, -5, 3, 1, 4, -2, -5, -2, -3, -5, 0, -1, -5, 2, -2,
410 -2, 1, -2, -4, 1, 3, -2, 2, 0, -3, 2, -3, -2, -3,
411 };
412 float bias[] = { 2 };
413
414 CNN_CONFIG cnn_config = { 1,
415 0,
416 0,
417 0,
418 0,
419 {
420 {
421 1,
422 4,
423 11,
424 1,
425 7,
426 6,
427 0,
428 weights,
429 bias,
430 PADDING_SAME_ZERO,
431 NONE,
432 0,
433 0,
434 BRANCH_NO_COPY,
435 BRANCH_NOC,
436 {},
437 {},
438 0,
439 },
440 } };
441
442 int image_height = 24;
443 int image_width = 17;
444 float input[] = {
445 -1, -3, 4, 4, -5, 4, 3, -5, -1, -3, 4, -4, 2, -3, 3, -5, 2, -1, -5,
446 1, -1, 3, 1, -3, -3, 4, 0, 2, -3, -5, -5, -4, 0, -5, -2, -3, -1, -2,
447 2, -5, 4, 4, 0, -4, -3, 1, -3, -5, -4, -4, 1, -2, -3, 3, -3, -3, -1,
448 -5, -5, -2, 3, 1, -1, -5, -5, 1, -4, -2, -1, -2, -4, -4, 2, -2, 2, 1,
449 -2, -4, -1, 1, -2, -5, 3, -2, -1, -1, -5, -3, 1, -2, -2, -3, -1, -2, -4,
450 -2, 1, -4, -1, 4, 3, -4, 0, 4, 2, 2, 4, -3, -5, 2, 2, 1, -1, -4,
451 -2, 1, 3, 2, 0, 4, -1, -3, 2, 1, -4, 2, 2, -4, -2, 0, -2, -1, 4,
452 4, 2, 3, -4, 2, -4, -5, 4, -1, -3, -1, 0, -4, 1, 3, -1, -3, -5, 3,
453 -2, -4, 1, 2, -2, -3, -3, -5, 1, -3, -1, 0, -1, 3, -4, -1, -5, -5, 1,
454 0, 0, -2, -2, 2, -2, 0, 0, 2, 0, -3, 0, -1, -4, -4, -1, 3, -4, -4,
455 -1, 0, -5, -3, -2, 4, -3, -4, -4, 0, -5, 1, -2, -3, -3, -4, 4, 3, 4,
456 3, 3, -1, 3, 1, -3, -2, 3, 3, 0, 2, -4, -3, 2, 2, 0, -2, 4, -2,
457 2, -2, -1, -4, -2, 2, -4, 3, -1, 4, 1, 1, 4, -1, -4, -4, 1, 1, -2,
458 4, -1, 3, 2, -3, 4, 3, 1, 4, 0, -4, 2, 0, 2, 4, -2, -2, 4, 2,
459 -1, -2, 1, -3, 2, 3, -5, -3, 4, 4, 2, -5, -4, -5, -2, -4, 2, 0, 2,
460 -5, 4, -4, -2, -5, 2, 1, 0, 4, 1, -2, -3, -4, -3, -4, 3, 3, 2, 0,
461 -3, 1, -5, 4, 0, 4, -1, 3, -5, -5, -2, -1, -1, 4, 3, 3, 4, 3, -4,
462 4, -3, -3, -1, -4, -1, -4, -1, -2, 4, -2, -4, 4, 4, -3, -4, -1, 1, 2,
463 -1, -2, -2, 3, 2, 2, -3, 0, -1, 0, 3, 2, -5, 0, -4, 0, 0, 2, -4,
464 -1, -1, 0, -2, 0, 1, 0, 0, 4, -5, -1, -5, 2, -1, 0, 2, -1, 1, 3,
465 -3, -5, -2, -3, 4, -2, -2, -1, -3, -4, -1, -2, -4, 1, 4, -3, -2, -1, 3,
466 -3, -2, 3, 2, 1, -4, -3, -5, 1,
467 };
468 float expected_1[] = {
469 41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
470 };
471
472 CNN_THREAD_DATA thread_data = { 1, NULL };
473
474 RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
475 image_width, &thread_data, MSE_INT_TOL);
476
477 cnn_config.layer_config[0].skip_width = 6;
478 cnn_config.layer_config[0].skip_height = 7;
479
480 float expected_2[] = {
481 21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
482 };
483 RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
484 image_width, &thread_data, MSE_INT_TOL);
485
486 cnn_config.layer_config[0].skip_width = 3;
487 cnn_config.layer_config[0].skip_height = 10;
488
489 float expected_3[] = {
490 -26, -21, -35, 69, 49, 4, -51, -43, -56,
491 -41, 15, -44, 40, -62, 63, 38, 27, 47,
492 };
493 RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
494 image_width, &thread_data, MSE_INT_TOL);
495
496 cnn_config.layer_config[0].skip_width = 10;
497 cnn_config.layer_config[0].skip_height = 3;
498
499 float expected_4[] = {
500 21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
501 };
502
503 RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
504 image_width, &thread_data, MSE_INT_TOL);
505 }
506
TEST_F(CNNTest,TestMaxPool)507 TEST_F(CNNTest, TestMaxPool) {
508 int image_width = 8;
509 int image_height = 8;
510 int stride = 3;
511 float input[] = {
512 1, -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8, 5, -1, -1, 9,
513 -3, 0, -2, 0, 6, 3, -4, 8, 7, 8, 7, -1, 4, -1, 0, 2,
514 -5, -2, 8, 5, 5, 4, 2, 7, 4, 6, 2, 8, 8, -4, -3, -4,
515 -3, -1, 2, 3, 3, 6, -5, 8, 9, 5, 0, -2, -1, 6, 5, 7,
516 };
517
518 float expected[] = {
519 49, 58, 70, 68, 68, 70, 48, 57, 88,
520 };
521
522 float weights[] = {
523 3, 1, 3, 4, -1, 5, -2, 1, -4,
524 };
525
526 float bias[] = {
527 -3,
528 };
529
530 CNN_CONFIG cnn_config = { 1,
531 0,
532 0,
533 0,
534 0,
535 { {
536 1,
537 3,
538 3,
539 1,
540 stride,
541 stride,
542 1,
543 weights,
544 bias,
545 PADDING_SAME_ZERO,
546 NONE,
547 0,
548 0,
549 BRANCH_NO_COPY,
550 BRANCH_NOC,
551 {},
552 {},
553 0,
554 } } };
555
556 CNN_THREAD_DATA thread_data = { 1, NULL };
557
558 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
559 image_width, &thread_data, MSE_INT_TOL);
560 }
561
TEST_F(CNNTest,TestDeconvolveNonActivationSingleLayerSingleKernel)562 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
563 int image_width = 4;
564 int image_height = 7;
565 float input[] = {
566 9, 6, 181, 9, 218, 30, 80, 108, 68, 216, 70, 128, 179, 228,
567 33, 212, 34, 14, 48, 27, 230, 23, 202, 113, 80, 56, 122, 112,
568 };
569
570 float expected_1_same[] = {
571 15, -30, 36, -525, 377, -193, 558, 531, 6, -24, -15, 124,
572 166, -561, -356, -754, -3, -3, -3, -3, -3, -3, -3, -3,
573 433, -311, 711, 381, 247, -317, 453, 129, 215, -627, -409, -885,
574 17, -255, -55, -647, -3, -3, -3, -3, -3, -3, -3, -3,
575 133, -719, 633, -225, 785, 191, 463, 79, 65, 9, 77, -853,
576 -365, -949, -15, -667, -3, -3, -3, -3, -3, -3, -3, -3,
577 355, -866, 990, 207, 747, 12, 520, -116, 176, -312, -133, -1370,
578 -426, -802, 143, -771, -3, -3, -3, -3, -3, -3, -3, -3,
579 65, -79, 127, -59, 135, -90, 195, 114, 31, -91, -57, -133,
580 17, -176, -72, -276, -3, -3, -3, -3, -3, -3, -3, -3,
581 457, -302, 733, 58, 470, -475, 829, 490, 227, -670, -440, -790,
582 153, -588, -294, -1150, -3, -3, -3, -3, -3, -3, -3, -3,
583 157, -251, 349, -185, 409, -293, 587, 251, 77, -187, -107, -369,
584 7, -481, -135, -827, -3, -3, -3, -3, -3, -3, -3, -3,
585 };
586 float expected_1_valid[] = {
587 -30, 15, -30, 36, -525, 377, -193, 558, 531, 24, 24, 6,
588 6, -24, -15, 124, 166, -561, -356, -754, -21, -39, -3, -3,
589 -3, -3, -3, -3, -3, -3, -3, -3, -3, -657, 433, -311,
590 711, 381, 247, -317, 453, 129, 321, 321, 215, 215, -627, -409,
591 -885, 17, -255, -55, -647, -219, -435, -3, -3, -3, -3, -3,
592 -3, -3, -3, -3, -3, -3, -207, 133, -719, 633, -225, 785,
593 191, 463, 79, 381, 381, 65, 65, 9, 77, -853, -365, -949,
594 -15, -667, -259, -515, -3, -3, -3, -3, -3, -3, -3, -3,
595 -3, -3, -3, -540, 355, -866, 990, 207, 747, 12, 520, -116,
596 633, 633, 176, 176, -312, -133, -1370, -426, -802, 143, -771, -427,
597 -851, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3,
598 -105, 65, -79, 127, -59, 135, -90, 195, 114, 78, 78, 31,
599 31, -91, -57, -133, 17, -176, -72, -276, -57, -111, -3, -3,
600 -3, -3, -3, -3, -3, -3, -3, -3, -3, -693, 457, -302,
601 733, 58, 470, -475, 829, 490, 336, 336, 227, 227, -670, -440,
602 -790, 153, -588, -294, -1150, -229, -455, -3, -3, -3, -3, -3,
603 -3, -3, -3, -3, -3, -3, -243, 157, -251, 349, -185, 409,
604 -293, 587, 251, 333, 333, 77, 77, -187, -107, -369, 7, -481,
605 -135, -827, -227, -451,
606 };
607 float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
608 float bias_1[] = { -3 };
609
610 CNN_CONFIG cnn_config = { 1,
611 0,
612 0,
613 0,
614 0,
615 { {
616 1,
617 5,
618 2,
619 1,
620 2,
621 3,
622 0,
623 weights_1,
624 bias_1,
625 PADDING_SAME_ZERO,
626 NONE,
627 1,
628 0,
629 BRANCH_NO_COPY,
630 BRANCH_NOC,
631 {},
632 {},
633 0,
634 } } };
635
636 CNN_THREAD_DATA thread_data = { 1, NULL };
637
638 RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
639 image_width, &thread_data, MSE_INT_TOL);
640
641 // Change padding to valid
642 cnn_config.layer_config[0].pad = PADDING_VALID;
643
644 RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
645 image_width, &thread_data, MSE_INT_TOL);
646
647 float expected_12_same[] = {
648 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12, 24,
649 6, -30, -15, -33, -21, 166, 154, -546, -356, -718, -30, -21,
650 433, -221, 561, 711, -33, -153, 247, -83, -87, 453, -111, 321,
651 215, -657, -409, -845, -93, 17, -43, -243, -55, -215, -327, -219,
652 133, -71, -447, 633, -219, 435, 785, -73, -177, 463, -131, 381,
653 65, -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259,
654 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215, 633,
655 176, -540, -133, -491, -687, -426, -882, -102, 143, 77, -639, -427,
656 65, -37, 57, 127, -17, -105, 135, -51, 60, 195, -30, 78,
657 31, -105, -57, -125, -45, 17, -11, -147, -72, -168, -84, -57,
658 457, -233, 618, 733, -26, -540, 470, -205, 264, 829, -116, 336,
659 227, -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229,
660 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115, 333,
661 77, -243, -107, -267, -171, 7, -105, -369, -135, -379, -339, -227,
662 };
663 float expected_12_valid[] = {
664 -30, 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12,
665 24, 24, 6, 6, -30, -15, -33, -21, 166, 154, -546, -356,
666 -718, -30, -21, -39, -657, 433, -221, 561, 711, -33, -153, 247,
667 -83, -87, 453, -111, 321, 321, 215, 215, -657, -409, -845, -93,
668 17, -43, -243, -55, -215, -327, -219, -435, -207, 133, -71, -447,
669 633, -219, 435, 785, -73, -177, 463, -131, 381, 381, 65, 65,
670 -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259, -515,
671 -540, 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215,
672 633, 633, 176, 176, -540, -133, -491, -687, -426, -882, -102, 143,
673 77, -639, -427, -851, -105, 65, -37, 57, 127, -17, -105, 135,
674 -51, 60, 195, -30, 78, 78, 31, 31, -105, -57, -125, -45,
675 17, -11, -147, -72, -168, -84, -57, -111, -693, 457, -233, 618,
676 733, -26, -540, 470, -205, 264, 829, -116, 336, 336, 227, 227,
677 -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229, -455,
678 -243, 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115,
679 333, 333, 77, 77, -243, -107, -267, -171, 7, -105, -369, -135,
680 -379, -339, -227, -451,
681 };
682
683 // Change skip_width, skip_height to {2, 3}
684 cnn_config.layer_config[0].skip_width = 3;
685 cnn_config.layer_config[0].skip_height = 2;
686 // Set padding to same
687 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
688
689 RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
690 image_width, &thread_data, MSE_INT_TOL);
691
692 // Change padding to valid
693 cnn_config.layer_config[0].pad = PADDING_VALID;
694 RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
695 image_width, &thread_data, MSE_INT_TOL);
696
697 cnn_config.layer_config[0].filter_width = 4;
698 cnn_config.layer_config[0].filter_height = 3;
699 float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
700 float bias_2[] = { -4 };
701 cnn_config.layer_config[0].weights = weights_2;
702 cnn_config.layer_config[0].bias = bias_2;
703
704 cnn_config.layer_config[0].skip_width = 5;
705 cnn_config.layer_config[0].skip_height = 2;
706 float expected_2_same[] = {
707 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
708 -185, -547, -4, -13, -31, -13, -31, -4, -4, 14, -22, 32,
709 -4, -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4,
710 14, -22, 32, -4, -195, -658, -213, -622, -4, -16, -94, -28,
711 -70, -4, 459, -244, 97, 480, -4, -85, -328, -103, -292, -4,
712 -4, 432, -440, 868, -4, -4, 56, -64, 116, -4, -4, 156,
713 -164, 316, -4, -4, 212, -220, 428, -4, 582, -208, 146, 664,
714 -4, -130, -652, -190, -532, -4, 166, -214, 6, 106, -4, 192,
715 -388, -24, 44, -4, -4, 132, -140, 268, -4, -4, 428, -436,
716 860, -4, -4, 136, -144, 276, -4, -4, 252, -260, 508, -4,
717 21, -541, -115, -269, -4, 416, -688, -16, 176, -4, 173, -103,
718 33, 177, -4, 168, -640, -88, -128, -4, -4, 354, -362, 712,
719 -4, -4, 452, -460, 908, -4, -4, 62, -70, 128, -4, -4,
720 420, -428, 844, -4, 499, -106, 141, 610, -4, 666, -46, 210,
721 866, -4, 47, -148, -19, -16, -4, 605, -85, 181, 763, -4,
722 -4, 64, -72, 132, -4, -4, 24, -32, 52, -4, -4, 92,
723 -100, 188, -4, -4, 50, -58, 104, -4, -132, -694, -200, -558,
724 -4, 15, -73, -13, -17, -4, -62, -610, -158, -418, -4, -36,
725 -343, -90, -235, -4, -4, 456, -464, 916, -4, -4, 42, -50,
726 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448, -4,
727 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
728 76, 438, -4, 223, -340, -3, 112, -4, -4, 156, -164, 316,
729 -4, -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4,
730 220, -228, 444, -4,
731 };
732 float expected_2_valid[] = {
733 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
734 -185, -547, -4, -13, -31, -13, -31, -4, 14, -22, 32, -4,
735 -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4, 14,
736 -22, 32, -195, -658, -213, -622, -4, -16, -94, -28, -70, -4,
737 459, -244, 97, 480, -4, -85, -328, -103, -292, -4, 432, -440,
738 868, -4, -4, 56, -64, 116, -4, -4, 156, -164, 316, -4,
739 -4, 212, -220, 428, 582, -208, 146, 664, -4, -130, -652, -190,
740 -532, -4, 166, -214, 6, 106, -4, 192, -388, -24, 44, -4,
741 132, -140, 268, -4, -4, 428, -436, 860, -4, -4, 136, -144,
742 276, -4, -4, 252, -260, 508, 21, -541, -115, -269, -4, 416,
743 -688, -16, 176, -4, 173, -103, 33, 177, -4, 168, -640, -88,
744 -128, -4, 354, -362, 712, -4, -4, 452, -460, 908, -4, -4,
745 62, -70, 128, -4, -4, 420, -428, 844, 499, -106, 141, 610,
746 -4, 666, -46, 210, 866, -4, 47, -148, -19, -16, -4, 605,
747 -85, 181, 763, -4, 64, -72, 132, -4, -4, 24, -32, 52,
748 -4, -4, 92, -100, 188, -4, -4, 50, -58, 104, -132, -694,
749 -200, -558, -4, 15, -73, -13, -17, -4, -62, -610, -158, -418,
750 -4, -36, -343, -90, -235, -4, 456, -464, 916, -4, -4, 42,
751 -50, 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448,
752 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
753 76, 438, -4, 223, -340, -3, 112, -4, 156, -164, 316, -4,
754 -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4, 220,
755 -228, 444, 236, -4, 76, 316, -4, 164, -4, 52, 220, -4,
756 362, -4, 118, 484, -4, 332, -4, 108, 444,
757 };
758 // Set padding to same
759 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
760
761 RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
762 image_width, &thread_data, MSE_INT_TOL);
763
764 cnn_config.layer_config[0].pad = PADDING_VALID;
765
766 RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
767 image_width, &thread_data, MSE_INT_TOL);
768
769 cnn_config.layer_config[0].skip_width = 2;
770 cnn_config.layer_config[0].skip_height = 5;
771 float expected_21_same[] = {
772 -31, -19, -49, -191, -565, -194, -574, -13, 14, -22, 44, -16,
773 382, -366, 738, -22, -4, 23, 32, 545, 20, 204, 720, 5,
774 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
775 -4, -4, -4, -4, -658, -252, -748, -114, -334, -192, -568, -112,
776 432, -440, 928, -64, 276, -164, 532, -220, -4, 304, 868, 266,
777 116, 400, 316, 104, -4, -4, -4, -4, -4, -4, -4, -4,
778 -4, -4, -4, -4, -4, -4, -4, -4, -208, -288, -856, -290,
779 -862, -202, -598, -132, 132, -140, 700, -436, 1000, -144, 532, -260,
780 -4, 712, 268, 422, 860, 450, 276, 124, -4, -4, -4, -4,
781 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
782 -541, -411, -1225, -265, -787, -249, -739, -216, 354, -362, 1168, -460,
783 974, -70, 552, -428, -4, 859, 712, 323, 908, 665, 128, 208,
784 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
785 -4, -4, -4, -4, -106, -52, -148, -66, -190, -79, -229, -31,
786 64, -72, 160, -32, 148, -100, 242, -58, -4, 72, 132, 154,
787 52, 125, 188, 23, -4, -4, -4, -4, -4, -4, -4, -4,
788 -4, -4, -4, -4, -4, -4, -4, -4, -694, -257, -763, -229,
789 -679, -319, -949, -117, 456, -464, 962, -50, 492, -408, 1030, -230,
790 -4, 295, 916, 625, 88, 537, 804, 109, -4, -4, -4, -4,
791 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
792 -244, -140, -412, -182, -538, -238, -706, -116, 156, -164, 428, -116,
793 464, -248, 708, -228, -4, 244, 316, 418, 220, 454, 484, 108,
794 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
795 -4, -4, -4, -4,
796 };
797 float expected_21_valid[] = {
798 -13, -31, -19, -49, -191, -565, -194, -574, -13, -31, -4, 14,
799 -22, 44, -16, 382, -366, 738, -22, 32, 23, -4, 23, 32,
800 545, 20, 204, 720, 5, 32, -4, -4, -4, -4, -4, -4,
801 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
802 -4, -4, -222, -658, -252, -748, -114, -334, -192, -568, -112, -328,
803 -4, 432, -440, 928, -64, 276, -164, 532, -220, 428, 650, -4,
804 304, 868, 266, 116, 400, 316, 104, 428, -4, -4, -4, -4,
805 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
806 -4, -4, -4, -4, -72, -208, -288, -856, -290, -862, -202, -598,
807 -132, -388, -4, 132, -140, 700, -436, 1000, -144, 532, -260, 508,
808 200, -4, 712, 268, 422, 860, 450, 276, 124, 508, -4, -4,
809 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
810 -4, -4, -4, -4, -4, -4, -183, -541, -411, -1225, -265, -787,
811 -249, -739, -216, -640, -4, 354, -362, 1168, -460, 974, -70, 552,
812 -428, 844, 533, -4, 859, 712, 323, 908, 665, 128, 208, 844,
813 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
814 -4, -4, -4, -4, -4, -4, -4, -4, -38, -106, -52, -148,
815 -66, -190, -79, -229, -31, -85, -4, 64, -72, 160, -32, 148,
816 -100, 242, -58, 104, 98, -4, 72, 132, 154, 52, 125, 188,
817 23, 104, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
818 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -234, -694,
819 -257, -763, -229, -679, -319, -949, -117, -343, -4, 456, -464, 962,
820 -50, 492, -408, 1030, -230, 448, 686, -4, 295, 916, 625, 88,
821 537, 804, 109, 448, -4, -4, -4, -4, -4, -4, -4, -4,
822 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
823 -84, -244, -140, -412, -182, -538, -238, -706, -116, -340, -4, 156,
824 -164, 428, -116, 464, -248, 708, -228, 444, 236, -4, 244, 316,
825 418, 220, 454, 484, 108, 444,
826 };
827
828 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
829
830 RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
831 image_width, &thread_data, MSE_INT_TOL);
832
833 cnn_config.layer_config[0].pad = PADDING_VALID;
834
835 RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
836 image_width, &thread_data, MSE_INT_TOL);
837 }
838
TEST_F(CNNTest,TestLargeKernelsAndStrides)839 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
840 float input_10x11[] = {
841 4, 4, 2, 4, 2, -5, -2, 3, -1, 0, 0, 1, 2, 0, -5, -2, -5, 1, -3,
842 -1, 4, -3, 2, -2, 1, 0, 1, -3, -3, -4, -2, -2, 1, -4, -1, 4, 1, -4,
843 -4, -4, 3, 2, -5, 3, -5, 1, 2, -4, 1, -1, 3, 4, -2, 3, -3, 3, 0,
844 2, -4, -5, -5, -2, -1, -2, 1, 1, 1, -2, 4, -5, 4, -1, -1, 2, 3, -4,
845 2, 2, 3, 0, 0, 1, 0, 3, 2, 3, 1, -2, 3, -4, 3, 2, 4, -2, 0,
846 4, -4, 1, -3, -3, -3, -5, 1, -3, -5, 0, 4, -1, -3, 2,
847 };
848
849 float weights_10x11[] = {
850 -3, 4, -4, -3, -5, 1, -2, 3, 1, -4, -4, 0, -1, 0, 3, 1, -3, -2, 0,
851 -1, 1, 3, -4, -4, -3, -3, -2, 4, 3, -5, 4, 2, -3, 4, -2, -1, 2, -1,
852 -5, 0, -3, 0, 3, -5, -5, 3, -4, -1, -5, 3, 4, 0, 4, -5, 2, -1, 2,
853 -1, -1, -1, -5, 0, -4, 3, -1, 1, 1, -1, 3, 2, -5, -4, 0, -4, 4, -5,
854 -3, 4, -5, 2, -5, -4, -4, -1, 3, 3, 0, 2, -4, 1, -2, 1, 1, 0, 3,
855 -2, 0, 1, 2, 4, -3, -1, -5, -5, 2, -4, 1, 1, 2, -4, -2, -2, 2, 1,
856 3, 4, -5, 1, -1, -3, -3, -1, -2, -5, 1, -1, 0, 1, 4, 4, 0, 0, 4,
857 -3, -1, -5, -3, 0, 1, 1, 1, -5, 3, 4, 3, -5, 3, -2, -2, 0, -4, 0,
858 0, -2, 1, -4, -1, 0, -5, -2, -2, -5, -3, -3, 1, 1, -3, 2, 4, 2, 4,
859 -4, -3, 3, 1, 1, 3, -4, 4, -2, -3, -3, -3, -3, -4, -2, 3, -5, 2, 4,
860 -1, -4, -4, 4, -2, -1, 3, -3, -4, -4, -2, 4, 1, 0, 2, -1, 4, -3, 1,
861 4, -3, 4, 4, 0, -4, 3, -2, -3, 2, 3, -1, -3, 2, 1, 4, -2, -3, 1,
862 4, -2, 2, -2, -5, -2, 1, 4, -1, -4, 4, -5, 2, -5, -4, -1, -2, 3, 1,
863 2, 1, -5, 1, -5, -4, -1, -2, 2, -2, -4, -3, -2, -2, 4, -1, 2, 2, -4,
864 2, -2, 4, -4, -2, -2, 1, -1, 1, 1, 1, -4, -5, -2, 3, -4, -1, 3, -2,
865 3, 2, -5, -4, 0, 3, -2, -4, -5, 3, -2, -4, 2, -2, 1, -4, 0, 2, -5,
866 1, -4, -1, -1, 4, -5, -4, 0, -5, -4, -3, -5, -4, 0, 2, 0, -4, 2, -2,
867 1, 1, -3, 2, 0, -4, 0, -4, 1, 0, -5, -1, -1, -1, -5, 4, 2, 2, -4,
868 3, -2, -2, 2, -3, -2, -1, 2, -4, -5, 2, -2, -4, -5, -5, -1, 2, -1, 0,
869 -5, -2, -2, -5, 0, 1, -1, -5, 0, 3, 2, 3, 0, -3, -2, 0, -5, -1, -2,
870 2, -4, -1, 2, 2, -5, 2, -4, 0, 3, -3, 1, 0, 0, 1, -5, -3, 1, -1,
871 0, -4, -3, 2, -4, -4, 4, -1, 0, 1, 2, -4, -5, 4, -2, 1, -4, -4, -3,
872 -1, -1, 1, -1, -4, -1, -4, -3, 2, -1, -2, -4, 1, 1, 0, -2, 0, -4, 3,
873 -3, 0, -4, -1, -4, 2, -1, -2, -5, -1, -2, -3, 3, -1, 0, -3, 0, 1, -5,
874 1, -5, 0, 1,
875 };
876
877 float bias_10x11[] = { 3 };
878
879 float expected_10x11[] = {
880 118,
881 };
882
883 CNN_CONFIG cnn_config = { 1,
884 0,
885 0,
886 0,
887 0,
888 { {
889 1,
890 23,
891 20,
892 1,
893 15,
894 20,
895 0,
896 weights_10x11,
897 bias_10x11,
898 PADDING_SAME_ZERO,
899 NONE,
900 0,
901 0,
902 BRANCH_NO_COPY,
903 BRANCH_NOC,
904 {},
905 {},
906 0,
907 } } };
908
909 int image_height = 10;
910 int image_width = 11;
911
912 CNN_THREAD_DATA thread_data = { 1, NULL };
913
914 RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
915 &cnn_config, image_width, &thread_data, MSE_INT_TOL);
916
917 float input_11x10[] = {
918 -2, -2, 3, -5, -1, -3, 1, 3, 2, 1, 1, -5, 4, 1, 3, -5, 3, -3, -5,
919 0, -1, -3, -3, 1, 1, -5, -1, -5, -5, -3, 0, 1, -3, -1, -3, -3, 0, 3,
920 4, -4, -1, 3, -3, -1, -3, 1, -3, -2, -1, -4, -3, 2, -4, 1, -4, -1, -3,
921 -5, -1, 2, 3, 0, 2, 2, -5, 4, 1, 2, -1, -4, 4, -4, -4, 0, -1, 1,
922 -1, 1, -3, -3, -2, 1, 2, 4, 4, 4, -3, -3, 0, 1, 0, 1, 4, 1, 3,
923 4, -3, -2, -4, 4, 2, 0, 3, 4, -1, 2, -2, 1, -3, -2,
924 };
925
926 float weights_11x10[] = {
927 4, -1, 1, -1, 2, 4, 3, 3, -4, 3, -5, 1, -1, -1, -2, -2, 0, 2, -3,
928 -2, 3, -5, -1, 0, -1, -2, -2, -1, 2, 4, 3, 1, 0, 0, -3, 3, -4, -1,
929 -5, 4, -2, -2, 1, 2, -1, -3, 1, 2, -5, 1, -3, 3, 3, 0, -4, -4, -5,
930 -3, -4, -4, 4, -2, 4, 4, -2, 2, -5, -1, -2, -5, -1, 4, -3, 3, -2, 0,
931 -4, -3, 0, -1, -2, 4, 2, 0, -2, -5, -4, 1, 4, -4, -2, 2, -2, 1, 1,
932 -4, 1, -4, -4, -2, 4, 2, -1, -5, -5, 1, -3, -3, 3, -3, -5, -3, 4, -1,
933 -1, -3, 0, -4, 3, -1, 0, -2, 0, -5, -2, -5, 2, 0, -5, 2, 3, -2, 2,
934 4, -1, 1, -3, 2, 3, 2, 0, -5, -4, -5, 2, 1, 1, -1, -2, 3, 4, 2,
935 -2, 4, -2, 3, 1, -4, -3, -1, 4, 4, -3, -5, -2, 2, 0, 3, -2, 3, -1,
936 -4, 0, -2, 0, 3, 4, -2, -3, -2, 0, 3, 4, 2, -4, 0, 1, 2, 2, -1,
937 -1, 4, 1, 4, -2, -1, -1, -5, 1, -3, 3, 3, -1, -4, 3, -5, 0, 0, -1,
938 -4, -1, -2, 4, -2, 3, 3, -3, 1, -1, 2, -1, 4, 4, -2, -2, 4, -2, 0,
939 3, -3, -5, -1, -2, 4, -4, 2, -4, 0, -2, 3, -3, 2, 2, -2, -5, -1, 4,
940 3, -2, -1, 3, 3, -1, 3, 0, -3, 0, 4, 2, 0, -1, 4, 1, 1, 2, 1,
941 3, 1, 1, 1, -3, -5, -4, 4, -4, 2, 0, 0, -4, 1, 4, -5, 4, 4, 0,
942 1, 0, -2, -4, -4, -3, 0, 1, -5, 4, 0, -3, -2, -4, 2, 4, 1, -5, 1,
943 -4, 1, 0, -3, -3, 0, 2, -5, 4, 3, -2, -5, 3, 1, -1, 0, 3, -2, -2,
944 3, -2, -5, 4, 1, -2, 2, -1, 0, 4, 0, -5, 3, -2, 1, 2, 1, -5, -3,
945 -2, -5, 4, -4, 0, 3, 2, -1, -4, -1, 2, 1, -2, 3, -1, -4, 2, 0, -3,
946 1, -1, 2, -5, -4, -1, -5, 1, 4, 3, 4, 2, -3, 1, -5, -1, 3, 0, -1,
947 -4, 3, 4, -5, 4, 4, -3, 2, -3, -1, -3, -5, -3, 2, -3, -2, 1, 1, 0,
948 -5, 3, 2, 1, -5, 1, 1, 1, 3, 4, -4, -1, -2, 0, -5, -3, -5, -2, -4,
949 3, 3, 3, 4, 0, -4, -1, -5, 0, -3, 1, 4, 4, -4, 4, -5, -5, -1, -2,
950 -5, 3, -4, 4, 3, 0, -3, 2, -2, 0, 0, 4, 4, 0, -2, 1, -1, -3, 2,
951 -1, 1, -3, -5,
952 };
953
954 float bias_11x10[] = {
955 -5,
956 };
957
958 float expected_11x10[] = {
959 36, -84, 95, 45, 18, 46, 77, -54, -99, -149, 66, 49, 161, 11,
960 39, 61, -66, 61, 4, -3, 34, -44, -23, 31, 64, 29, 47, 72,
961 -27, -27, 121, -3, 100, 1, 30, -78, -12, -89, -59, 8, -16, 112,
962 91, -102, -26, -4, 30, 54, 4, -84, -24, -58, 27, -53, -33, 5,
963 53, -26, 63, 50, -103, -130, -23, 6, -104, -207, 73, 23, 77, 132,
964 38, 32, -130, -44, -60, 7, 27, 176, 45, -32, -2, 99, -97, 63,
965 69, 126, 47, 63, 136, -57, 5, 16, -40, -157, 8, 38, -44, -10,
966 91, 7, 122, 140, 30, -105, 4, -1, 113, 64, 180, 141,
967 };
968
969 cnn_config.layer_config[0].weights = weights_11x10;
970 cnn_config.layer_config[0].bias = bias_11x10;
971 cnn_config.layer_config[0].filter_width = 20;
972 cnn_config.layer_config[0].filter_height = 23;
973 cnn_config.layer_config[0].skip_width = 1;
974 cnn_config.layer_config[0].skip_height = 1;
975 image_height = 11;
976 image_width = 10;
977
978 RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
979 &cnn_config, image_width, &thread_data, MSE_INT_TOL);
980 }
981
TEST_F(CNNTest,TestSoftsignSingleLayer)982 TEST_F(CNNTest, TestSoftsignSingleLayer) {
983 int image_width = 8;
984 int image_height = 8;
985 int filter_height = 5;
986 int filter_width = 4;
987 float input[] = {
988 -0.5220f, 0.8410f, -0.8990f, -0.0090f, 0.6710f, -0.9470f, -0.8240f,
989 -0.0870f, 0.5380f, 0.4750f, 0.570f, -0.3760f, -0.6960f, -0.5940f,
990 -0.3830f, 0.080f, -0.0980f, -0.4940f, -0.4030f, 0.9460f, -0.6020f,
991 0.4220f, 0.6190f, 0.6640f, -0.9210f, -0.1470f, -0.2480f, -0.1120f,
992 -0.580f, -0.0650f, 0.3330f, 0.9860f, -0.7430f, 0.7610f, 0.4840f,
993 0.1030f, 0.9570f, 0.6120f, -0.5240f, -0.1220f, -0.5850f, -0.270f,
994 0.7840f, -0.9790f, 0.7290f, -0.30f, -0.6460f, 0.0780f, 0.4750f,
995 -0.0510f, 0.4550f, 0.3850f, -0.7230f, 0.4460f, -0.6260f, -0.810f,
996 0.8720f, -0.2120f, -0.580f, -0.9510f, -0.8430f, -0.1340f, -0.0850f,
997 0.9190f,
998 };
999 float expected_same[] = {
1000 0.430f, 0.660f, 0.5510f, -0.610f, 0.450f, -0.1610f, 0.0520f, 0.3240f,
1001 0.6820f, 0.3820f, 0.6360f, 0.7480f, 0.3080f, 0.090f, 0.3910f, 0.1730f,
1002 0.340f, 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.4670f, 0.6150f,
1003 -0.3880f, 0.7590f, 0.4190f, 0.7350f, 0.5310f, -0.5160f, -0.1760f, 0.6790f,
1004 -0.6780f, 0.5470f, 0.5750f, -0.6420f, 0.7210f, -0.4620f, 0.5430f, 0.770f,
1005 -0.1990f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.6430f, 0.4510f,
1006 -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f, 0.5870f, 0.4720f,
1007 0.4040f, 0.3630f, 0.670f, 0.2360f, 0.410f, 0.6980f, -0.5350f, 0.3940f,
1008 };
1009 float expected_replicate[] = {
1010 0.540f, 0.7230f, -0.3530f, -0.2130f, 0.7440f, -0.4470f, -0.6260f,
1011 -0.2050f, 0.7230f, 0.4630f, 0.5920f, 0.7440f, 0.6080f, 0.3130f,
1012 -0.5670f, -0.4720f, 0.5480f, 0.6660f, -0.4990f, 0.4280f, 0.1540f,
1013 0.120f, 0.3390f, 0.6090f, 0.4160f, 0.7590f, 0.4190f, 0.7350f,
1014 0.5310f, -0.5160f, -0.490f, 0.4450f, -0.610f, 0.5470f, 0.5750f,
1015 -0.6420f, 0.7210f, -0.4620f, 0.3150f, 0.7370f, -0.5820f, 0.3950f,
1016 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.7430f, -0.5340f, -0.6270f,
1017 0.4430f, 0.4730f, 0.4570f, 0.7450f, 0.630f, 0.2620f, 0.3140f,
1018 -0.1840f, 0.1810f, 0.7210f, 0.2760f, 0.6430f, 0.6720f, -0.4390f,
1019 0.2040f,
1020 };
1021 float expected_valid[] = {
1022 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.7590f, 0.4190f,
1023 0.7350f, 0.5310f, -0.5160f, 0.5470f, 0.5750f, -0.6420f, 0.7210f,
1024 -0.4620f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f,
1025 };
1026 float weights[] = {
1027 0.6210f, 0.3710f, -0.2770f, -0.7230f, -0.2450f, 0.6770f, 0.3080f,
1028 -0.9880f, -0.080f, 0.7190f, -0.6760f, -0.0170f, -0.8970f, 0.8260f,
1029 0.7390f, -0.4550f, -0.4260f, -0.6330f, 0.0880f, -0.9390f,
1030 };
1031 float bias[] = {
1032 0.750f,
1033 };
1034
1035 CNN_CONFIG cnn_config = { 1,
1036 0,
1037 0,
1038 0,
1039 0,
1040 { {
1041 1,
1042 filter_width,
1043 filter_height,
1044 1,
1045 1,
1046 1,
1047 0,
1048 weights,
1049 bias,
1050 PADDING_SAME_ZERO,
1051 SOFTSIGN,
1052 0,
1053 0,
1054 BRANCH_NO_COPY,
1055 BRANCH_NOC,
1056 {},
1057 {},
1058 0,
1059 } } };
1060
1061 CNN_THREAD_DATA thread_data = { 1, NULL };
1062
1063 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
1064 image_width, &thread_data, MSE_FLOAT_TOL);
1065
1066 cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
1067
1068 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
1069 image_width, &thread_data, MSE_FLOAT_TOL);
1070
1071 cnn_config.layer_config[0].pad = PADDING_VALID;
1072
1073 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
1074 image_width, &thread_data, MSE_FLOAT_TOL);
1075 }
1076
TEST_F(CNNTest,TestBranchTensorAdd)1077 TEST_F(CNNTest, TestBranchTensorAdd) {
1078 int filter_width = 2;
1079 int filter_height = 3;
1080
1081 int image_width = 4;
1082 int image_height = 4;
1083
1084 float input[] = {
1085 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1086 };
1087
1088 float weights[] = {
1089 -3, -1, 4, -1, -3, 3, 3, 0, 2, 0, 3, 2, 4, 4, 4, -5, 1, -4,
1090 2, -4, 1, -3, 0, 4, -5, 4, 0, -4, -3, -1, 0, 0, -2, 0, 0, 2,
1091 -5, -1, 1, -3, 3, 4, 3, 0, 1, -1, 1, 1, 2, 4, -2, -5, 2, -2,
1092 3, -2, 4, -1, 0, 2, 3, 2, -2, -1, -3, 1, 3, 4, -1, -3, 0, -4,
1093 4, 2, -3, -3, -1, 0, 1, 0, 3, 3, -3, 0, 3, 2, -5, -3, 4, -5,
1094 3, -1, -1, -3, 0, 1, -1, -4, 2, 4, -1, 4, -1, 1, 3, 4, 4, 4,
1095 0, -1, -3, -3, -3, -3, 2, -3, -2, 2, 3, -3,
1096 };
1097
1098 float bias[] = {
1099 3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1100 };
1101
1102 float expected[] = {
1103 -11502, -4101, -3424, 668, -17950, -5470, -5504, 626,
1104 4835, 446, 1779, -3483, 3679, -4214, 4578, -105,
1105 };
1106
1107 int channels = 2;
1108
1109 CNN_CONFIG cnn_config = { 6,
1110 0,
1111 0,
1112 0,
1113 0,
1114 { {
1115 1,
1116 filter_width,
1117 filter_height,
1118 channels,
1119 1,
1120 1,
1121 0,
1122 weights,
1123 bias,
1124 PADDING_SAME_ZERO,
1125 NONE,
1126 0,
1127 0,
1128 BRANCH_NO_COPY,
1129 BRANCH_NOC,
1130 {},
1131 {},
1132 -1,
1133 },
1134 {
1135 channels,
1136 filter_width,
1137 filter_height,
1138 channels,
1139 1,
1140 1,
1141 0,
1142 nullptr,
1143 nullptr,
1144 PADDING_SAME_ZERO,
1145 NONE,
1146 0,
1147 0,
1148 BRANCH_INPUT,
1149 BRANCH_NOC,
1150 {
1151 0x02,
1152 0,
1153 0x00,
1154 },
1155 {},
1156 -1,
1157 },
1158 {
1159 channels,
1160 filter_width,
1161 filter_height,
1162 channels,
1163 1,
1164 1,
1165 0,
1166 nullptr,
1167 nullptr,
1168 PADDING_SAME_ZERO,
1169 NONE,
1170 0,
1171 1,
1172 BRANCH_NO_COPY,
1173 BRANCH_NOC,
1174 {},
1175 {},
1176 -1,
1177 },
1178 {
1179 channels,
1180 filter_width,
1181 filter_height,
1182 channels,
1183 1,
1184 1,
1185 0,
1186 nullptr,
1187 nullptr,
1188 PADDING_SAME_ZERO,
1189 NONE,
1190 0,
1191 1,
1192 BRANCH_NO_COPY,
1193 BRANCH_NOC,
1194 {},
1195 {},
1196 -1,
1197 },
1198 {
1199 channels,
1200 filter_width,
1201 filter_height,
1202 channels,
1203 1,
1204 1,
1205 0,
1206 nullptr,
1207 nullptr,
1208 PADDING_SAME_ZERO,
1209 NONE,
1210 0,
1211 0,
1212 BRANCH_NO_COPY,
1213 BRANCH_ADD,
1214 {
1215 0x00,
1216 0,
1217 0x02,
1218 },
1219 {},
1220 -1,
1221 },
1222 {
1223 channels,
1224 filter_width,
1225 filter_height,
1226 1,
1227 1,
1228 1,
1229 0,
1230 nullptr,
1231 nullptr,
1232 PADDING_SAME_ZERO,
1233 NONE,
1234 0,
1235 0,
1236 BRANCH_NO_COPY,
1237 BRANCH_NOC,
1238 {},
1239 {},
1240 0,
1241 } } };
1242
1243 // Weights and biases need to be specified separately because
1244 // of the offset.
1245 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1246
1247 CNN_THREAD_DATA thread_data = { 1, NULL };
1248
1249 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1250 image_width, &thread_data, MSE_INT_TOL);
1251 }
1252
TEST_F(CNNTest,TestBranchTensorConcatenation)1253 TEST_F(CNNTest, TestBranchTensorConcatenation) {
1254 int filter_width = 2;
1255 int filter_height = 3;
1256
1257 int image_width = 4;
1258 int image_height = 4;
1259
1260 float input[] = {
1261 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1262 };
1263
1264 float weights[] = {
1265 3, 0, 2, 0, 2, 3, 1, -3, 1, -5, -3, 0, -4, 4, 0, -5, 0, -5, -1,
1266 -2, -5, 0, -3, 2, -4, 2, 0, 2, -1, 0, -4, 3, 0, 0, -1, -5, 2, -1,
1267 4, -4, -2, -3, -3, 3, 4, -2, -1, -4, -1, 4, 4, -1, 4, 3, -4, 2, -2,
1268 -4, -3, -2, 3, -3, -5, -1, 3, -2, 4, 1, -4, -3, -5, -5, -3, 4, -2, -2,
1269 -1, -5, -5, 0, -1, -2, -3, 3, -4, -5, 2, -3, 1, 0, -5, 2, 2, -2, 0,
1270 2, 2, -2, 4, 2, 2, 0, 1, -5, -3, 0, 2, -2, 1, 2, -5, 2, 3, 3,
1271 -1, 3, 0, -3, 3, -4, -4, 3, 3, -4, -2, 2, -2, 2, -2, -1, 3, 0,
1272 };
1273
1274 float bias[] = {
1275 -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1276 };
1277
1278 float expected[] = {
1279 -33533, -32087, -6741, -2124, 39979, 41453, 14034, 689,
1280 -22611, -42203, -14882, -239, 15781, 15963, 9524, 837,
1281 };
1282
1283 int channels = 2;
1284
1285 CNN_CONFIG cnn_config = { 6,
1286 0,
1287 0,
1288 0,
1289 0,
1290 { {
1291 1,
1292 filter_width,
1293 filter_height,
1294 channels,
1295 1,
1296 1,
1297 0,
1298 weights,
1299 bias,
1300 PADDING_SAME_ZERO,
1301 NONE,
1302 0,
1303 0,
1304 BRANCH_NO_COPY,
1305 BRANCH_NOC,
1306 {},
1307 {},
1308 -1,
1309 },
1310 {
1311 channels,
1312 filter_width,
1313 filter_height,
1314 channels,
1315 1,
1316 1,
1317 0,
1318 nullptr,
1319 nullptr,
1320 PADDING_SAME_ZERO,
1321 NONE,
1322 0,
1323 0,
1324 BRANCH_INPUT,
1325 BRANCH_NOC,
1326 {
1327 0x02,
1328 0,
1329 0x00,
1330 },
1331 {},
1332 -1,
1333 },
1334 {
1335 channels,
1336 filter_width,
1337 filter_height,
1338 channels,
1339 1,
1340 1,
1341 0,
1342 nullptr,
1343 nullptr,
1344 PADDING_SAME_ZERO,
1345 NONE,
1346 0,
1347 1,
1348 BRANCH_NO_COPY,
1349 BRANCH_NOC,
1350 {},
1351 {},
1352 -1,
1353 },
1354 {
1355 channels,
1356 filter_width,
1357 filter_height,
1358 channels,
1359 1,
1360 1,
1361 0,
1362 nullptr,
1363 nullptr,
1364 PADDING_SAME_ZERO,
1365 NONE,
1366 0,
1367 1,
1368 BRANCH_NO_COPY,
1369 BRANCH_NOC,
1370 {},
1371 {},
1372 -1,
1373 },
1374 {
1375 channels,
1376 filter_width,
1377 filter_height,
1378 channels,
1379 1,
1380 1,
1381 0,
1382 nullptr,
1383 nullptr,
1384 PADDING_SAME_ZERO,
1385 NONE,
1386 0,
1387 0,
1388 BRANCH_NO_COPY,
1389 BRANCH_CAT,
1390 {
1391 0x00,
1392 0,
1393 0x02,
1394 },
1395 {},
1396 -1,
1397 },
1398 {
1399 channels + channels,
1400 filter_width,
1401 filter_height,
1402 1,
1403 1,
1404 1,
1405 0,
1406 nullptr,
1407 nullptr,
1408 PADDING_SAME_ZERO,
1409 NONE,
1410 0,
1411 0,
1412 BRANCH_NO_COPY,
1413 BRANCH_NOC,
1414 {},
1415 {},
1416 0,
1417 } } };
1418
1419 // Weights and biases need to be specified separately because
1420 // of the offset.
1421 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1422
1423 CNN_THREAD_DATA thread_data = { 1, NULL };
1424
1425 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1426 image_width, &thread_data, MSE_INT_TOL);
1427 }
1428
1429 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1430
TEST_F(CNNTest,TestBranchCombinations)1431 TEST_F(CNNTest, TestBranchCombinations) {
1432 int filter_width = 2;
1433 int filter_height = 3;
1434
1435 int image_width = 4;
1436 int image_height = 4;
1437
1438 float input[] = {
1439 3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1440 };
1441
1442 float weights[] = {
1443 2, 3, 0, 4, 4, 3, 1, 0, 1, -5, 4, -3, 3, 0, 4, -1, -1, -5,
1444 2, 1, -3, -5, 3, -1, -3, -2, 0, -2, 3, 0, -2, -4, -2, -2, 2, -5,
1445 4, -5, 0, 1, -5, -4, -3, -4, 2, -2, 1, 0, 3, -2, -4, 3, 4, -4,
1446 -1, -1, -3, -2, -2, -1, 2, 0, 2, -1, 2, -4, -4, -1, 2, 0, 3, -2,
1447 -2, 3, -3, 4, -2, 4, 3, 4, 1, 0, -2, -3, -5, 1, -3, 2, 0, -2,
1448 -2, -1, -1, -5, -2, -3, -1, 3, 3, 4, 4, 0, 2, 1, 3, -3, 2, -5,
1449 -5, 1, -5, -1, 3, 3, 2, -4, -1, 3, -4, -2, -5, -2, 1, 3, 2, 2,
1450 -5, -2, -3, -1, -2, -4, -1, -2, 2, 1, -4, -4, 2, 0, 2, 0, 2, -3,
1451 -2, -4, 4, 0, 1, -3, -5, 4, -1, 2, 3, -5, -1, 0, 4, -1, -1, 3,
1452 -1, -3, 3, 1, 4, 3, 4, 3, -4, -5, -1, 3, 3, -4, 3, 1, 3, -5,
1453 3, 4, -5, 4, 2, -1, -5, 2, 1, 0, 4, 0, -3, 2, 0, 2, -2, 1,
1454 -1, -2, -1, -5, 4, 3, 3, -2, 2, 4, -5, -5, -3, -2, 4, 0, -4, 1,
1455 };
1456
1457 float bias[] = {
1458 -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1459 };
1460
1461 float expected[] = {
1462 149496, 15553, -24193, -20956, 134094, 86432, -68283, -6366,
1463 -53031, 133739, 67407, -13539, -53205, -58635, -20033, 1979,
1464 };
1465
1466 int channels = 2;
1467
1468 CNN_CONFIG cnn_config = { 10,
1469 0,
1470 0,
1471 0,
1472 0,
1473 {
1474 {
1475 1,
1476 filter_width,
1477 filter_height,
1478 channels,
1479 1,
1480 1,
1481 0,
1482 weights,
1483 bias,
1484 PADDING_SAME_ZERO,
1485 NONE,
1486 0,
1487 0,
1488 BRANCH_NO_COPY,
1489 BRANCH_NOC,
1490 {},
1491 {},
1492 -1,
1493 },
1494 {
1495 channels,
1496 filter_width,
1497 filter_height,
1498 channels,
1499 1,
1500 1,
1501 0,
1502 nullptr,
1503 nullptr,
1504 PADDING_SAME_ZERO,
1505 NONE,
1506 0,
1507 0,
1508 BRANCH_INPUT,
1509 BRANCH_NOC,
1510 {
1511 0x06,
1512 0,
1513 0x00,
1514 },
1515 {},
1516 -1,
1517 },
1518 {
1519 channels,
1520 filter_width,
1521 filter_height,
1522 channels,
1523 1,
1524 1,
1525 0,
1526 nullptr,
1527 nullptr,
1528 PADDING_SAME_ZERO,
1529 NONE,
1530 0,
1531 2,
1532 BRANCH_OUTPUT,
1533 BRANCH_NOC,
1534 {
1535 0x08,
1536 0,
1537 0x00,
1538 },
1539 {},
1540 -1,
1541 },
1542 {
1543 channels,
1544 filter_width,
1545 filter_height,
1546 channels,
1547 1,
1548 1,
1549 0,
1550 nullptr,
1551 nullptr,
1552 PADDING_SAME_ZERO,
1553 NONE,
1554 0,
1555 3,
1556 BRANCH_NO_COPY,
1557 BRANCH_NOC,
1558 {},
1559 {},
1560 -1,
1561 },
1562 {
1563 channels,
1564 filter_width,
1565 filter_height,
1566 channels,
1567 1,
1568 1,
1569 0,
1570 nullptr,
1571 nullptr,
1572 PADDING_SAME_ZERO,
1573 NONE,
1574 0,
1575 2,
1576 BRANCH_NO_COPY,
1577 BRANCH_ADD,
1578 {
1579 0x00,
1580 0,
1581 0x08,
1582 },
1583 {},
1584 -1,
1585 },
1586 {
1587 channels,
1588 filter_width,
1589 filter_height,
1590 channels,
1591 1,
1592 1,
1593 0,
1594 nullptr,
1595 nullptr,
1596 PADDING_SAME_ZERO,
1597 NONE,
1598 0,
1599 2,
1600 BRANCH_NO_COPY,
1601 BRANCH_NOC,
1602 {},
1603 {},
1604 -1,
1605 },
1606 {
1607 channels,
1608 filter_width,
1609 filter_height,
1610 channels,
1611 1,
1612 1,
1613 0,
1614 nullptr,
1615 nullptr,
1616 PADDING_SAME_ZERO,
1617 NONE,
1618 0,
1619 1,
1620 BRANCH_NO_COPY,
1621 BRANCH_NOC,
1622 {},
1623 {},
1624 -1,
1625 },
1626 {
1627 channels,
1628 filter_width,
1629 filter_height,
1630 channels,
1631 1,
1632 1,
1633 0,
1634 nullptr,
1635 nullptr,
1636 PADDING_SAME_ZERO,
1637 NONE,
1638 0,
1639 1,
1640 BRANCH_NO_COPY,
1641 BRANCH_ADD,
1642 {
1643 0x00,
1644 0,
1645 0x0C,
1646 },
1647 {},
1648 -1,
1649 },
1650 {
1651 channels,
1652 filter_width,
1653 filter_height,
1654 channels,
1655 1,
1656 1,
1657 0,
1658 nullptr,
1659 nullptr,
1660 PADDING_SAME_ZERO,
1661 NONE,
1662 0,
1663 0,
1664 BRANCH_NO_COPY,
1665 BRANCH_ADD,
1666 {
1667 0x00,
1668 0,
1669 0x02,
1670 },
1671 {},
1672 -1,
1673 },
1674 {
1675 channels,
1676 filter_width,
1677 filter_height,
1678 1,
1679 1,
1680 1,
1681 0,
1682 nullptr,
1683 nullptr,
1684 PADDING_SAME_ZERO,
1685 NONE,
1686 0,
1687 0,
1688 BRANCH_NO_COPY,
1689 BRANCH_NOC,
1690 {},
1691 {},
1692 0,
1693 },
1694 } };
1695
1696 // Weights and biases need to be specified separately because
1697 // of the offset.
1698 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1699
1700 CNN_THREAD_DATA thread_data = { 1, NULL };
1701
1702 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1703 image_width, &thread_data, MSE_INT_TOL);
1704 }
1705
TEST_F(CNNTest,TestSplittingTensors)1706 TEST_F(CNNTest, TestSplittingTensors) {
1707 int filter_width = 2;
1708 int filter_height = 3;
1709
1710 int image_width = 4;
1711 int image_height = 4;
1712
1713 float input[] = {
1714 -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1715 };
1716
1717 float weights[] = {
1718 -4, 1, 0, 2, 3, 4, 4, -4, -5, -3, 2, 2, -4, -3, 3, 2,
1719 4, -4, -3, -4, -4, 1, -3, -5, -3, 4, 2, -2, 2, -1, -4, -1,
1720 -2, -3, 1, 1, 0, -5, -1, 3, 3, -5, -3, 0, -3, 1, -3, -1,
1721 1, -3, -2, -2, 4, -2, 0, 1, 2, 2, -4, 2, 4, 0, -5, -2,
1722 4, 4, -5, 1, 0, 2, -2, -5, -5, -3, -5, -5, 4, -3, 0, 0,
1723 -4, -4, 0, -5, -4, 0, 0, -3, -5, -3, -1, 2, -1, 4, -1, 2,
1724 };
1725
1726 float bias[] = {
1727 -4, -2, -3, -3, 3, 1, -2,
1728 };
1729
1730 float expected[] = {
1731 530, -762, 1469, 777, 849, -771, -1698, 600,
1732 -658, -1821, 98, -668, -1798, 30, 887, -971,
1733 };
1734
1735 CNN_CONFIG cnn_config = { 3,
1736 0,
1737 0,
1738 0,
1739 0,
1740 {
1741 {
1742 1,
1743 filter_width,
1744 filter_height,
1745 4,
1746 1,
1747 1,
1748 0,
1749 nullptr,
1750 nullptr,
1751 PADDING_SAME_ZERO,
1752 NONE,
1753 0,
1754 0,
1755 BRANCH_OUTPUT,
1756 BRANCH_NOC,
1757 {
1758 0x02,
1759 2,
1760 0x00,
1761 },
1762 {},
1763 -1,
1764 },
1765 {
1766 4,
1767 filter_width,
1768 filter_height,
1769 2,
1770 1,
1771 1,
1772 0,
1773 nullptr,
1774 nullptr,
1775 PADDING_SAME_ZERO,
1776 NONE,
1777 0,
1778 0,
1779 BRANCH_NO_COPY,
1780 BRANCH_CAT,
1781 {
1782 0x00,
1783 0,
1784 0x02,
1785 },
1786 {},
1787 -1,
1788 },
1789 {
1790 4,
1791 filter_width,
1792 filter_height,
1793 1,
1794 1,
1795 1,
1796 0,
1797 nullptr,
1798 nullptr,
1799 PADDING_SAME_ZERO,
1800 NONE,
1801 0,
1802 0,
1803 BRANCH_NO_COPY,
1804 BRANCH_NOC,
1805 {},
1806 {},
1807 0,
1808 },
1809 } };
1810
1811 // Weights and biases need to be specified separately because
1812 // of the offset.
1813 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1814
1815 CNN_THREAD_DATA thread_data = { 1, NULL };
1816
1817 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1818 image_width, &thread_data, MSE_INT_TOL);
1819 }
1820
TEST_F(CNNTest,TestOutputChannelsCount)1821 TEST_F(CNNTest, TestOutputChannelsCount) {
1822 int filter_width = 1;
1823 int filter_height = 1;
1824
1825 int image_width = 2;
1826 int image_height = 2;
1827
1828 float input[] = { 0, 0, 0, 0 };
1829
1830 float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1831
1832 float bias[] = { 0, 0, 0, 0, 0, 0 };
1833
1834 float expected[] = {
1835 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1836 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1837 };
1838
1839 CNN_CONFIG cnn_config = { 3,
1840 0,
1841 0,
1842 0,
1843 0,
1844 {
1845 {
1846 1,
1847 filter_width,
1848 filter_height,
1849 2,
1850 1,
1851 1,
1852 0,
1853 weights,
1854 bias,
1855 PADDING_SAME_ZERO,
1856 NONE,
1857 0,
1858 0,
1859 BRANCH_INPUT,
1860 BRANCH_NOC,
1861 {
1862 0x06,
1863 0,
1864 0x00,
1865 },
1866 {},
1867 -1,
1868 },
1869 {
1870 1,
1871 filter_width,
1872 filter_height,
1873 2,
1874 1,
1875 1,
1876 0,
1877 weights,
1878 bias,
1879 PADDING_SAME_ZERO,
1880 NONE,
1881 0,
1882 2,
1883 BRANCH_NO_COPY,
1884 BRANCH_CAT,
1885 {
1886 0x00,
1887 0,
1888 0x03,
1889 },
1890 {},
1891 -1,
1892 },
1893 {
1894 2,
1895 filter_width,
1896 filter_height,
1897 2,
1898 1,
1899 1,
1900 0,
1901 weights,
1902 bias,
1903 PADDING_SAME_ZERO,
1904 NONE,
1905 0,
1906 0,
1907 BRANCH_NO_COPY,
1908 BRANCH_CAT,
1909 {
1910 0x00,
1911 0,
1912 0x04,
1913 },
1914 {},
1915 0,
1916 },
1917 } };
1918
1919 // Weights and biases need to be specified separately because
1920 // of the offset.
1921 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1922
1923 CNN_THREAD_DATA thread_data = { 1, NULL };
1924
1925 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1926 image_width, &thread_data, MSE_FLOAT_TOL);
1927 }
1928
TEST_F(CNNTest,TestBatchNorm)1929 TEST_F(CNNTest, TestBatchNorm) {
1930 int image_width = 28;
1931 int image_height = 28;
1932 int filter_height = 7;
1933 int filter_width = 7;
1934 float input[] = {
1935 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1936 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1937 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1938 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1939 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1940 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1941 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1942 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1943 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1944 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1945 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1946 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1947 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1948 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1949 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1950 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1951 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1952 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1953 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1954 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1955 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1956 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1957 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1958 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1959 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1960 0.0f, 0.0f, 0.0117647f, 0.0705882f, 0.0705882f, 0.0705882f,
1961 0.494118f, 0.533333f, 0.686275f, 0.101961f, 0.65098f, 1.0f,
1962 0.968627f, 0.498039f, 0.0f, 0.0f, 0.0f, 0.0f,
1963 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1964 0.0f, 0.0f, 0.117647f, 0.141176f, 0.368627f, 0.603922f,
1965 0.666667f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.992157f,
1966 0.882353f, 0.67451f, 0.992157f, 0.94902f, 0.764706f, 0.25098f,
1967 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1968 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.192157f,
1969 0.933333f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.992157f,
1970 0.992157f, 0.992157f, 0.992157f, 0.984314f, 0.364706f, 0.321569f,
1971 0.321569f, 0.219608f, 0.152941f, 0.0f, 0.0f, 0.0f,
1972 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1973 0.0f, 0.0f, 0.0f, 0.0705882f, 0.858824f, 0.992157f,
1974 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.776471f, 0.713725f,
1975 0.968627f, 0.945098f, 0.0f, 0.0f, 0.0f, 0.0f,
1976 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1977 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1978 0.0f, 0.0f, 0.313725f, 0.611765f, 0.419608f, 0.992157f,
1979 0.992157f, 0.803922f, 0.0431373f, 0.0f, 0.168627f, 0.603922f,
1980 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1981 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1982 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1983 0.0f, 0.054902f, 0.00392157f, 0.603922f, 0.992157f, 0.352941f,
1984 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1985 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1986 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1987 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1988 0.0f, 0.545098f, 0.992157f, 0.745098f, 0.00784314f, 0.0f,
1989 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1990 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1991 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1992 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0431373f,
1993 0.745098f, 0.992157f, 0.27451f, 0.0f, 0.0f, 0.0f,
1994 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1995 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1996 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1997 0.0f, 0.0f, 0.0f, 0.0f, 0.137255f, 0.945098f,
1998 0.882353f, 0.627451f, 0.423529f, 0.00392157f, 0.0f, 0.0f,
1999 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2000 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2001 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2002 0.0f, 0.0f, 0.0f, 0.317647f, 0.941176f, 0.992157f,
2003 0.992157f, 0.466667f, 0.0980392f, 0.0f, 0.0f, 0.0f,
2004 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2005 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2006 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2007 0.0f, 0.0f, 0.176471f, 0.729412f, 0.992157f, 0.992157f,
2008 0.588235f, 0.105882f, 0.0f, 0.0f, 0.0f, 0.0f,
2009 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2010 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2011 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2012 0.0f, 0.0627451f, 0.364706f, 0.988235f, 0.992157f, 0.733333f,
2013 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2014 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2015 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2016 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2017 0.0f, 0.976471f, 0.992157f, 0.976471f, 0.25098f, 0.0f,
2018 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2019 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2020 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2021 0.0f, 0.0f, 0.180392f, 0.509804f, 0.717647f, 0.992157f,
2022 0.992157f, 0.811765f, 0.00784314f, 0.0f, 0.0f, 0.0f,
2023 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2024 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2025 0.0f, 0.0f, 0.0f, 0.0f, 0.152941f, 0.580392f,
2026 0.898039f, 0.992157f, 0.992157f, 0.992157f, 0.980392f, 0.713725f,
2027 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2028 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2029 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2030 0.0941176f, 0.447059f, 0.866667f, 0.992157f, 0.992157f, 0.992157f,
2031 0.992157f, 0.788235f, 0.305882f, 0.0f, 0.0f, 0.0f,
2032 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2033 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2034 0.0f, 0.0f, 0.0901961f, 0.258824f, 0.835294f, 0.992157f,
2035 0.992157f, 0.992157f, 0.992157f, 0.776471f, 0.317647f, 0.00784314f,
2036 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2037 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2038 0.0f, 0.0f, 0.0f, 0.0f, 0.0705882f, 0.670588f,
2039 0.858824f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.764706f,
2040 0.313725f, 0.0352941f, 0.0f, 0.0f, 0.0f, 0.0f,
2041 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2042 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2043 0.215686f, 0.67451f, 0.886275f, 0.992157f, 0.992157f, 0.992157f,
2044 0.992157f, 0.956863f, 0.521569f, 0.0431373f, 0.0f, 0.0f,
2045 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2046 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2047 0.0f, 0.0f, 0.0f, 0.0f, 0.533333f, 0.992157f,
2048 0.992157f, 0.992157f, 0.831373f, 0.529412f, 0.517647f, 0.0627451f,
2049 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2050 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2051 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2052 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2053 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2054 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2055 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2056 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2057 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2058 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2059 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2060 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2061 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2062 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2063 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2064 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2065 0.0f, 0.0f, 0.0f, 0.0f
2066 };
2067 float expected[] = {
2068 -0.836424f, -0.857365f, -1.62739f, -1.62739f, -0.836424f, 5.40742f,
2069 0.920853f, -0.692567f, -0.836424f, -0.534405f, -1.62739f, -0.836424f,
2070 1.32602f, 1.36312f, 0.112766f, -0.836424f, -0.192962f, 1.56975f,
2071 2.45777f, 0.944414f, -0.192962f, -1.5519f, -1.5519f, -0.554006f,
2072 -0.192962f, 1.4231f, -1.5519f, -0.192962f, 1.3661f, -1.5519f,
2073 -1.5519f, -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
2074 -0.843708f, 4.53065f, 0.0429584f, -0.796804f, -0.843708f, 0.3473f,
2075 -0.843708f, -0.843708f, -0.114439f, 3.14817f, 0.0811934f, -0.843708f
2076 };
2077 float kernel[] = {
2078 0.119643f, -0.237864f, 0.0462892f, 0.0502297f, -0.0134528f,
2079 0.146347f, 0.153133f, 0.0513307f, 0.0752369f, 0.0135557f,
2080 -0.111434f, 0.0941854f, 0.0788362f, 0.0299412f, 0.111762f,
2081 0.144066f, 0.00431504f, -0.0177954f, 0.0738092f, -0.0344215f,
2082 0.0832582f, 0.053989f, -0.112691f, 0.0962145f, 0.0186525f,
2083 -0.00660205f, -0.111962f, -0.126801f, -0.231625f, 0.17309f,
2084 0.0748875f, -0.179569f, -0.00513812f, -0.156579f, -0.147322f,
2085 0.184168f, 0.189308f, -0.200359f, -0.0156733f, 0.140649f,
2086 0.0858496f, -0.0263217f, -0.0740749f, -0.112563f, 0.107528f,
2087 0.0609729f, -0.221625f, 0.0769944f, -0.00900815f, -0.00136441f,
2088 -0.0236521f, -0.0418025f, -0.00286299f, 0.12241f, 0.0964093f,
2089 -0.0150897f, 0.0532171f, 0.0625916f, 0.116939f, 0.118024f,
2090 0.161918f, -0.00909767f, 0.100897f, -0.054563f, -0.175179f,
2091 -0.0687892f, 0.00734235f, 0.109833f, -0.113776f, 0.0595405f,
2092 -0.170255f, 0.0124815f, -0.0363301f, -0.0127038f, 0.0445554f,
2093 -0.0729894f, 0.107428f, -0.0341417f, 0.132619f, 0.00984557f,
2094 -0.00443654f, 0.202929f, 0.0945134f, 0.0148725f, 0.00998574f,
2095 -0.0226449f, 0.0478197f, -0.0793442f, 0.0707599f, -0.084225f,
2096 0.0865795f, 0.071104f, -0.047894f, 0.0838322f, 0.0635493f,
2097 -0.00370265f, -0.157247f, -0.0289622f, -0.0590963f, 0.13207f,
2098 0.00468011f, -0.0345372f, 0.217939f, 0.18861f, -0.0290393f,
2099 -0.0440664f, 0.0126197f, -0.129132f, -0.124943f, 0.0968156f,
2100 -0.0853643f, -0.182305f, 0.00461618f, -0.147095f, -0.230282f,
2101 0.00856019f, 0.0278893f, -0.0300229f, 0.0417871f, 0.0804717f,
2102 -0.0768571f, -0.0397085f, -0.0601096f, 0.100901f, -0.0184926f,
2103 0.0350673f, 0.0971094f, -0.0171837f, -0.289644f, -0.0899041f,
2104 0.08998f, -0.160319f, -0.0195103f, 0.0392167f, -0.137864f,
2105 -0.0136294f, 0.0330886f, -0.0409244f, -0.092533f, -0.0427934f,
2106 -0.191144f, -0.0969461f, 0.112035f, 0.138611f, 0.128717f,
2107 0.191184f, 0.197462f
2108 };
2109 float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
2110
2111 float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
2112 float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
2113 float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
2114 float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
2115
2116 CNN_BATCHNORM_PARAMS bn_params = {
2117 bn_gamma,
2118 bn_beta,
2119 bn_mean,
2120 bn_std,
2121 };
2122
2123 CNN_CONFIG cnn_config = {
2124 1,
2125 0,
2126 0,
2127 0,
2128 0,
2129 {
2130 {
2131 1,
2132 filter_width,
2133 filter_height,
2134 3,
2135 7,
2136 7,
2137 0,
2138 kernel,
2139 bias,
2140 PADDING_VALID,
2141 RELU,
2142 0,
2143 0,
2144 BRANCH_NO_COPY,
2145 BRANCH_NOC,
2146 {},
2147 bn_params,
2148 0,
2149 },
2150 },
2151 };
2152
2153 CNN_THREAD_DATA thread_data = { 1, NULL };
2154
2155 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2156 image_width, &thread_data, MSE_FLOAT_TOL);
2157 }
2158
TEST_F(CNNTest,TestMultithreading)2159 TEST_F(CNNTest, TestMultithreading) {
2160 int image_height = 2;
2161 int image_width = 2;
2162 int filter_height = 3;
2163 int filter_width = 3;
2164
2165 float input[] = {
2166 -2,
2167 4,
2168 1,
2169 0,
2170 };
2171
2172 float weights[] = {
2173 -4, 2, -2, 0, -4, 4, -3, -3, -3, -1, 1, 0, -5, -3, 0, -5, 0, 0,
2174 -1, 0, 2, -5, 0, 1, 4, 2, 1, 0, -2, -1, -5, -3, 2, -2, 1, -5,
2175 };
2176
2177 float bias[] = {
2178 -4,
2179 -3,
2180 -2,
2181 3,
2182 };
2183
2184 float expected[] = {
2185 2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2186 };
2187
2188 CNN_CONFIG cnn_config = {
2189 1,
2190 0,
2191 0,
2192 0,
2193 0,
2194 {
2195 {
2196 1,
2197 filter_width,
2198 filter_height,
2199 4,
2200 1,
2201 1,
2202 0,
2203 weights,
2204 bias,
2205 PADDING_SAME_ZERO,
2206 NONE,
2207 0,
2208 0,
2209 BRANCH_NO_COPY,
2210 BRANCH_NOC,
2211 {},
2212 {},
2213 0,
2214 },
2215 },
2216 };
2217
2218 CNN_THREAD_DATA thread_data = { 1, NULL };
2219
2220 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2221 image_width, &thread_data, MSE_FLOAT_TOL);
2222
2223 const AVxWorkerInterface *const winterface = aom_get_worker_interface();
2224 AVxWorker workers[4];
2225
2226 for (int i = 0; i < 4; ++i) {
2227 winterface->init(&workers[i]);
2228 }
2229
2230 thread_data = { 4, workers };
2231
2232 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2233 image_width, &thread_data, MSE_FLOAT_TOL);
2234
2235 for (int i = 0; i < 4; ++i) {
2236 winterface->end(&workers[i]);
2237 }
2238 }
2239
TEST_F(CNNTest,TestMultiOutput)2240 TEST_F(CNNTest, TestMultiOutput) {
2241 const int image_dim = 8;
2242 const int image_ch = 3;
2243 const int filter_dim = 2;
2244 const int stride = 2;
2245 const int num_filters = 2;
2246
2247 const float input_[] = {
2248 1.7537929121f, 0.134331551012f, 0.123580039877f, 0.957731845246f,
2249 0.391006834217f, 1.00699352042f, -0.778177955829f, -0.814166433059f,
2250 -0.656374394915f, 0.321967305228f, -2.19455719176f, 0.708035038966f,
2251 0.409148822266f, -0.318254408902f, 0.152450211189f, -0.250210793369f,
2252 0.826811563186f, 1.6804156584f, 0.273626975978f, 0.437936241887f,
2253 -0.329935520167f, -0.288761611645f, 0.156937008304f, 0.271054157295f,
2254 -0.0224828854332f, 1.70110336895f, -0.989066699309f, 1.30863131729f,
2255 -0.165813705702f, 0.00380178619265f, -0.0837342367587f, 0.760954783156f,
2256 -0.413610373524f, 1.17968204175f, 0.720295719536f, 0.308718974472f,
2257 -1.10091337671f, 0.693160033687f, -0.0202862320697f, 1.0221927503f,
2258 -1.24521801881f, -0.478501952308f, -1.71648619442f, -0.182571723636f,
2259 0.339292649504f, 2.0806519131f, 0.967974033444f, 0.175248672328f,
2260 0.0658124561472f, 0.795504169496f, 0.750592557361f, -1.46631013249f,
2261 -1.79052846838f, -1.03672179515f, -0.841985521653f, 1.20995011489f,
2262 0.140859718215f, -0.651552622661f, 0.451065110806f, 1.1189443693f,
2263 0.100213260593f, -0.834076868118f, -1.28734321611f, 1.22064420095f,
2264 -0.364143084361f, 0.750961509335f, -0.888689074553f, -0.8253547106f,
2265 -1.21800999027f, -0.966670603566f, 1.37384014741f, 0.47281264834f,
2266 -0.420416235531f, 0.520163906493f, 0.501296589423f, 1.53418976951f,
2267 0.715234751485f, 0.644551588907f, 0.0763504863375f, -0.0018541943723f,
2268 0.322853189656f, -0.795099723224f, -0.125177096675f, 1.4476577471f,
2269 -0.585888410088f, -1.44391754955f, -0.610543221933f, -0.221859179799f,
2270 0.252060200774f, -0.86287169623f, -0.0350246229157f, 1.0932311997f,
2271 0.899464648842f, -0.468806951704f, -0.300861137168f, 1.15776414206f,
2272 1.03268544738f, -0.171579585622f, -0.179136557119f, -0.354091003368f,
2273 -0.612298249394f, -1.20237379258f, 1.54604109659f, 0.130664370287f,
2274 0.885225111868f, 1.0362799581f, 0.980561720868f, -0.619379186999f,
2275 -1.33818929924f, -0.237233737961f, -1.89335425073f, 0.567821011321f,
2276 0.862420368465f, -1.37380916821f, 0.352190056666f, 0.611261516274f,
2277 0.393237747152f, 0.894686247967f, 0.190405182149f, 0.264872662911f,
2278 -0.0657009133797f, 0.0580512653493f, -0.401825294366f, 0.4106081318f,
2279 0.49484512188f, -0.0751103149442f, -1.43243736382f, 1.79855656009f,
2280 -1.1075351975f, 0.000354882733011f, -0.950716438608f, 1.27129831688f,
2281 1.00495189838f, 0.110358656713f, 1.08315032822f, -0.972676676218f,
2282 -0.0757668962831f, 1.88932045165f, -0.0672638136275f, 0.425913010161f,
2283 -0.781540372017f, 0.976000248609f, 0.687218504122f, 1.31374513445f,
2284 -0.932658930672f, -1.25339468479f, 0.422071294078f, -0.24189927912f,
2285 0.216906604642f, -1.88720997548f, 1.99252872889f, 0.353943735777f,
2286 0.737434784132f, -1.17848645017f, 1.70424254896f, 0.775297112968f,
2287 -0.516392797501f, 0.398130609129f, 0.737248101457f, 0.166282500886f,
2288 1.24699015468f, 0.47116183125f, 1.19091180182f, -0.372695424578f,
2289 0.219773209389f, -0.829467838962f, -0.52533122724f, 1.98707754595f,
2290 0.553692606972f, -0.933228902369f, 1.55427751643f, -1.08813399144f,
2291 -0.325686682094f, 0.205091443796f, -1.70381666435f, 0.466465327942f,
2292 1.73126863447f, -0.939133672634f, 1.48318077459f, -0.599414038168f,
2293 -1.1583078687f, 0.518116190201f, 0.133571482458f, 0.84958342672f,
2294 1.02205000597f, -0.0772082009087f, -1.69567503859f, 1.4697939436f,
2295 1.67813743122f, -0.627911582938f, 0.131380509137f, -1.35717850726f,
2296 };
2297 const float *input[3] = { input_, &input_[image_dim * image_dim],
2298 &input_[2 * image_dim * image_dim] };
2299
2300 const float bias[] = { 0.0f, 0.0f };
2301
2302 const float weights_1[] = {
2303 -0.489547413618f, 0.141916424749f, -0.279286485585f, -0.115322211094f,
2304 0.299572786936f, 0.205289980785f, -0.536254480088f, -0.253626313744f,
2305 -0.422883815849f, -0.169702966298f, -0.540104704793f, 0.495319646763f,
2306 0.298799079422f, -0.10054550901f, -0.306085047056f, 0.171061886165f,
2307 -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
2308 -0.157203423678f, -0.362138920529f, -0.216206085209f, 0.147502517971f,
2309 };
2310
2311 const float weights_2[] = {
2312 0.207580604357f, 0.480821146263f, -0.29111909562f, 0.47422567493f,
2313 0.206892553253f, -0.235067084092f, 0.354516800602f, -0.212399370252f,
2314 -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
2315 0.567044864811f, -0.060341127522f, 0.0501464839637f, -0.437785677916f,
2316 };
2317
2318 const float weights_3[] = {
2319 -0.0690452401448f, -0.356657338763f, -0.219464031809f, 0.551288365843f,
2320 0.181372090853f, -0.00245268542109f, 0.409000696276f, -0.593209108763f,
2321 0.587352566749f, -0.243720660227f, 0.266232713887f, -0.00439285245097f,
2322 0.252883228305f, 0.152646192631f, 0.0918944932026f, 0.398853715057f,
2323 };
2324
2325 const float weights_4[] = {
2326 0.207560791573f, 0.194201350401f, 0.227802322443f, 0.206533663345f,
2327 0.0557331066805f, 0.0224159800424f, -0.143939197467f, -0.27703361602f,
2328 0.130643888389f, -0.269456557461f, 0.186242862864f, -0.162879944774f,
2329 -0.145503996718f, -0.0768822987581f, -0.203127976359f, -0.238119922873f,
2330 -0.258806479994f, 0.0357957680385f, -0.1027606976f, -0.287920082345f,
2331 0.189047820993f, 0.250711538481f, -0.272815714175f, -0.0431449742024f,
2332 0.207261230996f, -0.0396472677451f, 0.131236557412f, 0.174291832499f,
2333 -0.251515885765f, -0.107164007499f, 0.185824534748f, -0.00561585838161f,
2334 0.273393799578f, -0.139563699075f, -0.263922456031f, -0.118859844081f,
2335 0.109230982597f, -0.170170294794f, 0.0123025648515f, -0.0839368964355f,
2336 -0.0774058234297f, 0.255847138286f, -0.208430879637f, 0.279170114319f,
2337 -0.272890330712f, -0.217725903006f, -0.295923275459f, -0.17008723953f,
2338 -0.284281803405f, 0.281406323629f, 0.266910044663f, -0.209963914338f,
2339 0.271980962964f, 0.142013581699f, -0.143896509026f, -0.290509242975f,
2340 -0.305768180935f, 0.196902832117f, -0.090424189662f, -0.147460802346f,
2341 0.217722016651f, 0.12353848977f, -0.169177363577f, -0.0454230918512f,
2342 };
2343
2344 const float expected_0[] = {
2345 -2.04858441055f, -2.12883075791f, -0.045177363807f, 0.763949675768f,
2346 -0.544361512821f, -1.58123168032f, 1.89319847039f, 0.16859080901f,
2347 -1.16023321135f, -0.396988107751f, 1.76637090744f, -1.40434786514f,
2348 0.908227575669f, 0.817064817605f, 0.215631134908f, -0.848605613428f,
2349 -0.106756747018f, 0.0193027166685f, 0.801345615113f, -0.395407237598f,
2350 -1.79983795658f, -1.73054496242f, 0.0584392594454f, -0.388786095569f,
2351 -0.237269619354f, 0.000843578271263f, -1.24043512104f, 0.487839445893f,
2352 -0.394259726605f, 0.559632843424f, -0.527224052291f, -1.53792340282f,
2353 };
2354
2355 const float expected_1[] = {
2356 0.0f, 0.0f, 0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
2357 0.0f, 1.22013465602f,
2358 };
2359
2360 const float expected_2[] = {
2361 0.156119444687f,
2362 0.517385299817f,
2363 };
2364
2365 const float expected_3[] = {
2366 0.224177852984f,
2367 0.503384419034f,
2368 0.156119444687f,
2369 0.517385299817f,
2370 };
2371
2372 const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
2373
2374 CNN_CONFIG cnn_config = {
2375 4, // num_layers
2376 0, // is_residue
2377 0, // ext_width
2378 0, // ext_height
2379 0, // strict_bounds
2380 {
2381 // layer_config
2382 {
2383 image_ch, // in_channels
2384 filter_dim, // filter_width
2385 filter_dim, // filter_height
2386 num_filters, // out_channels
2387 stride, // skip_width
2388 stride, // skip_height
2389 0, // max_pool
2390 weights_1, // weights
2391 bias, // bias
2392 PADDING_SAME_ZERO, // pad
2393 NONE, // activation
2394 0, // deconvolve
2395 0, // branch
2396 BRANCH_OUTPUT, // branch_copy_type
2397 BRANCH_NOC, // branch_combine_type
2398 { 2, 0, 0 }, // branch_config
2399 {}, // bn_params
2400 0, // output_num
2401 },
2402 {
2403 num_filters, // in_channels
2404 filter_dim, // filter_width
2405 filter_dim, // filter_height
2406 num_filters, // out_channels
2407 stride, // skip_width
2408 stride, // skip_height
2409 0, // max_pool
2410 weights_2, // weights
2411 bias, // bias
2412 PADDING_SAME_ZERO, // pad
2413 RELU, // activation
2414 0, // deconvolve
2415 0, // branch
2416 BRANCH_NO_COPY, // branch_copy_type
2417 BRANCH_NOC, // branch_combine_type
2418 {}, // branch_config
2419 {}, // bn_params
2420 1, // output_num
2421 },
2422 {
2423 num_filters, // in_channels
2424 filter_dim, // filter_width
2425 filter_dim, // filter_height
2426 num_filters, // out_channels
2427 stride, // skip_width
2428 stride, // skip_height
2429 0, // max_pool
2430 weights_3, // weights
2431 bias, // bias
2432 PADDING_SAME_ZERO, // pad
2433 RELU, // activation
2434 0, // deconvolve
2435 0, // branch
2436 BRANCH_NO_COPY, // branch_copy_type
2437 BRANCH_NOC, // branch_combine_type
2438 {}, // branch_config
2439 {}, // bn_params
2440 2, // output_num
2441 },
2442 {
2443 num_filters, // in_channels
2444 2 * filter_dim, // filter_width
2445 2 * filter_dim, // filter_height
2446 num_filters, // out_channels
2447 2 * stride, // skip_width
2448 2 * stride, // skip_height
2449 0, // max_pool
2450 weights_4, // weights
2451 bias, // bias
2452 PADDING_VALID, // pad
2453 RELU, // activation
2454 0, // deconvolve
2455 1, // branch
2456 BRANCH_NO_COPY, // branch_copy_type
2457 BRANCH_CAT, // branch_combine_type
2458 { 0, 0, 1 }, // branch_config
2459 {}, // bn_params
2460 3, // output_num
2461 },
2462 },
2463 };
2464
2465 CNN_THREAD_DATA thread_data = { 1, NULL };
2466
2467 const int num_outputs = 4;
2468 const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
2469 2 * filter_dim };
2470 const int output_dims[4] = { 4, 2, 1, 1 };
2471 const int output_sizes[4] = {
2472 output_chs[0] * output_dims[0] * output_dims[0],
2473 output_chs[1] * output_dims[1] * output_dims[1],
2474 output_chs[2] * output_dims[2] * output_dims[2],
2475 output_chs[3] * output_dims[3] * output_dims[3],
2476 };
2477 float *const output_ = (float *)aom_malloc(
2478 sizeof(*output_) *
2479 (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
2480 float *output[CNN_MAX_CHANNELS] = { nullptr };
2481 int ch_ite = 0;
2482 float *output_ite = output_;
2483 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
2484 for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
2485 output[ch_ite++] = output_ite;
2486 output_ite += output_dims[output_idx] * output_dims[output_idx];
2487 }
2488 }
2489 CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
2490 output };
2491
2492 RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
2493 &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
2494
2495 aom_free(output_);
2496 }
2497