• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 #include <stdio.h>
15 
16 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
17 
18 #include "config/av1_rtcd.h"
19 
20 #include "av1/encoder/cnn.h"
21 
22 #define SQR(x) ((x) * (x))
23 
24 // Best possible pixelwise guarenteed preicison given each float has at most
25 // 3 specified decimals.
26 #define PIXELWISE_FLOAT_TOL 1E-2
27 
28 #define MSE_FLOAT_TOL 1E-6
29 #define MSE_INT_TOL 0
30 
31 namespace {
32 
33 class CNNTest : public ::testing::Test {
34  protected:
RunCNNTest(int image_width,int image_height,const float * input,const float * expected,const CNN_CONFIG * cnn_config,int in_stride,CNN_THREAD_DATA * thread_data,double tolerance)35   static void RunCNNTest(int image_width, int image_height, const float *input,
36                          const float *expected, const CNN_CONFIG *cnn_config,
37                          int in_stride, CNN_THREAD_DATA *thread_data,
38                          double tolerance) {
39     int out_width, out_height, out_channels;
40     av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
41                              &out_height, &out_channels);
42 
43     const int out_size = out_width * out_height;
44     const int out_stride = out_width;
45 
46     float *output_ =
47         (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
48     float *output[CNN_MAX_CHANNELS] = { nullptr };
49     for (int channel = 0; channel < out_channels; ++channel) {
50       output[channel] = output_ + (channel * out_size);
51     }
52     const int num_outputs = 1;
53     const int output_chs[1] = { out_channels };
54     const int output_strides[1] = { out_stride };
55     CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
56                                     output };
57 
58     RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
59                        thread_data, &output_struct, &expected, tolerance);
60 
61     aom_free(output_);
62   }
63 
RunMultiOutCNNTest(const float ** input,int image_width,int image_height,int in_stride,const CNN_CONFIG * cnn_config,CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output,const float ** expected,double tolerance)64   static void RunMultiOutCNNTest(const float **input, int image_width,
65                                  int image_height, int in_stride,
66                                  const CNN_CONFIG *cnn_config,
67                                  CNN_THREAD_DATA *thread_data,
68                                  CNN_MULTI_OUT *output, const float **expected,
69                                  double tolerance) {
70     const int num_outputs = output->num_outputs;
71     const int *output_chs = output->output_channels;
72 
73     int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
74     int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
75     int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
76 
77     av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
78                              out_heights, not_used);
79     av1_cnn_predict(input, image_width, image_height, in_stride, cnn_config,
80                     thread_data, output);
81 
82     int channel_offset = 0;
83     for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
84       const float *expected_out = expected[output_idx];
85       const int curr_output_chs = output_chs[output_idx];
86       const int out_size = out_widths[output_idx] * out_heights[output_idx];
87 
88       double mse = 0;
89       int expected_ite = 0;
90       for (int channel = 0; channel < curr_output_chs; ++channel) {
91         const float *buf_out = output->output_buffer[channel_offset];
92 
93         for (int i = 0; i < out_size; ++i) {
94           EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
95                       PIXELWISE_FLOAT_TOL)
96               << " output " << output_idx << " channel " << channel << " pixel "
97               << expected_ite % out_size << ": " << expected_out[expected_ite]
98               << "/" << buf_out[i] << std::endl;
99           mse += SQR(expected_out[expected_ite] - buf_out[i]);
100           expected_ite++;
101         }
102 
103         channel_offset++;
104       }
105       mse /= (out_size * curr_output_chs);
106       EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
107     }
108 
109     aom_free(out_widths);
110     aom_free(out_heights);
111     aom_free(not_used);
112   }
113 
AssignLayerWeightsBiases(CNN_CONFIG * cnn_config,float * weights,float * bias)114   static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
115                                        float *bias) {
116     size_t weight_offset = 0;
117     size_t bias_offset = 0;
118     for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
119       CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
120       layer_config->weights = weights + weight_offset;
121       layer_config->bias = bias + bias_offset;
122       weight_offset += layer_config->filter_width *
123                        layer_config->filter_height * layer_config->in_channels *
124                        layer_config->out_channels;
125       bias_offset += layer_config->out_channels;
126 
127       ASSERT_NE(layer_config->weights, nullptr);
128       ASSERT_NE(layer_config->bias, nullptr);
129     }
130   }
131 };
132 
133 }  // namespace
134 
TEST_F(CNNTest,TestMultilayerConvolution)135 TEST_F(CNNTest, TestMultilayerConvolution) {
136   int image_height = 16;
137   int image_width = 16;
138   int filter_height = 5;
139   int filter_width = 4;
140 
141   float input[] = {
142     -3, 1,  -3, 2,  -2, -2, 2,  -2, 1,  -2, -3, 1,  2,  2,  2,  -2, 0,  1,  -1,
143     -3, -1, -1, 1,  0,  -3, 1,  0,  -1, 1,  0,  0,  -3, -3, -3, 0,  2,  1,  -1,
144     2,  0,  1,  -3, -1, 2,  2,  1,  -2, 0,  -1, 0,  -2, -2, -1, 1,  0,  0,  0,
145     -2, -2, -2, 1,  1,  -2, 1,  1,  -2, -2, 1,  -2, -1, -2, -3, 2,  -3, -1, 1,
146     0,  -2, -2, -2, 1,  -2, -2, -1, -1, 2,  2,  2,  -1, 1,  -3, -3, 0,  2,  0,
147     2,  1,  -3, -3, 1,  2,  2,  1,  -2, -3, 0,  -3, 0,  -3, -2, 0,  1,  1,  0,
148     -3, 2,  -1, 2,  1,  0,  1,  -2, 1,  -1, -1, 2,  0,  -2, -3, 1,  1,  -2, -1,
149     -3, -3, -1, 0,  -3, -2, 0,  0,  1,  0,  -3, -2, -1, 1,  0,  2,  1,  0,  -3,
150     -2, -3, -3, -1, 0,  -2, 2,  -1, -3, 0,  -1, -1, 2,  0,  -3, -2, -1, 0,  0,
151     1,  -2, 1,  2,  1,  2,  2,  -3, 2,  -1, 0,  0,  -1, 0,  2,  2,  -1, 2,  -2,
152     1,  1,  -3, -3, 1,  -1, -1, -2, 2,  -2, -2, 2,  -1, -3, 2,  -3, 1,  -1, -1,
153     -3, 1,  -1, 1,  0,  -3, -3, 1,  -3, -3, 0,  2,  2,  -2, -1, 2,  0,  2,  1,
154     -1, -3, 0,  0,  -1, -1, 1,  0,  2,  0,  -3, 2,  1,  0,  1,  -3, 2,  -3, -3,
155     -1, -3, -3, 2,  0,  2,  -2, 1,  -1,
156   };
157 
158   float weights[] = {
159     -2, 2,  -2, 2,  -1, -3, 2,  2,  0,  0,  -3, -1, -2, -3, 1,  -1, 0,  0,  0,
160     2,  -2, 2,  -2, -3, 1,  1,  1,  -3, -1, 0,  1,  2,  -2, 0,  -1, -3, -1, -2,
161     2,  -3, -3, 1,  -2, -3, 0,  2,  1,  -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
162     -1, -3, -1, -2, -2, -3, 2,  0,  -3, 0,  -3, -3, 1,  -3, -1, 0,  -1, 1,  1,
163     -1, 1,  -2, 0,  2,  0,  -3, 1,  -1, -1, 2,  0,  1,  -3, -3, 1,  2,  -3, -3,
164     1,  -3, 2,  0,  -3, 1,  2,  2,  -2, -1, -2, 1,  1,  0,  -2, -2, 1,  2,  -1,
165     -3, 1,  -2, 2,  -3, -2, -3, 2,  1,  0,  -2, 0,  1,  -3, 2,  -2, -2, 0,  2,
166     -3, 2,  0,  0,  1,  -2, 1,  1,  -2, -1, -2, 1,  -2, 0,  -2, -2, 0,  -1, -1,
167     -3, -3, -3, 1,  -3, -2, 2,  -1, 2,  0,  2,  -2, 2,  -2, 1,  -3, -3, -1, 0,
168     2,  2,  1,  -1, -3, -1, -3, 2,  1,  -2, 0,  -3, -1, -3, -1, 2,  1,  0,  2,
169     -1, 1,  0,  1,  2,  -1, -2, 2,  1,  -3, -1, -3, 0,  1,  -2, 0,  -2, -3, 0,
170     -2, 2,  2,  0,  0,  2,  -3, 2,  -3, -2, 1,  2,  -3, -3, -1, -3, 0,  -3, -3,
171     -2, -2, -2, 0,  0,  1,  0,  0,  -1, 0,  0,  -3, 0,  -3, -1, -2, 1,  -2, -1,
172     2,  -2, 0,  0,  1,  0,  -2, -1, 0,  -3, 1,  0,  -1, -3, 1,  -1, 1,  -1, -3,
173     1,  0,  1,  1,  -1, 2,  2,  0,  0,  1,  -3, 2,  -2, -2, -3, -2, -1, -2, 2,
174     0,  2,  -2, -3, -1, -3, 2,  2,  -1, 2,  2,  -1, 0,  -3, 1,
175   };
176 
177   float bias[] = {
178     1, -1, 0, 1, 1, 1, -2,
179   };
180 
181   float expected_same[] = {
182     -1125, 2926,  6406,  631,   -1244, 97,    -1454, 2526,  1065,  3292,  3464,
183     2553,  -330,  532,   1038,  1182,  -402,  3758,  3392,  9854,  4365,  1408,
184     4736,  3134,  3838,  2409,  3221,  4350,  6750,  4045,  815,   1188,  2959,
185     9802,  9590,  4572,  5740,  4253,  1701,  7974,  7012,  6854,  7093,  3907,
186     4539,  3886,  4267,  3505,  465,   7824,  9219,  10026, 7968,  957,   2295,
187     5594,  10811, 9641,  5950,  10043, 8783,  3132,  1421,  1110,  4108,  13929,
188     10660, -84,   -61,   3932,  -180,  6811,  13393, 15147, 15640, 9337,  6961,
189     3808,  1604,  1398,  1047,  6739,  10144, 6517,  4698,  2678,  7389,  2595,
190     5248,  12075, 11272, 13951, 8820,  1090,  2199,  2206,  2788,  12116, 6683,
191     2612,  -291,  3183,  9414,  12316, 14524, 12333, 13208, 7832,  4664,  4657,
192     3534,  1298,  -666,  4250,  7707,  9103,  5760,  688,   9571,  15782, 14203,
193     14878, 17339, 14684, 8690,  5671,  875,   1429,  1531,  6173,  2984,  5558,
194     2996,  7928,  6733,  16117, 15262, 12757, 7980,  3923,  4795,  5973,  2051,
195     455,   -1922, 1816,  5906,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
196     7451,  6666,  74,    -1645, -35,   -391,  3813,  7324,  892,   1656,  6095,
197     12193, 14648, 12156, 14663, 10251, 10325, 7821,  3925,  323,   697,   442,
198     1324,  4669,  7002,  5485,  5171,  5086,  10582, 11053, 9709,  11353, 8543,
199     5256,  2873,  235,   -628,  1496,  1878,  -867,  3420,  6865,  5937,  10182,
200     13277, 10069, 10789, 5998,  624,   -2082, 4417,  1258,  -1080, -819,  -1430,
201     1033,  5220,  6335,  8471,  8980,  11908, 14430, 12584, 8404,  1576,  -803,
202     985,   1481,  1367,  -193,  873,   3684,  2288,  6676,  9477,  11155, 9602,
203     9707,  10507, 4739,  3174,  -575,  -178,  3002,  1710,  423,   -477,  554,
204     3088,  2029,  5113,  5000,  3771,  6090,  5365,  1185,  2855,  399,   -312,
205     -1577, 176,   955,
206   };
207 
208   float expected_replicate[] = {
209     13768, 13528, 12999, 6906,  4618,  4043,  2611,  9955,  6685,  4776,  2753,
210     1036,  3063,  4544,  5183,  7349,  12451, 12501, 9131,  12753, 8908,  4058,
211     6299,  7542,  7115,  3307,  3360,  3543,  9754,  7808,  5991,  9019,  14320,
212     14919, 12492, 6871,  7373,  3336,  2085,  10604, 9377,  6882,  5009,  3103,
213     6220,  6278,  7588,  10196, 11045, 11563, 11842, 11911, 8279,  2030,  1858,
214     6368,  12123, 9909,  6347,  10345, 9365,  4038,  1673,  3051,  16492, 16649,
215     12276, 408,   -301,  4122,  -654,  7864,  14038, 15279, 15315, 9744,  8243,
216     5298,  746,   380,   9824,  9124,  10895, 6640,  4712,  2669,  6980,  2759,
217     5385,  12345, 11336, 13129, 8600,  2370,  3682,  5219,  12407, 13123, 6784,
218     2612,  -291,  3183,  9414,  12316, 14524, 12333, 13397, 7543,  3916,  4153,
219     4477,  4314,  7983,  8418,  9163,  9103,  5760,  688,   9571,  15782, 14203,
220     14878, 17718, 14570, 7940,  6642,  5094,  7133,  9964,  10219, 3224,  5558,
221     2996,  7928,  6733,  16117, 15262, 12757, 7958,  4401,  5187,  5476,  5529,
222     6055,  2206,  3909,  6015,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
223     6967,  6840,  481,   -1600, 274,   1,     10373, 8514,  1123,  2117,  6758,
224     12736, 16223, 13585, 15988, 11771, 10600, 7918,  4156,  2840,  3111,  3287,
225     6359,  7652,  8813,  6530,  6967,  7789,  13671, 13990, 13247, 13241, 9836,
226     5251,  3024,  2313,  1834,  4187,  2637,  -1312, 2139,  7378,  7665,  11933,
227     15591, 15314, 15678, 9531,  2820,  -1516, 3400,  1314,  22,    363,   -2896,
228     -898,  5906,  7308,  10650, 12975, 16978, 20370, 18817, 12381, 4118,  -861,
229     -137,  236,   1802,  1632,  -350,  2334,  3400,  8680,  14064, 18216, 18675,
230     21765, 22871, 11491, 4937,  -1555, -11,   1669,  2392,  3265,  -5254, -217,
231     5001,  8063,  13444, 18884, 19706, 22794, 21064, 9545,  6689,  -7,    289,
232     -2021, 504,   2347,
233   };
234 
235   float expected_valid[] = {
236     2612,  -291,  3183,  9414,  12316, 14524, 12333, 9103,  5760,  688,
237     9571,  15782, 14203, 14878, 5558,  2996,  7928,  6733,  16117, 15262,
238     12757, 3321,  10908, 10910, 7377,  12204, 12809, 11195,
239   };
240 
241   CNN_CONFIG cnn_config = { 3,
242                             0,
243                             0,
244                             0,
245                             0,
246                             {
247                                 {
248                                     1,
249                                     filter_width,
250                                     filter_height,
251                                     3,
252                                     1,
253                                     1,
254                                     0,
255                                     nullptr,
256                                     nullptr,
257                                     PADDING_SAME_ZERO,
258                                     NONE,
259                                     0,
260                                     0,
261                                     BRANCH_NO_COPY,
262                                     BRANCH_NOC,
263                                     {},
264                                     {},
265                                     -1,
266                                 },
267                                 {
268                                     3,
269                                     filter_width,
270                                     filter_height,
271                                     3,
272                                     1,
273                                     1,
274                                     0,
275                                     nullptr,
276                                     nullptr,
277                                     PADDING_SAME_ZERO,
278                                     NONE,
279                                     0,
280                                     0,
281                                     BRANCH_NO_COPY,
282                                     BRANCH_NOC,
283                                     {},
284                                     {},
285                                     -1,
286                                 },
287                                 {
288                                     3,
289                                     filter_width,
290                                     filter_height,
291                                     1,
292                                     1,
293                                     1,
294                                     0,
295                                     nullptr,
296                                     nullptr,
297                                     PADDING_SAME_ZERO,
298                                     NONE,
299                                     0,
300                                     0,
301                                     BRANCH_NO_COPY,
302                                     BRANCH_NOC,
303                                     {},
304                                     {},
305                                     0,
306                                 },
307                             } };
308 
309   // Weights and biases need to be specified separately because
310   // of the offset.
311   AssignLayerWeightsBiases(&cnn_config, weights, bias);
312 
313   CNN_THREAD_DATA thread_data = { 1, NULL };
314 
315   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
316              image_width, &thread_data, MSE_INT_TOL);
317 
318   for (int i = 0; i < cnn_config.num_layers; ++i) {
319     cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
320   }
321 
322   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
323              image_width, &thread_data, MSE_INT_TOL);
324 
325   for (int i = 0; i < cnn_config.num_layers; ++i) {
326     cnn_config.layer_config[i].pad = PADDING_VALID;
327   }
328 
329   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
330              image_width, &thread_data, MSE_INT_TOL);
331 }
332 
TEST_F(CNNTest,TestRELUSingleLayer)333 TEST_F(CNNTest, TestRELUSingleLayer) {
334   int image_width = 8;
335   int image_height = 8;
336   int filter_height = 5;
337   int filter_width = 4;
338   float input[] = {
339     0, -2, -3, 1,  -1, 2,  -2, 1,  -3, -1, 0,  1,  -2, -3, -2, -2,
340     1, -3, 2,  -3, -1, -1, 2,  0,  -2, -3, 0,  -2, -3, 1,  -1, -1,
341     2, -2, 0,  -2, -3, -3, 1,  1,  -1, 1,  0,  1,  -3, 0,  2,  2,
342     0, -3, 1,  -3, 2,  -2, 1,  -1, -1, -2, -3, -2, -1, -3, -2, -1,
343   };
344   float expected_same[] = {
345     9,  0,  1,  1,  0,  3,  0,  19, 0,  12, 10, 0,  0,  0,  5, 0,
346     0,  18, 21, 7,  19, 4,  3,  0,  0,  9,  16, 0,  11, 16, 0, 11,
347     12, 2,  0,  11, 0,  16, 6,  0,  8,  22, 13, 10, 12, 0,  0, 0,
348     0,  1,  2,  12, 29, 6,  10, 0,  13, 0,  0,  5,  8,  10, 0, 0,
349   };
350   float expected_replicate[] = {
351     18, 17, 12, 2,  0,  0,  5,  11, 0,  17, 22, 6,  0,  0,  17, 0,
352     0,  18, 21, 7,  19, 4,  3,  5,  3,  9,  16, 0,  11, 16, 0,  3,
353     3,  2,  0,  11, 0,  16, 6,  0,  17, 22, 13, 10, 12, 0,  0,  0,
354     0,  4,  1,  10, 30, 7,  10, 0,  23, 8,  0,  13, 15, 19, 8,  10,
355   };
356   float expected_valid[] = {
357     18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
358   };
359   float weights[] = {
360     -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
361   };
362   float bias[] = { -3 };
363 
364   CNN_CONFIG cnn_config = { 1,
365                             0,
366                             0,
367                             0,
368                             0,
369                             { {
370                                 1,
371                                 filter_width,
372                                 filter_height,
373                                 1,
374                                 1,
375                                 1,
376                                 0,
377                                 weights,
378                                 bias,
379                                 PADDING_SAME_ZERO,
380                                 RELU,
381                                 0,
382                                 0,
383                                 BRANCH_NO_COPY,
384                                 BRANCH_NOC,
385                                 {},
386                                 {},
387                                 0,
388                             } } };
389 
390   CNN_THREAD_DATA thread_data = { 1, NULL };
391 
392   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
393              image_width, &thread_data, MSE_INT_TOL);
394 
395   cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
396 
397   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
398              image_width, &thread_data, MSE_INT_TOL);
399 
400   cnn_config.layer_config[0].pad = PADDING_VALID;
401 
402   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
403              image_width, &thread_data, MSE_INT_TOL);
404 }
405 
TEST_F(CNNTest,TestVaryingStridesVaryingDimImages)406 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
407   float weights[] = {
408     1,  -5, -3, -4, -1, 1,  2,  -3, 2,  2,  -1, 1,  -5, 1,  1,
409     -3, -5, 3,  1,  4,  -2, -5, -2, -3, -5, 0,  -1, -5, 2,  -2,
410     -2, 1,  -2, -4, 1,  3,  -2, 2,  0,  -3, 2,  -3, -2, -3,
411   };
412   float bias[] = { 2 };
413 
414   CNN_CONFIG cnn_config = { 1,
415                             0,
416                             0,
417                             0,
418                             0,
419                             {
420                                 {
421                                     1,
422                                     4,
423                                     11,
424                                     1,
425                                     7,
426                                     6,
427                                     0,
428                                     weights,
429                                     bias,
430                                     PADDING_SAME_ZERO,
431                                     NONE,
432                                     0,
433                                     0,
434                                     BRANCH_NO_COPY,
435                                     BRANCH_NOC,
436                                     {},
437                                     {},
438                                     0,
439                                 },
440                             } };
441 
442   int image_height = 24;
443   int image_width = 17;
444   float input[] = {
445     -1, -3, 4,  4,  -5, 4,  3,  -5, -1, -3, 4,  -4, 2,  -3, 3,  -5, 2,  -1, -5,
446     1,  -1, 3,  1,  -3, -3, 4,  0,  2,  -3, -5, -5, -4, 0,  -5, -2, -3, -1, -2,
447     2,  -5, 4,  4,  0,  -4, -3, 1,  -3, -5, -4, -4, 1,  -2, -3, 3,  -3, -3, -1,
448     -5, -5, -2, 3,  1,  -1, -5, -5, 1,  -4, -2, -1, -2, -4, -4, 2,  -2, 2,  1,
449     -2, -4, -1, 1,  -2, -5, 3,  -2, -1, -1, -5, -3, 1,  -2, -2, -3, -1, -2, -4,
450     -2, 1,  -4, -1, 4,  3,  -4, 0,  4,  2,  2,  4,  -3, -5, 2,  2,  1,  -1, -4,
451     -2, 1,  3,  2,  0,  4,  -1, -3, 2,  1,  -4, 2,  2,  -4, -2, 0,  -2, -1, 4,
452     4,  2,  3,  -4, 2,  -4, -5, 4,  -1, -3, -1, 0,  -4, 1,  3,  -1, -3, -5, 3,
453     -2, -4, 1,  2,  -2, -3, -3, -5, 1,  -3, -1, 0,  -1, 3,  -4, -1, -5, -5, 1,
454     0,  0,  -2, -2, 2,  -2, 0,  0,  2,  0,  -3, 0,  -1, -4, -4, -1, 3,  -4, -4,
455     -1, 0,  -5, -3, -2, 4,  -3, -4, -4, 0,  -5, 1,  -2, -3, -3, -4, 4,  3,  4,
456     3,  3,  -1, 3,  1,  -3, -2, 3,  3,  0,  2,  -4, -3, 2,  2,  0,  -2, 4,  -2,
457     2,  -2, -1, -4, -2, 2,  -4, 3,  -1, 4,  1,  1,  4,  -1, -4, -4, 1,  1,  -2,
458     4,  -1, 3,  2,  -3, 4,  3,  1,  4,  0,  -4, 2,  0,  2,  4,  -2, -2, 4,  2,
459     -1, -2, 1,  -3, 2,  3,  -5, -3, 4,  4,  2,  -5, -4, -5, -2, -4, 2,  0,  2,
460     -5, 4,  -4, -2, -5, 2,  1,  0,  4,  1,  -2, -3, -4, -3, -4, 3,  3,  2,  0,
461     -3, 1,  -5, 4,  0,  4,  -1, 3,  -5, -5, -2, -1, -1, 4,  3,  3,  4,  3,  -4,
462     4,  -3, -3, -1, -4, -1, -4, -1, -2, 4,  -2, -4, 4,  4,  -3, -4, -1, 1,  2,
463     -1, -2, -2, 3,  2,  2,  -3, 0,  -1, 0,  3,  2,  -5, 0,  -4, 0,  0,  2,  -4,
464     -1, -1, 0,  -2, 0,  1,  0,  0,  4,  -5, -1, -5, 2,  -1, 0,  2,  -1, 1,  3,
465     -3, -5, -2, -3, 4,  -2, -2, -1, -3, -4, -1, -2, -4, 1,  4,  -3, -2, -1, 3,
466     -3, -2, 3,  2,  1,  -4, -3, -5, 1,
467   };
468   float expected_1[] = {
469     41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
470   };
471 
472   CNN_THREAD_DATA thread_data = { 1, NULL };
473 
474   RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
475              image_width, &thread_data, MSE_INT_TOL);
476 
477   cnn_config.layer_config[0].skip_width = 6;
478   cnn_config.layer_config[0].skip_height = 7;
479 
480   float expected_2[] = {
481     21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
482   };
483   RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
484              image_width, &thread_data, MSE_INT_TOL);
485 
486   cnn_config.layer_config[0].skip_width = 3;
487   cnn_config.layer_config[0].skip_height = 10;
488 
489   float expected_3[] = {
490     -26, -21, -35, 69, 49,  4,  -51, -43, -56,
491     -41, 15,  -44, 40, -62, 63, 38,  27,  47,
492   };
493   RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
494              image_width, &thread_data, MSE_INT_TOL);
495 
496   cnn_config.layer_config[0].skip_width = 10;
497   cnn_config.layer_config[0].skip_height = 3;
498 
499   float expected_4[] = {
500     21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
501   };
502 
503   RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
504              image_width, &thread_data, MSE_INT_TOL);
505 }
506 
TEST_F(CNNTest,TestMaxPool)507 TEST_F(CNNTest, TestMaxPool) {
508   int image_width = 8;
509   int image_height = 8;
510   int stride = 3;
511   float input[] = {
512     1,  -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8,  5,  -1, -1, 9,
513     -3, 0,  -2, 0, 6, 3, -4, 8,  7, 8, 7, -1, 4,  -1, 0,  2,
514     -5, -2, 8,  5, 5, 4, 2,  7,  4, 6, 2, 8,  8,  -4, -3, -4,
515     -3, -1, 2,  3, 3, 6, -5, 8,  9, 5, 0, -2, -1, 6,  5,  7,
516   };
517 
518   float expected[] = {
519     49, 58, 70, 68, 68, 70, 48, 57, 88,
520   };
521 
522   float weights[] = {
523     3, 1, 3, 4, -1, 5, -2, 1, -4,
524   };
525 
526   float bias[] = {
527     -3,
528   };
529 
530   CNN_CONFIG cnn_config = { 1,
531                             0,
532                             0,
533                             0,
534                             0,
535                             { {
536                                 1,
537                                 3,
538                                 3,
539                                 1,
540                                 stride,
541                                 stride,
542                                 1,
543                                 weights,
544                                 bias,
545                                 PADDING_SAME_ZERO,
546                                 NONE,
547                                 0,
548                                 0,
549                                 BRANCH_NO_COPY,
550                                 BRANCH_NOC,
551                                 {},
552                                 {},
553                                 0,
554                             } } };
555 
556   CNN_THREAD_DATA thread_data = { 1, NULL };
557 
558   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
559              image_width, &thread_data, MSE_INT_TOL);
560 }
561 
TEST_F(CNNTest,TestDeconvolveNonActivationSingleLayerSingleKernel)562 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
563   int image_width = 4;
564   int image_height = 7;
565   float input[] = {
566     9,  6,   181, 9,  218, 30, 80,  108, 68,  216, 70, 128, 179, 228,
567     33, 212, 34,  14, 48,  27, 230, 23,  202, 113, 80, 56,  122, 112,
568   };
569 
570   float expected_1_same[] = {
571     15,   -30,  36,   -525,  377, -193, 558, 531,  6,   -24,  -15,  124,
572     166,  -561, -356, -754,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
573     433,  -311, 711,  381,   247, -317, 453, 129,  215, -627, -409, -885,
574     17,   -255, -55,  -647,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
575     133,  -719, 633,  -225,  785, 191,  463, 79,   65,  9,    77,   -853,
576     -365, -949, -15,  -667,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
577     355,  -866, 990,  207,   747, 12,   520, -116, 176, -312, -133, -1370,
578     -426, -802, 143,  -771,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
579     65,   -79,  127,  -59,   135, -90,  195, 114,  31,  -91,  -57,  -133,
580     17,   -176, -72,  -276,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
581     457,  -302, 733,  58,    470, -475, 829, 490,  227, -670, -440, -790,
582     153,  -588, -294, -1150, -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
583     157,  -251, 349,  -185,  409, -293, 587, 251,  77,  -187, -107, -369,
584     7,    -481, -135, -827,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
585   };
586   float expected_1_valid[] = {
587     -30,  15,   -30,  36,   -525,  377,  -193,  558,  531,  24,   24,   6,
588     6,    -24,  -15,  124,  166,   -561, -356,  -754, -21,  -39,  -3,   -3,
589     -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -657, 433,  -311,
590     711,  381,  247,  -317, 453,   129,  321,   321,  215,  215,  -627, -409,
591     -885, 17,   -255, -55,  -647,  -219, -435,  -3,   -3,   -3,   -3,   -3,
592     -3,   -3,   -3,   -3,   -3,    -3,   -207,  133,  -719, 633,  -225, 785,
593     191,  463,  79,   381,  381,   65,   65,    9,    77,   -853, -365, -949,
594     -15,  -667, -259, -515, -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
595     -3,   -3,   -3,   -540, 355,   -866, 990,   207,  747,  12,   520,  -116,
596     633,  633,  176,  176,  -312,  -133, -1370, -426, -802, 143,  -771, -427,
597     -851, -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
598     -105, 65,   -79,  127,  -59,   135,  -90,   195,  114,  78,   78,   31,
599     31,   -91,  -57,  -133, 17,    -176, -72,   -276, -57,  -111, -3,   -3,
600     -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -693, 457,  -302,
601     733,  58,   470,  -475, 829,   490,  336,   336,  227,  227,  -670, -440,
602     -790, 153,  -588, -294, -1150, -229, -455,  -3,   -3,   -3,   -3,   -3,
603     -3,   -3,   -3,   -3,   -3,    -3,   -243,  157,  -251, 349,  -185, 409,
604     -293, 587,  251,  333,  333,   77,   77,    -187, -107, -369, 7,    -481,
605     -135, -827, -227, -451,
606   };
607   float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
608   float bias_1[] = { -3 };
609 
610   CNN_CONFIG cnn_config = { 1,
611                             0,
612                             0,
613                             0,
614                             0,
615                             { {
616                                 1,
617                                 5,
618                                 2,
619                                 1,
620                                 2,
621                                 3,
622                                 0,
623                                 weights_1,
624                                 bias_1,
625                                 PADDING_SAME_ZERO,
626                                 NONE,
627                                 1,
628                                 0,
629                                 BRANCH_NO_COPY,
630                                 BRANCH_NOC,
631                                 {},
632                                 {},
633                                 0,
634                             } } };
635 
636   CNN_THREAD_DATA thread_data = { 1, NULL };
637 
638   RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
639              image_width, &thread_data, MSE_INT_TOL);
640 
641   // Change padding to valid
642   cnn_config.layer_config[0].pad = PADDING_VALID;
643 
644   RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
645              image_width, &thread_data, MSE_INT_TOL);
646 
647   float expected_12_same[] = {
648     15,  -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,  24,
649     6,   -30,  -15,  -33,  -21,  166,  154,  -546, -356, -718, -30,  -21,
650     433, -221, 561,  711,  -33,  -153, 247,  -83,  -87,  453,  -111, 321,
651     215, -657, -409, -845, -93,  17,   -43,  -243, -55,  -215, -327, -219,
652     133, -71,  -447, 633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,
653     65,  -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259,
654     355, -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215, 633,
655     176, -540, -133, -491, -687, -426, -882, -102, 143,  77,   -639, -427,
656     65,  -37,  57,   127,  -17,  -105, 135,  -51,  60,   195,  -30,  78,
657     31,  -105, -57,  -125, -45,  17,   -11,  -147, -72,  -168, -84,  -57,
658     457, -233, 618,  733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,
659     227, -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229,
660     157, -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115, 333,
661     77,  -243, -107, -267, -171, 7,    -105, -369, -135, -379, -339, -227,
662   };
663   float expected_12_valid[] = {
664     -30,  15,   -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,
665     24,   24,   6,    6,    -30,  -15,  -33,  -21,  166,  154,  -546, -356,
666     -718, -30,  -21,  -39,  -657, 433,  -221, 561,  711,  -33,  -153, 247,
667     -83,  -87,  453,  -111, 321,  321,  215,  215,  -657, -409, -845, -93,
668     17,   -43,  -243, -55,  -215, -327, -219, -435, -207, 133,  -71,  -447,
669     633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,  381,  65,   65,
670     -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259, -515,
671     -540, 355,  -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215,
672     633,  633,  176,  176,  -540, -133, -491, -687, -426, -882, -102, 143,
673     77,   -639, -427, -851, -105, 65,   -37,  57,   127,  -17,  -105, 135,
674     -51,  60,   195,  -30,  78,   78,   31,   31,   -105, -57,  -125, -45,
675     17,   -11,  -147, -72,  -168, -84,  -57,  -111, -693, 457,  -233, 618,
676     733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,  336,  227,  227,
677     -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229, -455,
678     -243, 157,  -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115,
679     333,  333,  77,   77,   -243, -107, -267, -171, 7,    -105, -369, -135,
680     -379, -339, -227, -451,
681   };
682 
683   // Change skip_width, skip_height to {2, 3}
684   cnn_config.layer_config[0].skip_width = 3;
685   cnn_config.layer_config[0].skip_height = 2;
686   // Set padding to same
687   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
688 
689   RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
690              image_width, &thread_data, MSE_INT_TOL);
691 
692   // Change padding to valid
693   cnn_config.layer_config[0].pad = PADDING_VALID;
694   RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
695              image_width, &thread_data, MSE_INT_TOL);
696 
697   cnn_config.layer_config[0].filter_width = 4;
698   cnn_config.layer_config[0].filter_height = 3;
699   float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
700   float bias_2[] = { -4 };
701   cnn_config.layer_config[0].weights = weights_2;
702   cnn_config.layer_config[0].bias = bias_2;
703 
704   cnn_config.layer_config[0].skip_width = 5;
705   cnn_config.layer_config[0].skip_height = 2;
706   float expected_2_same[] = {
707     -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
708     -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   -4,   14,   -22,  32,
709     -4,   -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,
710     14,   -22,  32,   -4,   -195, -658, -213, -622, -4,   -16,  -94,  -28,
711     -70,  -4,   459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,
712     -4,   432,  -440, 868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,
713     -164, 316,  -4,   -4,   212,  -220, 428,  -4,   582,  -208, 146,  664,
714     -4,   -130, -652, -190, -532, -4,   166,  -214, 6,    106,  -4,   192,
715     -388, -24,  44,   -4,   -4,   132,  -140, 268,  -4,   -4,   428,  -436,
716     860,  -4,   -4,   136,  -144, 276,  -4,   -4,   252,  -260, 508,  -4,
717     21,   -541, -115, -269, -4,   416,  -688, -16,  176,  -4,   173,  -103,
718     33,   177,  -4,   168,  -640, -88,  -128, -4,   -4,   354,  -362, 712,
719     -4,   -4,   452,  -460, 908,  -4,   -4,   62,   -70,  128,  -4,   -4,
720     420,  -428, 844,  -4,   499,  -106, 141,  610,  -4,   666,  -46,  210,
721     866,  -4,   47,   -148, -19,  -16,  -4,   605,  -85,  181,  763,  -4,
722     -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,   -4,   -4,   92,
723     -100, 188,  -4,   -4,   50,   -58,  104,  -4,   -132, -694, -200, -558,
724     -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418, -4,   -36,
725     -343, -90,  -235, -4,   -4,   456,  -464, 916,  -4,   -4,   42,   -50,
726     88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,  -4,
727     606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
728     76,   438,  -4,   223,  -340, -3,   112,  -4,   -4,   156,  -164, 316,
729     -4,   -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,
730     220,  -228, 444,  -4,
731   };
732   float expected_2_valid[] = {
733     -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
734     -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   14,   -22,  32,   -4,
735     -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,   14,
736     -22,  32,   -195, -658, -213, -622, -4,   -16,  -94,  -28,  -70,  -4,
737     459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,   432,  -440,
738     868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,  -164, 316,  -4,
739     -4,   212,  -220, 428,  582,  -208, 146,  664,  -4,   -130, -652, -190,
740     -532, -4,   166,  -214, 6,    106,  -4,   192,  -388, -24,  44,   -4,
741     132,  -140, 268,  -4,   -4,   428,  -436, 860,  -4,   -4,   136,  -144,
742     276,  -4,   -4,   252,  -260, 508,  21,   -541, -115, -269, -4,   416,
743     -688, -16,  176,  -4,   173,  -103, 33,   177,  -4,   168,  -640, -88,
744     -128, -4,   354,  -362, 712,  -4,   -4,   452,  -460, 908,  -4,   -4,
745     62,   -70,  128,  -4,   -4,   420,  -428, 844,  499,  -106, 141,  610,
746     -4,   666,  -46,  210,  866,  -4,   47,   -148, -19,  -16,  -4,   605,
747     -85,  181,  763,  -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,
748     -4,   -4,   92,   -100, 188,  -4,   -4,   50,   -58,  104,  -132, -694,
749     -200, -558, -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418,
750     -4,   -36,  -343, -90,  -235, -4,   456,  -464, 916,  -4,   -4,   42,
751     -50,  88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,
752     606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
753     76,   438,  -4,   223,  -340, -3,   112,  -4,   156,  -164, 316,  -4,
754     -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,   220,
755     -228, 444,  236,  -4,   76,   316,  -4,   164,  -4,   52,   220,  -4,
756     362,  -4,   118,  484,  -4,   332,  -4,   108,  444,
757   };
758   // Set padding to same
759   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
760 
761   RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
762              image_width, &thread_data, MSE_INT_TOL);
763 
764   cnn_config.layer_config[0].pad = PADDING_VALID;
765 
766   RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
767              image_width, &thread_data, MSE_INT_TOL);
768 
769   cnn_config.layer_config[0].skip_width = 2;
770   cnn_config.layer_config[0].skip_height = 5;
771   float expected_21_same[] = {
772     -31,  -19,  -49,   -191, -565, -194, -574, -13,  14,   -22,  44,   -16,
773     382,  -366, 738,   -22,  -4,   23,   32,   545,  20,   204,  720,  5,
774     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
775     -4,   -4,   -4,    -4,   -658, -252, -748, -114, -334, -192, -568, -112,
776     432,  -440, 928,   -64,  276,  -164, 532,  -220, -4,   304,  868,  266,
777     116,  400,  316,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
778     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -208, -288, -856, -290,
779     -862, -202, -598,  -132, 132,  -140, 700,  -436, 1000, -144, 532,  -260,
780     -4,   712,  268,   422,  860,  450,  276,  124,  -4,   -4,   -4,   -4,
781     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
782     -541, -411, -1225, -265, -787, -249, -739, -216, 354,  -362, 1168, -460,
783     974,  -70,  552,   -428, -4,   859,  712,  323,  908,  665,  128,  208,
784     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
785     -4,   -4,   -4,    -4,   -106, -52,  -148, -66,  -190, -79,  -229, -31,
786     64,   -72,  160,   -32,  148,  -100, 242,  -58,  -4,   72,   132,  154,
787     52,   125,  188,   23,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
788     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -694, -257, -763, -229,
789     -679, -319, -949,  -117, 456,  -464, 962,  -50,  492,  -408, 1030, -230,
790     -4,   295,  916,   625,  88,   537,  804,  109,  -4,   -4,   -4,   -4,
791     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
792     -244, -140, -412,  -182, -538, -238, -706, -116, 156,  -164, 428,  -116,
793     464,  -248, 708,   -228, -4,   244,  316,  418,  220,  454,  484,  108,
794     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
795     -4,   -4,   -4,    -4,
796   };
797   float expected_21_valid[] = {
798     -13,  -31,  -19,  -49,  -191, -565, -194, -574, -13,  -31,   -4,   14,
799     -22,  44,   -16,  382,  -366, 738,  -22,  32,   23,   -4,    23,   32,
800     545,  20,   204,  720,  5,    32,   -4,   -4,   -4,   -4,    -4,   -4,
801     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
802     -4,   -4,   -222, -658, -252, -748, -114, -334, -192, -568,  -112, -328,
803     -4,   432,  -440, 928,  -64,  276,  -164, 532,  -220, 428,   650,  -4,
804     304,  868,  266,  116,  400,  316,  104,  428,  -4,   -4,    -4,   -4,
805     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
806     -4,   -4,   -4,   -4,   -72,  -208, -288, -856, -290, -862,  -202, -598,
807     -132, -388, -4,   132,  -140, 700,  -436, 1000, -144, 532,   -260, 508,
808     200,  -4,   712,  268,  422,  860,  450,  276,  124,  508,   -4,   -4,
809     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
810     -4,   -4,   -4,   -4,   -4,   -4,   -183, -541, -411, -1225, -265, -787,
811     -249, -739, -216, -640, -4,   354,  -362, 1168, -460, 974,   -70,  552,
812     -428, 844,  533,  -4,   859,  712,  323,  908,  665,  128,   208,  844,
813     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
814     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -38,  -106,  -52,  -148,
815     -66,  -190, -79,  -229, -31,  -85,  -4,   64,   -72,  160,   -32,  148,
816     -100, 242,  -58,  104,  98,   -4,   72,   132,  154,  52,    125,  188,
817     23,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
818     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -234, -694,
819     -257, -763, -229, -679, -319, -949, -117, -343, -4,   456,   -464, 962,
820     -50,  492,  -408, 1030, -230, 448,  686,  -4,   295,  916,   625,  88,
821     537,  804,  109,  448,  -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
822     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
823     -84,  -244, -140, -412, -182, -538, -238, -706, -116, -340,  -4,   156,
824     -164, 428,  -116, 464,  -248, 708,  -228, 444,  236,  -4,    244,  316,
825     418,  220,  454,  484,  108,  444,
826   };
827 
828   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
829 
830   RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
831              image_width, &thread_data, MSE_INT_TOL);
832 
833   cnn_config.layer_config[0].pad = PADDING_VALID;
834 
835   RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
836              image_width, &thread_data, MSE_INT_TOL);
837 }
838 
TEST_F(CNNTest,TestLargeKernelsAndStrides)839 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
840   float input_10x11[] = {
841     4,  4,  2,  4,  2,  -5, -2, 3, -1, 0,  0,  1,  2,  0,  -5, -2, -5, 1,  -3,
842     -1, 4,  -3, 2,  -2, 1,  0,  1, -3, -3, -4, -2, -2, 1,  -4, -1, 4,  1,  -4,
843     -4, -4, 3,  2,  -5, 3,  -5, 1, 2,  -4, 1,  -1, 3,  4,  -2, 3,  -3, 3,  0,
844     2,  -4, -5, -5, -2, -1, -2, 1, 1,  1,  -2, 4,  -5, 4,  -1, -1, 2,  3,  -4,
845     2,  2,  3,  0,  0,  1,  0,  3, 2,  3,  1,  -2, 3,  -4, 3,  2,  4,  -2, 0,
846     4,  -4, 1,  -3, -3, -3, -5, 1, -3, -5, 0,  4,  -1, -3, 2,
847   };
848 
849   float weights_10x11[] = {
850     -3, 4,  -4, -3, -5, 1,  -2, 3,  1,  -4, -4, 0,  -1, 0,  3,  1,  -3, -2, 0,
851     -1, 1,  3,  -4, -4, -3, -3, -2, 4,  3,  -5, 4,  2,  -3, 4,  -2, -1, 2,  -1,
852     -5, 0,  -3, 0,  3,  -5, -5, 3,  -4, -1, -5, 3,  4,  0,  4,  -5, 2,  -1, 2,
853     -1, -1, -1, -5, 0,  -4, 3,  -1, 1,  1,  -1, 3,  2,  -5, -4, 0,  -4, 4,  -5,
854     -3, 4,  -5, 2,  -5, -4, -4, -1, 3,  3,  0,  2,  -4, 1,  -2, 1,  1,  0,  3,
855     -2, 0,  1,  2,  4,  -3, -1, -5, -5, 2,  -4, 1,  1,  2,  -4, -2, -2, 2,  1,
856     3,  4,  -5, 1,  -1, -3, -3, -1, -2, -5, 1,  -1, 0,  1,  4,  4,  0,  0,  4,
857     -3, -1, -5, -3, 0,  1,  1,  1,  -5, 3,  4,  3,  -5, 3,  -2, -2, 0,  -4, 0,
858     0,  -2, 1,  -4, -1, 0,  -5, -2, -2, -5, -3, -3, 1,  1,  -3, 2,  4,  2,  4,
859     -4, -3, 3,  1,  1,  3,  -4, 4,  -2, -3, -3, -3, -3, -4, -2, 3,  -5, 2,  4,
860     -1, -4, -4, 4,  -2, -1, 3,  -3, -4, -4, -2, 4,  1,  0,  2,  -1, 4,  -3, 1,
861     4,  -3, 4,  4,  0,  -4, 3,  -2, -3, 2,  3,  -1, -3, 2,  1,  4,  -2, -3, 1,
862     4,  -2, 2,  -2, -5, -2, 1,  4,  -1, -4, 4,  -5, 2,  -5, -4, -1, -2, 3,  1,
863     2,  1,  -5, 1,  -5, -4, -1, -2, 2,  -2, -4, -3, -2, -2, 4,  -1, 2,  2,  -4,
864     2,  -2, 4,  -4, -2, -2, 1,  -1, 1,  1,  1,  -4, -5, -2, 3,  -4, -1, 3,  -2,
865     3,  2,  -5, -4, 0,  3,  -2, -4, -5, 3,  -2, -4, 2,  -2, 1,  -4, 0,  2,  -5,
866     1,  -4, -1, -1, 4,  -5, -4, 0,  -5, -4, -3, -5, -4, 0,  2,  0,  -4, 2,  -2,
867     1,  1,  -3, 2,  0,  -4, 0,  -4, 1,  0,  -5, -1, -1, -1, -5, 4,  2,  2,  -4,
868     3,  -2, -2, 2,  -3, -2, -1, 2,  -4, -5, 2,  -2, -4, -5, -5, -1, 2,  -1, 0,
869     -5, -2, -2, -5, 0,  1,  -1, -5, 0,  3,  2,  3,  0,  -3, -2, 0,  -5, -1, -2,
870     2,  -4, -1, 2,  2,  -5, 2,  -4, 0,  3,  -3, 1,  0,  0,  1,  -5, -3, 1,  -1,
871     0,  -4, -3, 2,  -4, -4, 4,  -1, 0,  1,  2,  -4, -5, 4,  -2, 1,  -4, -4, -3,
872     -1, -1, 1,  -1, -4, -1, -4, -3, 2,  -1, -2, -4, 1,  1,  0,  -2, 0,  -4, 3,
873     -3, 0,  -4, -1, -4, 2,  -1, -2, -5, -1, -2, -3, 3,  -1, 0,  -3, 0,  1,  -5,
874     1,  -5, 0,  1,
875   };
876 
877   float bias_10x11[] = { 3 };
878 
879   float expected_10x11[] = {
880     118,
881   };
882 
883   CNN_CONFIG cnn_config = { 1,
884                             0,
885                             0,
886                             0,
887                             0,
888                             { {
889                                 1,
890                                 23,
891                                 20,
892                                 1,
893                                 15,
894                                 20,
895                                 0,
896                                 weights_10x11,
897                                 bias_10x11,
898                                 PADDING_SAME_ZERO,
899                                 NONE,
900                                 0,
901                                 0,
902                                 BRANCH_NO_COPY,
903                                 BRANCH_NOC,
904                                 {},
905                                 {},
906                                 0,
907                             } } };
908 
909   int image_height = 10;
910   int image_width = 11;
911 
912   CNN_THREAD_DATA thread_data = { 1, NULL };
913 
914   RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
915              &cnn_config, image_width, &thread_data, MSE_INT_TOL);
916 
917   float input_11x10[] = {
918     -2, -2, 3,  -5, -1, -3, 1,  3,  2,  1,  1,  -5, 4,  1,  3,  -5, 3,  -3, -5,
919     0,  -1, -3, -3, 1,  1,  -5, -1, -5, -5, -3, 0,  1,  -3, -1, -3, -3, 0,  3,
920     4,  -4, -1, 3,  -3, -1, -3, 1,  -3, -2, -1, -4, -3, 2,  -4, 1,  -4, -1, -3,
921     -5, -1, 2,  3,  0,  2,  2,  -5, 4,  1,  2,  -1, -4, 4,  -4, -4, 0,  -1, 1,
922     -1, 1,  -3, -3, -2, 1,  2,  4,  4,  4,  -3, -3, 0,  1,  0,  1,  4,  1,  3,
923     4,  -3, -2, -4, 4,  2,  0,  3,  4,  -1, 2,  -2, 1,  -3, -2,
924   };
925 
926   float weights_11x10[] = {
927     4,  -1, 1,  -1, 2,  4,  3,  3,  -4, 3,  -5, 1,  -1, -1, -2, -2, 0,  2,  -3,
928     -2, 3,  -5, -1, 0,  -1, -2, -2, -1, 2,  4,  3,  1,  0,  0,  -3, 3,  -4, -1,
929     -5, 4,  -2, -2, 1,  2,  -1, -3, 1,  2,  -5, 1,  -3, 3,  3,  0,  -4, -4, -5,
930     -3, -4, -4, 4,  -2, 4,  4,  -2, 2,  -5, -1, -2, -5, -1, 4,  -3, 3,  -2, 0,
931     -4, -3, 0,  -1, -2, 4,  2,  0,  -2, -5, -4, 1,  4,  -4, -2, 2,  -2, 1,  1,
932     -4, 1,  -4, -4, -2, 4,  2,  -1, -5, -5, 1,  -3, -3, 3,  -3, -5, -3, 4,  -1,
933     -1, -3, 0,  -4, 3,  -1, 0,  -2, 0,  -5, -2, -5, 2,  0,  -5, 2,  3,  -2, 2,
934     4,  -1, 1,  -3, 2,  3,  2,  0,  -5, -4, -5, 2,  1,  1,  -1, -2, 3,  4,  2,
935     -2, 4,  -2, 3,  1,  -4, -3, -1, 4,  4,  -3, -5, -2, 2,  0,  3,  -2, 3,  -1,
936     -4, 0,  -2, 0,  3,  4,  -2, -3, -2, 0,  3,  4,  2,  -4, 0,  1,  2,  2,  -1,
937     -1, 4,  1,  4,  -2, -1, -1, -5, 1,  -3, 3,  3,  -1, -4, 3,  -5, 0,  0,  -1,
938     -4, -1, -2, 4,  -2, 3,  3,  -3, 1,  -1, 2,  -1, 4,  4,  -2, -2, 4,  -2, 0,
939     3,  -3, -5, -1, -2, 4,  -4, 2,  -4, 0,  -2, 3,  -3, 2,  2,  -2, -5, -1, 4,
940     3,  -2, -1, 3,  3,  -1, 3,  0,  -3, 0,  4,  2,  0,  -1, 4,  1,  1,  2,  1,
941     3,  1,  1,  1,  -3, -5, -4, 4,  -4, 2,  0,  0,  -4, 1,  4,  -5, 4,  4,  0,
942     1,  0,  -2, -4, -4, -3, 0,  1,  -5, 4,  0,  -3, -2, -4, 2,  4,  1,  -5, 1,
943     -4, 1,  0,  -3, -3, 0,  2,  -5, 4,  3,  -2, -5, 3,  1,  -1, 0,  3,  -2, -2,
944     3,  -2, -5, 4,  1,  -2, 2,  -1, 0,  4,  0,  -5, 3,  -2, 1,  2,  1,  -5, -3,
945     -2, -5, 4,  -4, 0,  3,  2,  -1, -4, -1, 2,  1,  -2, 3,  -1, -4, 2,  0,  -3,
946     1,  -1, 2,  -5, -4, -1, -5, 1,  4,  3,  4,  2,  -3, 1,  -5, -1, 3,  0,  -1,
947     -4, 3,  4,  -5, 4,  4,  -3, 2,  -3, -1, -3, -5, -3, 2,  -3, -2, 1,  1,  0,
948     -5, 3,  2,  1,  -5, 1,  1,  1,  3,  4,  -4, -1, -2, 0,  -5, -3, -5, -2, -4,
949     3,  3,  3,  4,  0,  -4, -1, -5, 0,  -3, 1,  4,  4,  -4, 4,  -5, -5, -1, -2,
950     -5, 3,  -4, 4,  3,  0,  -3, 2,  -2, 0,  0,  4,  4,  0,  -2, 1,  -1, -3, 2,
951     -1, 1,  -3, -5,
952   };
953 
954   float bias_11x10[] = {
955     -5,
956   };
957 
958   float expected_11x10[] = {
959     36,  -84,  95,   45,  18,   46,   77,  -54, -99,  -149, 66,  49,  161, 11,
960     39,  61,   -66,  61,  4,    -3,   34,  -44, -23,  31,   64,  29,  47,  72,
961     -27, -27,  121,  -3,  100,  1,    30,  -78, -12,  -89,  -59, 8,   -16, 112,
962     91,  -102, -26,  -4,  30,   54,   4,   -84, -24,  -58,  27,  -53, -33, 5,
963     53,  -26,  63,   50,  -103, -130, -23, 6,   -104, -207, 73,  23,  77,  132,
964     38,  32,   -130, -44, -60,  7,    27,  176, 45,   -32,  -2,  99,  -97, 63,
965     69,  126,  47,   63,  136,  -57,  5,   16,  -40,  -157, 8,   38,  -44, -10,
966     91,  7,    122,  140, 30,   -105, 4,   -1,  113,  64,   180, 141,
967   };
968 
969   cnn_config.layer_config[0].weights = weights_11x10;
970   cnn_config.layer_config[0].bias = bias_11x10;
971   cnn_config.layer_config[0].filter_width = 20;
972   cnn_config.layer_config[0].filter_height = 23;
973   cnn_config.layer_config[0].skip_width = 1;
974   cnn_config.layer_config[0].skip_height = 1;
975   image_height = 11;
976   image_width = 10;
977 
978   RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
979              &cnn_config, image_width, &thread_data, MSE_INT_TOL);
980 }
981 
TEST_F(CNNTest,TestSoftsignSingleLayer)982 TEST_F(CNNTest, TestSoftsignSingleLayer) {
983   int image_width = 8;
984   int image_height = 8;
985   int filter_height = 5;
986   int filter_width = 4;
987   float input[] = {
988     -0.5220f, 0.8410f,  -0.8990f, -0.0090f, 0.6710f,  -0.9470f, -0.8240f,
989     -0.0870f, 0.5380f,  0.4750f,  0.570f,   -0.3760f, -0.6960f, -0.5940f,
990     -0.3830f, 0.080f,   -0.0980f, -0.4940f, -0.4030f, 0.9460f,  -0.6020f,
991     0.4220f,  0.6190f,  0.6640f,  -0.9210f, -0.1470f, -0.2480f, -0.1120f,
992     -0.580f,  -0.0650f, 0.3330f,  0.9860f,  -0.7430f, 0.7610f,  0.4840f,
993     0.1030f,  0.9570f,  0.6120f,  -0.5240f, -0.1220f, -0.5850f, -0.270f,
994     0.7840f,  -0.9790f, 0.7290f,  -0.30f,   -0.6460f, 0.0780f,  0.4750f,
995     -0.0510f, 0.4550f,  0.3850f,  -0.7230f, 0.4460f,  -0.6260f, -0.810f,
996     0.8720f,  -0.2120f, -0.580f,  -0.9510f, -0.8430f, -0.1340f, -0.0850f,
997     0.9190f,
998   };
999   float expected_same[] = {
1000     0.430f,   0.660f,  0.5510f,  -0.610f,  0.450f,  -0.1610f, 0.0520f,  0.3240f,
1001     0.6820f,  0.3820f, 0.6360f,  0.7480f,  0.3080f, 0.090f,   0.3910f,  0.1730f,
1002     0.340f,   0.6660f, -0.4990f, 0.4280f,  0.1540f, 0.120f,   0.4670f,  0.6150f,
1003     -0.3880f, 0.7590f, 0.4190f,  0.7350f,  0.5310f, -0.5160f, -0.1760f, 0.6790f,
1004     -0.6780f, 0.5470f, 0.5750f,  -0.6420f, 0.7210f, -0.4620f, 0.5430f,  0.770f,
1005     -0.1990f, 0.3950f, 0.7860f,  -0.4380f, 0.7540f, 0.2640f,  -0.6430f, 0.4510f,
1006     -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f,   0.5870f,  0.4720f,
1007     0.4040f,  0.3630f, 0.670f,   0.2360f,  0.410f,  0.6980f,  -0.5350f, 0.3940f,
1008   };
1009   float expected_replicate[] = {
1010     0.540f,   0.7230f,  -0.3530f, -0.2130f, 0.7440f,  -0.4470f, -0.6260f,
1011     -0.2050f, 0.7230f,  0.4630f,  0.5920f,  0.7440f,  0.6080f,  0.3130f,
1012     -0.5670f, -0.4720f, 0.5480f,  0.6660f,  -0.4990f, 0.4280f,  0.1540f,
1013     0.120f,   0.3390f,  0.6090f,  0.4160f,  0.7590f,  0.4190f,  0.7350f,
1014     0.5310f,  -0.5160f, -0.490f,  0.4450f,  -0.610f,  0.5470f,  0.5750f,
1015     -0.6420f, 0.7210f,  -0.4620f, 0.3150f,  0.7370f,  -0.5820f, 0.3950f,
1016     0.7860f,  -0.4380f, 0.7540f,  0.2640f,  -0.7430f, -0.5340f, -0.6270f,
1017     0.4430f,  0.4730f,  0.4570f,  0.7450f,  0.630f,   0.2620f,  0.3140f,
1018     -0.1840f, 0.1810f,  0.7210f,  0.2760f,  0.6430f,  0.6720f,  -0.4390f,
1019     0.2040f,
1020   };
1021   float expected_valid[] = {
1022     0.6660f,  -0.4990f, 0.4280f,  0.1540f,  0.120f,  0.7590f,  0.4190f,
1023     0.7350f,  0.5310f,  -0.5160f, 0.5470f,  0.5750f, -0.6420f, 0.7210f,
1024     -0.4620f, 0.3950f,  0.7860f,  -0.4380f, 0.7540f, 0.2640f,
1025   };
1026   float weights[] = {
1027     0.6210f,  0.3710f,  -0.2770f, -0.7230f, -0.2450f, 0.6770f,  0.3080f,
1028     -0.9880f, -0.080f,  0.7190f,  -0.6760f, -0.0170f, -0.8970f, 0.8260f,
1029     0.7390f,  -0.4550f, -0.4260f, -0.6330f, 0.0880f,  -0.9390f,
1030   };
1031   float bias[] = {
1032     0.750f,
1033   };
1034 
1035   CNN_CONFIG cnn_config = { 1,
1036                             0,
1037                             0,
1038                             0,
1039                             0,
1040                             { {
1041                                 1,
1042                                 filter_width,
1043                                 filter_height,
1044                                 1,
1045                                 1,
1046                                 1,
1047                                 0,
1048                                 weights,
1049                                 bias,
1050                                 PADDING_SAME_ZERO,
1051                                 SOFTSIGN,
1052                                 0,
1053                                 0,
1054                                 BRANCH_NO_COPY,
1055                                 BRANCH_NOC,
1056                                 {},
1057                                 {},
1058                                 0,
1059                             } } };
1060 
1061   CNN_THREAD_DATA thread_data = { 1, NULL };
1062 
1063   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
1064              image_width, &thread_data, MSE_FLOAT_TOL);
1065 
1066   cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
1067 
1068   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
1069              image_width, &thread_data, MSE_FLOAT_TOL);
1070 
1071   cnn_config.layer_config[0].pad = PADDING_VALID;
1072 
1073   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
1074              image_width, &thread_data, MSE_FLOAT_TOL);
1075 }
1076 
TEST_F(CNNTest,TestBranchTensorAdd)1077 TEST_F(CNNTest, TestBranchTensorAdd) {
1078   int filter_width = 2;
1079   int filter_height = 3;
1080 
1081   int image_width = 4;
1082   int image_height = 4;
1083 
1084   float input[] = {
1085     -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1086   };
1087 
1088   float weights[] = {
1089     -3, -1, 4,  -1, -3, 3,  3,  0,  2,  0,  3,  2,  4,  4, 4,  -5, 1, -4,
1090     2,  -4, 1,  -3, 0,  4,  -5, 4,  0,  -4, -3, -1, 0,  0, -2, 0,  0, 2,
1091     -5, -1, 1,  -3, 3,  4,  3,  0,  1,  -1, 1,  1,  2,  4, -2, -5, 2, -2,
1092     3,  -2, 4,  -1, 0,  2,  3,  2,  -2, -1, -3, 1,  3,  4, -1, -3, 0, -4,
1093     4,  2,  -3, -3, -1, 0,  1,  0,  3,  3,  -3, 0,  3,  2, -5, -3, 4, -5,
1094     3,  -1, -1, -3, 0,  1,  -1, -4, 2,  4,  -1, 4,  -1, 1, 3,  4,  4, 4,
1095     0,  -1, -3, -3, -3, -3, 2,  -3, -2, 2,  3,  -3,
1096   };
1097 
1098   float bias[] = {
1099     3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1100   };
1101 
1102   float expected[] = {
1103     -11502, -4101, -3424, 668,   -17950, -5470, -5504, 626,
1104     4835,   446,   1779,  -3483, 3679,   -4214, 4578,  -105,
1105   };
1106 
1107   int channels = 2;
1108 
1109   CNN_CONFIG cnn_config = { 6,
1110                             0,
1111                             0,
1112                             0,
1113                             0,
1114                             { {
1115                                   1,
1116                                   filter_width,
1117                                   filter_height,
1118                                   channels,
1119                                   1,
1120                                   1,
1121                                   0,
1122                                   weights,
1123                                   bias,
1124                                   PADDING_SAME_ZERO,
1125                                   NONE,
1126                                   0,
1127                                   0,
1128                                   BRANCH_NO_COPY,
1129                                   BRANCH_NOC,
1130                                   {},
1131                                   {},
1132                                   -1,
1133                               },
1134                               {
1135                                   channels,
1136                                   filter_width,
1137                                   filter_height,
1138                                   channels,
1139                                   1,
1140                                   1,
1141                                   0,
1142                                   nullptr,
1143                                   nullptr,
1144                                   PADDING_SAME_ZERO,
1145                                   NONE,
1146                                   0,
1147                                   0,
1148                                   BRANCH_INPUT,
1149                                   BRANCH_NOC,
1150                                   {
1151                                       0x02,
1152                                       0,
1153                                       0x00,
1154                                   },
1155                                   {},
1156                                   -1,
1157                               },
1158                               {
1159                                   channels,
1160                                   filter_width,
1161                                   filter_height,
1162                                   channels,
1163                                   1,
1164                                   1,
1165                                   0,
1166                                   nullptr,
1167                                   nullptr,
1168                                   PADDING_SAME_ZERO,
1169                                   NONE,
1170                                   0,
1171                                   1,
1172                                   BRANCH_NO_COPY,
1173                                   BRANCH_NOC,
1174                                   {},
1175                                   {},
1176                                   -1,
1177                               },
1178                               {
1179                                   channels,
1180                                   filter_width,
1181                                   filter_height,
1182                                   channels,
1183                                   1,
1184                                   1,
1185                                   0,
1186                                   nullptr,
1187                                   nullptr,
1188                                   PADDING_SAME_ZERO,
1189                                   NONE,
1190                                   0,
1191                                   1,
1192                                   BRANCH_NO_COPY,
1193                                   BRANCH_NOC,
1194                                   {},
1195                                   {},
1196                                   -1,
1197                               },
1198                               {
1199                                   channels,
1200                                   filter_width,
1201                                   filter_height,
1202                                   channels,
1203                                   1,
1204                                   1,
1205                                   0,
1206                                   nullptr,
1207                                   nullptr,
1208                                   PADDING_SAME_ZERO,
1209                                   NONE,
1210                                   0,
1211                                   0,
1212                                   BRANCH_NO_COPY,
1213                                   BRANCH_ADD,
1214                                   {
1215                                       0x00,
1216                                       0,
1217                                       0x02,
1218                                   },
1219                                   {},
1220                                   -1,
1221                               },
1222                               {
1223                                   channels,
1224                                   filter_width,
1225                                   filter_height,
1226                                   1,
1227                                   1,
1228                                   1,
1229                                   0,
1230                                   nullptr,
1231                                   nullptr,
1232                                   PADDING_SAME_ZERO,
1233                                   NONE,
1234                                   0,
1235                                   0,
1236                                   BRANCH_NO_COPY,
1237                                   BRANCH_NOC,
1238                                   {},
1239                                   {},
1240                                   0,
1241                               } } };
1242 
1243   // Weights and biases need to be specified separately because
1244   // of the offset.
1245   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1246 
1247   CNN_THREAD_DATA thread_data = { 1, NULL };
1248 
1249   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1250              image_width, &thread_data, MSE_INT_TOL);
1251 }
1252 
TEST_F(CNNTest,TestBranchTensorConcatenation)1253 TEST_F(CNNTest, TestBranchTensorConcatenation) {
1254   int filter_width = 2;
1255   int filter_height = 3;
1256 
1257   int image_width = 4;
1258   int image_height = 4;
1259 
1260   float input[] = {
1261     -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1262   };
1263 
1264   float weights[] = {
1265     3,  0,  2,  0,  2,  3,  1,  -3, 1,  -5, -3, 0,  -4, 4,  0,  -5, 0,  -5, -1,
1266     -2, -5, 0,  -3, 2,  -4, 2,  0,  2,  -1, 0,  -4, 3,  0,  0,  -1, -5, 2,  -1,
1267     4,  -4, -2, -3, -3, 3,  4,  -2, -1, -4, -1, 4,  4,  -1, 4,  3,  -4, 2,  -2,
1268     -4, -3, -2, 3,  -3, -5, -1, 3,  -2, 4,  1,  -4, -3, -5, -5, -3, 4,  -2, -2,
1269     -1, -5, -5, 0,  -1, -2, -3, 3,  -4, -5, 2,  -3, 1,  0,  -5, 2,  2,  -2, 0,
1270     2,  2,  -2, 4,  2,  2,  0,  1,  -5, -3, 0,  2,  -2, 1,  2,  -5, 2,  3,  3,
1271     -1, 3,  0,  -3, 3,  -4, -4, 3,  3,  -4, -2, 2,  -2, 2,  -2, -1, 3,  0,
1272   };
1273 
1274   float bias[] = {
1275     -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1276   };
1277 
1278   float expected[] = {
1279     -33533, -32087, -6741,  -2124, 39979, 41453, 14034, 689,
1280     -22611, -42203, -14882, -239,  15781, 15963, 9524,  837,
1281   };
1282 
1283   int channels = 2;
1284 
1285   CNN_CONFIG cnn_config = { 6,
1286                             0,
1287                             0,
1288                             0,
1289                             0,
1290                             { {
1291                                   1,
1292                                   filter_width,
1293                                   filter_height,
1294                                   channels,
1295                                   1,
1296                                   1,
1297                                   0,
1298                                   weights,
1299                                   bias,
1300                                   PADDING_SAME_ZERO,
1301                                   NONE,
1302                                   0,
1303                                   0,
1304                                   BRANCH_NO_COPY,
1305                                   BRANCH_NOC,
1306                                   {},
1307                                   {},
1308                                   -1,
1309                               },
1310                               {
1311                                   channels,
1312                                   filter_width,
1313                                   filter_height,
1314                                   channels,
1315                                   1,
1316                                   1,
1317                                   0,
1318                                   nullptr,
1319                                   nullptr,
1320                                   PADDING_SAME_ZERO,
1321                                   NONE,
1322                                   0,
1323                                   0,
1324                                   BRANCH_INPUT,
1325                                   BRANCH_NOC,
1326                                   {
1327                                       0x02,
1328                                       0,
1329                                       0x00,
1330                                   },
1331                                   {},
1332                                   -1,
1333                               },
1334                               {
1335                                   channels,
1336                                   filter_width,
1337                                   filter_height,
1338                                   channels,
1339                                   1,
1340                                   1,
1341                                   0,
1342                                   nullptr,
1343                                   nullptr,
1344                                   PADDING_SAME_ZERO,
1345                                   NONE,
1346                                   0,
1347                                   1,
1348                                   BRANCH_NO_COPY,
1349                                   BRANCH_NOC,
1350                                   {},
1351                                   {},
1352                                   -1,
1353                               },
1354                               {
1355                                   channels,
1356                                   filter_width,
1357                                   filter_height,
1358                                   channels,
1359                                   1,
1360                                   1,
1361                                   0,
1362                                   nullptr,
1363                                   nullptr,
1364                                   PADDING_SAME_ZERO,
1365                                   NONE,
1366                                   0,
1367                                   1,
1368                                   BRANCH_NO_COPY,
1369                                   BRANCH_NOC,
1370                                   {},
1371                                   {},
1372                                   -1,
1373                               },
1374                               {
1375                                   channels,
1376                                   filter_width,
1377                                   filter_height,
1378                                   channels,
1379                                   1,
1380                                   1,
1381                                   0,
1382                                   nullptr,
1383                                   nullptr,
1384                                   PADDING_SAME_ZERO,
1385                                   NONE,
1386                                   0,
1387                                   0,
1388                                   BRANCH_NO_COPY,
1389                                   BRANCH_CAT,
1390                                   {
1391                                       0x00,
1392                                       0,
1393                                       0x02,
1394                                   },
1395                                   {},
1396                                   -1,
1397                               },
1398                               {
1399                                   channels + channels,
1400                                   filter_width,
1401                                   filter_height,
1402                                   1,
1403                                   1,
1404                                   1,
1405                                   0,
1406                                   nullptr,
1407                                   nullptr,
1408                                   PADDING_SAME_ZERO,
1409                                   NONE,
1410                                   0,
1411                                   0,
1412                                   BRANCH_NO_COPY,
1413                                   BRANCH_NOC,
1414                                   {},
1415                                   {},
1416                                   0,
1417                               } } };
1418 
1419   // Weights and biases need to be specified separately because
1420   // of the offset.
1421   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1422 
1423   CNN_THREAD_DATA thread_data = { 1, NULL };
1424 
1425   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1426              image_width, &thread_data, MSE_INT_TOL);
1427 }
1428 
1429 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1430 
TEST_F(CNNTest,TestBranchCombinations)1431 TEST_F(CNNTest, TestBranchCombinations) {
1432   int filter_width = 2;
1433   int filter_height = 3;
1434 
1435   int image_width = 4;
1436   int image_height = 4;
1437 
1438   float input[] = {
1439     3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1440   };
1441 
1442   float weights[] = {
1443     2,  3,  0,  4,  4,  3,  1,  0,  1,  -5, 4,  -3, 3,  0,  4,  -1, -1, -5,
1444     2,  1,  -3, -5, 3,  -1, -3, -2, 0,  -2, 3,  0,  -2, -4, -2, -2, 2,  -5,
1445     4,  -5, 0,  1,  -5, -4, -3, -4, 2,  -2, 1,  0,  3,  -2, -4, 3,  4,  -4,
1446     -1, -1, -3, -2, -2, -1, 2,  0,  2,  -1, 2,  -4, -4, -1, 2,  0,  3,  -2,
1447     -2, 3,  -3, 4,  -2, 4,  3,  4,  1,  0,  -2, -3, -5, 1,  -3, 2,  0,  -2,
1448     -2, -1, -1, -5, -2, -3, -1, 3,  3,  4,  4,  0,  2,  1,  3,  -3, 2,  -5,
1449     -5, 1,  -5, -1, 3,  3,  2,  -4, -1, 3,  -4, -2, -5, -2, 1,  3,  2,  2,
1450     -5, -2, -3, -1, -2, -4, -1, -2, 2,  1,  -4, -4, 2,  0,  2,  0,  2,  -3,
1451     -2, -4, 4,  0,  1,  -3, -5, 4,  -1, 2,  3,  -5, -1, 0,  4,  -1, -1, 3,
1452     -1, -3, 3,  1,  4,  3,  4,  3,  -4, -5, -1, 3,  3,  -4, 3,  1,  3,  -5,
1453     3,  4,  -5, 4,  2,  -1, -5, 2,  1,  0,  4,  0,  -3, 2,  0,  2,  -2, 1,
1454     -1, -2, -1, -5, 4,  3,  3,  -2, 2,  4,  -5, -5, -3, -2, 4,  0,  -4, 1,
1455   };
1456 
1457   float bias[] = {
1458     -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1459   };
1460 
1461   float expected[] = {
1462     149496, 15553,  -24193, -20956, 134094, 86432,  -68283, -6366,
1463     -53031, 133739, 67407,  -13539, -53205, -58635, -20033, 1979,
1464   };
1465 
1466   int channels = 2;
1467 
1468   CNN_CONFIG cnn_config = { 10,
1469                             0,
1470                             0,
1471                             0,
1472                             0,
1473                             {
1474                                 {
1475                                     1,
1476                                     filter_width,
1477                                     filter_height,
1478                                     channels,
1479                                     1,
1480                                     1,
1481                                     0,
1482                                     weights,
1483                                     bias,
1484                                     PADDING_SAME_ZERO,
1485                                     NONE,
1486                                     0,
1487                                     0,
1488                                     BRANCH_NO_COPY,
1489                                     BRANCH_NOC,
1490                                     {},
1491                                     {},
1492                                     -1,
1493                                 },
1494                                 {
1495                                     channels,
1496                                     filter_width,
1497                                     filter_height,
1498                                     channels,
1499                                     1,
1500                                     1,
1501                                     0,
1502                                     nullptr,
1503                                     nullptr,
1504                                     PADDING_SAME_ZERO,
1505                                     NONE,
1506                                     0,
1507                                     0,
1508                                     BRANCH_INPUT,
1509                                     BRANCH_NOC,
1510                                     {
1511                                         0x06,
1512                                         0,
1513                                         0x00,
1514                                     },
1515                                     {},
1516                                     -1,
1517                                 },
1518                                 {
1519                                     channels,
1520                                     filter_width,
1521                                     filter_height,
1522                                     channels,
1523                                     1,
1524                                     1,
1525                                     0,
1526                                     nullptr,
1527                                     nullptr,
1528                                     PADDING_SAME_ZERO,
1529                                     NONE,
1530                                     0,
1531                                     2,
1532                                     BRANCH_OUTPUT,
1533                                     BRANCH_NOC,
1534                                     {
1535                                         0x08,
1536                                         0,
1537                                         0x00,
1538                                     },
1539                                     {},
1540                                     -1,
1541                                 },
1542                                 {
1543                                     channels,
1544                                     filter_width,
1545                                     filter_height,
1546                                     channels,
1547                                     1,
1548                                     1,
1549                                     0,
1550                                     nullptr,
1551                                     nullptr,
1552                                     PADDING_SAME_ZERO,
1553                                     NONE,
1554                                     0,
1555                                     3,
1556                                     BRANCH_NO_COPY,
1557                                     BRANCH_NOC,
1558                                     {},
1559                                     {},
1560                                     -1,
1561                                 },
1562                                 {
1563                                     channels,
1564                                     filter_width,
1565                                     filter_height,
1566                                     channels,
1567                                     1,
1568                                     1,
1569                                     0,
1570                                     nullptr,
1571                                     nullptr,
1572                                     PADDING_SAME_ZERO,
1573                                     NONE,
1574                                     0,
1575                                     2,
1576                                     BRANCH_NO_COPY,
1577                                     BRANCH_ADD,
1578                                     {
1579                                         0x00,
1580                                         0,
1581                                         0x08,
1582                                     },
1583                                     {},
1584                                     -1,
1585                                 },
1586                                 {
1587                                     channels,
1588                                     filter_width,
1589                                     filter_height,
1590                                     channels,
1591                                     1,
1592                                     1,
1593                                     0,
1594                                     nullptr,
1595                                     nullptr,
1596                                     PADDING_SAME_ZERO,
1597                                     NONE,
1598                                     0,
1599                                     2,
1600                                     BRANCH_NO_COPY,
1601                                     BRANCH_NOC,
1602                                     {},
1603                                     {},
1604                                     -1,
1605                                 },
1606                                 {
1607                                     channels,
1608                                     filter_width,
1609                                     filter_height,
1610                                     channels,
1611                                     1,
1612                                     1,
1613                                     0,
1614                                     nullptr,
1615                                     nullptr,
1616                                     PADDING_SAME_ZERO,
1617                                     NONE,
1618                                     0,
1619                                     1,
1620                                     BRANCH_NO_COPY,
1621                                     BRANCH_NOC,
1622                                     {},
1623                                     {},
1624                                     -1,
1625                                 },
1626                                 {
1627                                     channels,
1628                                     filter_width,
1629                                     filter_height,
1630                                     channels,
1631                                     1,
1632                                     1,
1633                                     0,
1634                                     nullptr,
1635                                     nullptr,
1636                                     PADDING_SAME_ZERO,
1637                                     NONE,
1638                                     0,
1639                                     1,
1640                                     BRANCH_NO_COPY,
1641                                     BRANCH_ADD,
1642                                     {
1643                                         0x00,
1644                                         0,
1645                                         0x0C,
1646                                     },
1647                                     {},
1648                                     -1,
1649                                 },
1650                                 {
1651                                     channels,
1652                                     filter_width,
1653                                     filter_height,
1654                                     channels,
1655                                     1,
1656                                     1,
1657                                     0,
1658                                     nullptr,
1659                                     nullptr,
1660                                     PADDING_SAME_ZERO,
1661                                     NONE,
1662                                     0,
1663                                     0,
1664                                     BRANCH_NO_COPY,
1665                                     BRANCH_ADD,
1666                                     {
1667                                         0x00,
1668                                         0,
1669                                         0x02,
1670                                     },
1671                                     {},
1672                                     -1,
1673                                 },
1674                                 {
1675                                     channels,
1676                                     filter_width,
1677                                     filter_height,
1678                                     1,
1679                                     1,
1680                                     1,
1681                                     0,
1682                                     nullptr,
1683                                     nullptr,
1684                                     PADDING_SAME_ZERO,
1685                                     NONE,
1686                                     0,
1687                                     0,
1688                                     BRANCH_NO_COPY,
1689                                     BRANCH_NOC,
1690                                     {},
1691                                     {},
1692                                     0,
1693                                 },
1694                             } };
1695 
1696   // Weights and biases need to be specified separately because
1697   // of the offset.
1698   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1699 
1700   CNN_THREAD_DATA thread_data = { 1, NULL };
1701 
1702   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1703              image_width, &thread_data, MSE_INT_TOL);
1704 }
1705 
TEST_F(CNNTest,TestSplittingTensors)1706 TEST_F(CNNTest, TestSplittingTensors) {
1707   int filter_width = 2;
1708   int filter_height = 3;
1709 
1710   int image_width = 4;
1711   int image_height = 4;
1712 
1713   float input[] = {
1714     -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1715   };
1716 
1717   float weights[] = {
1718     -4, 1,  0,  2,  3,  4,  4,  -4, -5, -3, 2,  2,  -4, -3, 3,  2,
1719     4,  -4, -3, -4, -4, 1,  -3, -5, -3, 4,  2,  -2, 2,  -1, -4, -1,
1720     -2, -3, 1,  1,  0,  -5, -1, 3,  3,  -5, -3, 0,  -3, 1,  -3, -1,
1721     1,  -3, -2, -2, 4,  -2, 0,  1,  2,  2,  -4, 2,  4,  0,  -5, -2,
1722     4,  4,  -5, 1,  0,  2,  -2, -5, -5, -3, -5, -5, 4,  -3, 0,  0,
1723     -4, -4, 0,  -5, -4, 0,  0,  -3, -5, -3, -1, 2,  -1, 4,  -1, 2,
1724   };
1725 
1726   float bias[] = {
1727     -4, -2, -3, -3, 3, 1, -2,
1728   };
1729 
1730   float expected[] = {
1731     530,  -762,  1469, 777,  849,   -771, -1698, 600,
1732     -658, -1821, 98,   -668, -1798, 30,   887,   -971,
1733   };
1734 
1735   CNN_CONFIG cnn_config = { 3,
1736                             0,
1737                             0,
1738                             0,
1739                             0,
1740                             {
1741                                 {
1742                                     1,
1743                                     filter_width,
1744                                     filter_height,
1745                                     4,
1746                                     1,
1747                                     1,
1748                                     0,
1749                                     nullptr,
1750                                     nullptr,
1751                                     PADDING_SAME_ZERO,
1752                                     NONE,
1753                                     0,
1754                                     0,
1755                                     BRANCH_OUTPUT,
1756                                     BRANCH_NOC,
1757                                     {
1758                                         0x02,
1759                                         2,
1760                                         0x00,
1761                                     },
1762                                     {},
1763                                     -1,
1764                                 },
1765                                 {
1766                                     4,
1767                                     filter_width,
1768                                     filter_height,
1769                                     2,
1770                                     1,
1771                                     1,
1772                                     0,
1773                                     nullptr,
1774                                     nullptr,
1775                                     PADDING_SAME_ZERO,
1776                                     NONE,
1777                                     0,
1778                                     0,
1779                                     BRANCH_NO_COPY,
1780                                     BRANCH_CAT,
1781                                     {
1782                                         0x00,
1783                                         0,
1784                                         0x02,
1785                                     },
1786                                     {},
1787                                     -1,
1788                                 },
1789                                 {
1790                                     4,
1791                                     filter_width,
1792                                     filter_height,
1793                                     1,
1794                                     1,
1795                                     1,
1796                                     0,
1797                                     nullptr,
1798                                     nullptr,
1799                                     PADDING_SAME_ZERO,
1800                                     NONE,
1801                                     0,
1802                                     0,
1803                                     BRANCH_NO_COPY,
1804                                     BRANCH_NOC,
1805                                     {},
1806                                     {},
1807                                     0,
1808                                 },
1809                             } };
1810 
1811   // Weights and biases need to be specified separately because
1812   // of the offset.
1813   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1814 
1815   CNN_THREAD_DATA thread_data = { 1, NULL };
1816 
1817   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1818              image_width, &thread_data, MSE_INT_TOL);
1819 }
1820 
TEST_F(CNNTest,TestOutputChannelsCount)1821 TEST_F(CNNTest, TestOutputChannelsCount) {
1822   int filter_width = 1;
1823   int filter_height = 1;
1824 
1825   int image_width = 2;
1826   int image_height = 2;
1827 
1828   float input[] = { 0, 0, 0, 0 };
1829 
1830   float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1831 
1832   float bias[] = { 0, 0, 0, 0, 0, 0 };
1833 
1834   float expected[] = {
1835     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1836     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1837   };
1838 
1839   CNN_CONFIG cnn_config = { 3,
1840                             0,
1841                             0,
1842                             0,
1843                             0,
1844                             {
1845                                 {
1846                                     1,
1847                                     filter_width,
1848                                     filter_height,
1849                                     2,
1850                                     1,
1851                                     1,
1852                                     0,
1853                                     weights,
1854                                     bias,
1855                                     PADDING_SAME_ZERO,
1856                                     NONE,
1857                                     0,
1858                                     0,
1859                                     BRANCH_INPUT,
1860                                     BRANCH_NOC,
1861                                     {
1862                                         0x06,
1863                                         0,
1864                                         0x00,
1865                                     },
1866                                     {},
1867                                     -1,
1868                                 },
1869                                 {
1870                                     1,
1871                                     filter_width,
1872                                     filter_height,
1873                                     2,
1874                                     1,
1875                                     1,
1876                                     0,
1877                                     weights,
1878                                     bias,
1879                                     PADDING_SAME_ZERO,
1880                                     NONE,
1881                                     0,
1882                                     2,
1883                                     BRANCH_NO_COPY,
1884                                     BRANCH_CAT,
1885                                     {
1886                                         0x00,
1887                                         0,
1888                                         0x03,
1889                                     },
1890                                     {},
1891                                     -1,
1892                                 },
1893                                 {
1894                                     2,
1895                                     filter_width,
1896                                     filter_height,
1897                                     2,
1898                                     1,
1899                                     1,
1900                                     0,
1901                                     weights,
1902                                     bias,
1903                                     PADDING_SAME_ZERO,
1904                                     NONE,
1905                                     0,
1906                                     0,
1907                                     BRANCH_NO_COPY,
1908                                     BRANCH_CAT,
1909                                     {
1910                                         0x00,
1911                                         0,
1912                                         0x04,
1913                                     },
1914                                     {},
1915                                     0,
1916                                 },
1917                             } };
1918 
1919   // Weights and biases need to be specified separately because
1920   // of the offset.
1921   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1922 
1923   CNN_THREAD_DATA thread_data = { 1, NULL };
1924 
1925   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1926              image_width, &thread_data, MSE_FLOAT_TOL);
1927 }
1928 
TEST_F(CNNTest,TestBatchNorm)1929 TEST_F(CNNTest, TestBatchNorm) {
1930   int image_width = 28;
1931   int image_height = 28;
1932   int filter_height = 7;
1933   int filter_width = 7;
1934   float input[] = {
1935     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1936     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1937     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1938     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1939     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1940     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1941     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1942     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1943     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1944     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1945     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1946     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1947     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1948     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1949     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1950     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1951     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1952     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1953     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1954     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1955     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1956     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1957     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1958     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1959     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1960     0.0f,       0.0f,       0.0117647f,  0.0705882f,  0.0705882f,  0.0705882f,
1961     0.494118f,  0.533333f,  0.686275f,   0.101961f,   0.65098f,    1.0f,
1962     0.968627f,  0.498039f,  0.0f,        0.0f,        0.0f,        0.0f,
1963     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1964     0.0f,       0.0f,       0.117647f,   0.141176f,   0.368627f,   0.603922f,
1965     0.666667f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
1966     0.882353f,  0.67451f,   0.992157f,   0.94902f,    0.764706f,   0.25098f,
1967     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1968     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.192157f,
1969     0.933333f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
1970     0.992157f,  0.992157f,  0.992157f,   0.984314f,   0.364706f,   0.321569f,
1971     0.321569f,  0.219608f,  0.152941f,   0.0f,        0.0f,        0.0f,
1972     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1973     0.0f,       0.0f,       0.0f,        0.0705882f,  0.858824f,   0.992157f,
1974     0.992157f,  0.992157f,  0.992157f,   0.992157f,   0.776471f,   0.713725f,
1975     0.968627f,  0.945098f,  0.0f,        0.0f,        0.0f,        0.0f,
1976     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1977     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1978     0.0f,       0.0f,       0.313725f,   0.611765f,   0.419608f,   0.992157f,
1979     0.992157f,  0.803922f,  0.0431373f,  0.0f,        0.168627f,   0.603922f,
1980     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1981     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1982     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1983     0.0f,       0.054902f,  0.00392157f, 0.603922f,   0.992157f,   0.352941f,
1984     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1985     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1986     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1987     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1988     0.0f,       0.545098f,  0.992157f,   0.745098f,   0.00784314f, 0.0f,
1989     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1990     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1991     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1992     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0431373f,
1993     0.745098f,  0.992157f,  0.27451f,    0.0f,        0.0f,        0.0f,
1994     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1995     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1996     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1997     0.0f,       0.0f,       0.0f,        0.0f,        0.137255f,   0.945098f,
1998     0.882353f,  0.627451f,  0.423529f,   0.00392157f, 0.0f,        0.0f,
1999     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2000     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2001     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2002     0.0f,       0.0f,       0.0f,        0.317647f,   0.941176f,   0.992157f,
2003     0.992157f,  0.466667f,  0.0980392f,  0.0f,        0.0f,        0.0f,
2004     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2005     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2006     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2007     0.0f,       0.0f,       0.176471f,   0.729412f,   0.992157f,   0.992157f,
2008     0.588235f,  0.105882f,  0.0f,        0.0f,        0.0f,        0.0f,
2009     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2010     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2011     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2012     0.0f,       0.0627451f, 0.364706f,   0.988235f,   0.992157f,   0.733333f,
2013     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2014     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2015     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2016     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2017     0.0f,       0.976471f,  0.992157f,   0.976471f,   0.25098f,    0.0f,
2018     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2019     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2020     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2021     0.0f,       0.0f,       0.180392f,   0.509804f,   0.717647f,   0.992157f,
2022     0.992157f,  0.811765f,  0.00784314f, 0.0f,        0.0f,        0.0f,
2023     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2024     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2025     0.0f,       0.0f,       0.0f,        0.0f,        0.152941f,   0.580392f,
2026     0.898039f,  0.992157f,  0.992157f,   0.992157f,   0.980392f,   0.713725f,
2027     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2028     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2029     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2030     0.0941176f, 0.447059f,  0.866667f,   0.992157f,   0.992157f,   0.992157f,
2031     0.992157f,  0.788235f,  0.305882f,   0.0f,        0.0f,        0.0f,
2032     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2033     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2034     0.0f,       0.0f,       0.0901961f,  0.258824f,   0.835294f,   0.992157f,
2035     0.992157f,  0.992157f,  0.992157f,   0.776471f,   0.317647f,   0.00784314f,
2036     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2037     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2038     0.0f,       0.0f,       0.0f,        0.0f,        0.0705882f,  0.670588f,
2039     0.858824f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.764706f,
2040     0.313725f,  0.0352941f, 0.0f,        0.0f,        0.0f,        0.0f,
2041     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2042     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2043     0.215686f,  0.67451f,   0.886275f,   0.992157f,   0.992157f,   0.992157f,
2044     0.992157f,  0.956863f,  0.521569f,   0.0431373f,  0.0f,        0.0f,
2045     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2046     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2047     0.0f,       0.0f,       0.0f,        0.0f,        0.533333f,   0.992157f,
2048     0.992157f,  0.992157f,  0.831373f,   0.529412f,   0.517647f,   0.0627451f,
2049     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2050     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2051     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2052     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2053     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2054     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2055     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2056     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2057     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2058     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2059     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2060     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2061     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2062     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2063     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2064     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2065     0.0f,       0.0f,       0.0f,        0.0f
2066   };
2067   float expected[] = {
2068     -0.836424f, -0.857365f, -1.62739f,  -1.62739f,  -0.836424f, 5.40742f,
2069     0.920853f,  -0.692567f, -0.836424f, -0.534405f, -1.62739f,  -0.836424f,
2070     1.32602f,   1.36312f,   0.112766f,  -0.836424f, -0.192962f, 1.56975f,
2071     2.45777f,   0.944414f,  -0.192962f, -1.5519f,   -1.5519f,   -0.554006f,
2072     -0.192962f, 1.4231f,    -1.5519f,   -0.192962f, 1.3661f,    -1.5519f,
2073     -1.5519f,   -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
2074     -0.843708f, 4.53065f,   0.0429584f, -0.796804f, -0.843708f, 0.3473f,
2075     -0.843708f, -0.843708f, -0.114439f, 3.14817f,   0.0811934f, -0.843708f
2076   };
2077   float kernel[] = {
2078     0.119643f,    -0.237864f,   0.0462892f,   0.0502297f,   -0.0134528f,
2079     0.146347f,    0.153133f,    0.0513307f,   0.0752369f,   0.0135557f,
2080     -0.111434f,   0.0941854f,   0.0788362f,   0.0299412f,   0.111762f,
2081     0.144066f,    0.00431504f,  -0.0177954f,  0.0738092f,   -0.0344215f,
2082     0.0832582f,   0.053989f,    -0.112691f,   0.0962145f,   0.0186525f,
2083     -0.00660205f, -0.111962f,   -0.126801f,   -0.231625f,   0.17309f,
2084     0.0748875f,   -0.179569f,   -0.00513812f, -0.156579f,   -0.147322f,
2085     0.184168f,    0.189308f,    -0.200359f,   -0.0156733f,  0.140649f,
2086     0.0858496f,   -0.0263217f,  -0.0740749f,  -0.112563f,   0.107528f,
2087     0.0609729f,   -0.221625f,   0.0769944f,   -0.00900815f, -0.00136441f,
2088     -0.0236521f,  -0.0418025f,  -0.00286299f, 0.12241f,     0.0964093f,
2089     -0.0150897f,  0.0532171f,   0.0625916f,   0.116939f,    0.118024f,
2090     0.161918f,    -0.00909767f, 0.100897f,    -0.054563f,   -0.175179f,
2091     -0.0687892f,  0.00734235f,  0.109833f,    -0.113776f,   0.0595405f,
2092     -0.170255f,   0.0124815f,   -0.0363301f,  -0.0127038f,  0.0445554f,
2093     -0.0729894f,  0.107428f,    -0.0341417f,  0.132619f,    0.00984557f,
2094     -0.00443654f, 0.202929f,    0.0945134f,   0.0148725f,   0.00998574f,
2095     -0.0226449f,  0.0478197f,   -0.0793442f,  0.0707599f,   -0.084225f,
2096     0.0865795f,   0.071104f,    -0.047894f,   0.0838322f,   0.0635493f,
2097     -0.00370265f, -0.157247f,   -0.0289622f,  -0.0590963f,  0.13207f,
2098     0.00468011f,  -0.0345372f,  0.217939f,    0.18861f,     -0.0290393f,
2099     -0.0440664f,  0.0126197f,   -0.129132f,   -0.124943f,   0.0968156f,
2100     -0.0853643f,  -0.182305f,   0.00461618f,  -0.147095f,   -0.230282f,
2101     0.00856019f,  0.0278893f,   -0.0300229f,  0.0417871f,   0.0804717f,
2102     -0.0768571f,  -0.0397085f,  -0.0601096f,  0.100901f,    -0.0184926f,
2103     0.0350673f,   0.0971094f,   -0.0171837f,  -0.289644f,   -0.0899041f,
2104     0.08998f,     -0.160319f,   -0.0195103f,  0.0392167f,   -0.137864f,
2105     -0.0136294f,  0.0330886f,   -0.0409244f,  -0.092533f,   -0.0427934f,
2106     -0.191144f,   -0.0969461f,  0.112035f,    0.138611f,    0.128717f,
2107     0.191184f,    0.197462f
2108   };
2109   float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
2110 
2111   float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
2112   float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
2113   float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
2114   float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
2115 
2116   CNN_BATCHNORM_PARAMS bn_params = {
2117     bn_gamma,
2118     bn_beta,
2119     bn_mean,
2120     bn_std,
2121   };
2122 
2123   CNN_CONFIG cnn_config = {
2124     1,
2125     0,
2126     0,
2127     0,
2128     0,
2129     {
2130         {
2131             1,
2132             filter_width,
2133             filter_height,
2134             3,
2135             7,
2136             7,
2137             0,
2138             kernel,
2139             bias,
2140             PADDING_VALID,
2141             RELU,
2142             0,
2143             0,
2144             BRANCH_NO_COPY,
2145             BRANCH_NOC,
2146             {},
2147             bn_params,
2148             0,
2149         },
2150     },
2151   };
2152 
2153   CNN_THREAD_DATA thread_data = { 1, NULL };
2154 
2155   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2156              image_width, &thread_data, MSE_FLOAT_TOL);
2157 }
2158 
TEST_F(CNNTest,TestMultithreading)2159 TEST_F(CNNTest, TestMultithreading) {
2160   int image_height = 2;
2161   int image_width = 2;
2162   int filter_height = 3;
2163   int filter_width = 3;
2164 
2165   float input[] = {
2166     -2,
2167     4,
2168     1,
2169     0,
2170   };
2171 
2172   float weights[] = {
2173     -4, 2, -2, 0,  -4, 4, -3, -3, -3, -1, 1,  0,  -5, -3, 0, -5, 0, 0,
2174     -1, 0, 2,  -5, 0,  1, 4,  2,  1,  0,  -2, -1, -5, -3, 2, -2, 1, -5,
2175   };
2176 
2177   float bias[] = {
2178     -4,
2179     -3,
2180     -2,
2181     3,
2182   };
2183 
2184   float expected[] = {
2185     2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2186   };
2187 
2188   CNN_CONFIG cnn_config = {
2189     1,
2190     0,
2191     0,
2192     0,
2193     0,
2194     {
2195         {
2196             1,
2197             filter_width,
2198             filter_height,
2199             4,
2200             1,
2201             1,
2202             0,
2203             weights,
2204             bias,
2205             PADDING_SAME_ZERO,
2206             NONE,
2207             0,
2208             0,
2209             BRANCH_NO_COPY,
2210             BRANCH_NOC,
2211             {},
2212             {},
2213             0,
2214         },
2215     },
2216   };
2217 
2218   CNN_THREAD_DATA thread_data = { 1, NULL };
2219 
2220   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2221              image_width, &thread_data, MSE_FLOAT_TOL);
2222 
2223   const AVxWorkerInterface *const winterface = aom_get_worker_interface();
2224   AVxWorker workers[4];
2225 
2226   for (int i = 0; i < 4; ++i) {
2227     winterface->init(&workers[i]);
2228   }
2229 
2230   thread_data = { 4, workers };
2231 
2232   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2233              image_width, &thread_data, MSE_FLOAT_TOL);
2234 
2235   for (int i = 0; i < 4; ++i) {
2236     winterface->end(&workers[i]);
2237   }
2238 }
2239 
TEST_F(CNNTest,TestMultiOutput)2240 TEST_F(CNNTest, TestMultiOutput) {
2241   const int image_dim = 8;
2242   const int image_ch = 3;
2243   const int filter_dim = 2;
2244   const int stride = 2;
2245   const int num_filters = 2;
2246 
2247   const float input_[] = {
2248     1.7537929121f,     0.134331551012f,    0.123580039877f,   0.957731845246f,
2249     0.391006834217f,   1.00699352042f,     -0.778177955829f,  -0.814166433059f,
2250     -0.656374394915f,  0.321967305228f,    -2.19455719176f,   0.708035038966f,
2251     0.409148822266f,   -0.318254408902f,   0.152450211189f,   -0.250210793369f,
2252     0.826811563186f,   1.6804156584f,      0.273626975978f,   0.437936241887f,
2253     -0.329935520167f,  -0.288761611645f,   0.156937008304f,   0.271054157295f,
2254     -0.0224828854332f, 1.70110336895f,     -0.989066699309f,  1.30863131729f,
2255     -0.165813705702f,  0.00380178619265f,  -0.0837342367587f, 0.760954783156f,
2256     -0.413610373524f,  1.17968204175f,     0.720295719536f,   0.308718974472f,
2257     -1.10091337671f,   0.693160033687f,    -0.0202862320697f, 1.0221927503f,
2258     -1.24521801881f,   -0.478501952308f,   -1.71648619442f,   -0.182571723636f,
2259     0.339292649504f,   2.0806519131f,      0.967974033444f,   0.175248672328f,
2260     0.0658124561472f,  0.795504169496f,    0.750592557361f,   -1.46631013249f,
2261     -1.79052846838f,   -1.03672179515f,    -0.841985521653f,  1.20995011489f,
2262     0.140859718215f,   -0.651552622661f,   0.451065110806f,   1.1189443693f,
2263     0.100213260593f,   -0.834076868118f,   -1.28734321611f,   1.22064420095f,
2264     -0.364143084361f,  0.750961509335f,    -0.888689074553f,  -0.8253547106f,
2265     -1.21800999027f,   -0.966670603566f,   1.37384014741f,    0.47281264834f,
2266     -0.420416235531f,  0.520163906493f,    0.501296589423f,   1.53418976951f,
2267     0.715234751485f,   0.644551588907f,    0.0763504863375f,  -0.0018541943723f,
2268     0.322853189656f,   -0.795099723224f,   -0.125177096675f,  1.4476577471f,
2269     -0.585888410088f,  -1.44391754955f,    -0.610543221933f,  -0.221859179799f,
2270     0.252060200774f,   -0.86287169623f,    -0.0350246229157f, 1.0932311997f,
2271     0.899464648842f,   -0.468806951704f,   -0.300861137168f,  1.15776414206f,
2272     1.03268544738f,    -0.171579585622f,   -0.179136557119f,  -0.354091003368f,
2273     -0.612298249394f,  -1.20237379258f,    1.54604109659f,    0.130664370287f,
2274     0.885225111868f,   1.0362799581f,      0.980561720868f,   -0.619379186999f,
2275     -1.33818929924f,   -0.237233737961f,   -1.89335425073f,   0.567821011321f,
2276     0.862420368465f,   -1.37380916821f,    0.352190056666f,   0.611261516274f,
2277     0.393237747152f,   0.894686247967f,    0.190405182149f,   0.264872662911f,
2278     -0.0657009133797f, 0.0580512653493f,   -0.401825294366f,  0.4106081318f,
2279     0.49484512188f,    -0.0751103149442f,  -1.43243736382f,   1.79855656009f,
2280     -1.1075351975f,    0.000354882733011f, -0.950716438608f,  1.27129831688f,
2281     1.00495189838f,    0.110358656713f,    1.08315032822f,    -0.972676676218f,
2282     -0.0757668962831f, 1.88932045165f,     -0.0672638136275f, 0.425913010161f,
2283     -0.781540372017f,  0.976000248609f,    0.687218504122f,   1.31374513445f,
2284     -0.932658930672f,  -1.25339468479f,    0.422071294078f,   -0.24189927912f,
2285     0.216906604642f,   -1.88720997548f,    1.99252872889f,    0.353943735777f,
2286     0.737434784132f,   -1.17848645017f,    1.70424254896f,    0.775297112968f,
2287     -0.516392797501f,  0.398130609129f,    0.737248101457f,   0.166282500886f,
2288     1.24699015468f,    0.47116183125f,     1.19091180182f,    -0.372695424578f,
2289     0.219773209389f,   -0.829467838962f,   -0.52533122724f,   1.98707754595f,
2290     0.553692606972f,   -0.933228902369f,   1.55427751643f,    -1.08813399144f,
2291     -0.325686682094f,  0.205091443796f,    -1.70381666435f,   0.466465327942f,
2292     1.73126863447f,    -0.939133672634f,   1.48318077459f,    -0.599414038168f,
2293     -1.1583078687f,    0.518116190201f,    0.133571482458f,   0.84958342672f,
2294     1.02205000597f,    -0.0772082009087f,  -1.69567503859f,   1.4697939436f,
2295     1.67813743122f,    -0.627911582938f,   0.131380509137f,   -1.35717850726f,
2296   };
2297   const float *input[3] = { input_, &input_[image_dim * image_dim],
2298                             &input_[2 * image_dim * image_dim] };
2299 
2300   const float bias[] = { 0.0f, 0.0f };
2301 
2302   const float weights_1[] = {
2303     -0.489547413618f, 0.141916424749f,  -0.279286485585f,  -0.115322211094f,
2304     0.299572786936f,  0.205289980785f,  -0.536254480088f,  -0.253626313744f,
2305     -0.422883815849f, -0.169702966298f, -0.540104704793f,  0.495319646763f,
2306     0.298799079422f,  -0.10054550901f,  -0.306085047056f,  0.171061886165f,
2307     -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
2308     -0.157203423678f, -0.362138920529f, -0.216206085209f,  0.147502517971f,
2309   };
2310 
2311   const float weights_2[] = {
2312     0.207580604357f,  0.480821146263f,  -0.29111909562f,   0.47422567493f,
2313     0.206892553253f,  -0.235067084092f, 0.354516800602f,   -0.212399370252f,
2314     -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
2315     0.567044864811f,  -0.060341127522f, 0.0501464839637f,  -0.437785677916f,
2316   };
2317 
2318   const float weights_3[] = {
2319     -0.0690452401448f, -0.356657338763f,   -0.219464031809f, 0.551288365843f,
2320     0.181372090853f,   -0.00245268542109f, 0.409000696276f,  -0.593209108763f,
2321     0.587352566749f,   -0.243720660227f,   0.266232713887f,  -0.00439285245097f,
2322     0.252883228305f,   0.152646192631f,    0.0918944932026f, 0.398853715057f,
2323   };
2324 
2325   const float weights_4[] = {
2326     0.207560791573f,   0.194201350401f,   0.227802322443f,  0.206533663345f,
2327     0.0557331066805f,  0.0224159800424f,  -0.143939197467f, -0.27703361602f,
2328     0.130643888389f,   -0.269456557461f,  0.186242862864f,  -0.162879944774f,
2329     -0.145503996718f,  -0.0768822987581f, -0.203127976359f, -0.238119922873f,
2330     -0.258806479994f,  0.0357957680385f,  -0.1027606976f,   -0.287920082345f,
2331     0.189047820993f,   0.250711538481f,   -0.272815714175f, -0.0431449742024f,
2332     0.207261230996f,   -0.0396472677451f, 0.131236557412f,  0.174291832499f,
2333     -0.251515885765f,  -0.107164007499f,  0.185824534748f,  -0.00561585838161f,
2334     0.273393799578f,   -0.139563699075f,  -0.263922456031f, -0.118859844081f,
2335     0.109230982597f,   -0.170170294794f,  0.0123025648515f, -0.0839368964355f,
2336     -0.0774058234297f, 0.255847138286f,   -0.208430879637f, 0.279170114319f,
2337     -0.272890330712f,  -0.217725903006f,  -0.295923275459f, -0.17008723953f,
2338     -0.284281803405f,  0.281406323629f,   0.266910044663f,  -0.209963914338f,
2339     0.271980962964f,   0.142013581699f,   -0.143896509026f, -0.290509242975f,
2340     -0.305768180935f,  0.196902832117f,   -0.090424189662f, -0.147460802346f,
2341     0.217722016651f,   0.12353848977f,    -0.169177363577f, -0.0454230918512f,
2342   };
2343 
2344   const float expected_0[] = {
2345     -2.04858441055f,  -2.12883075791f,    -0.045177363807f, 0.763949675768f,
2346     -0.544361512821f, -1.58123168032f,    1.89319847039f,   0.16859080901f,
2347     -1.16023321135f,  -0.396988107751f,   1.76637090744f,   -1.40434786514f,
2348     0.908227575669f,  0.817064817605f,    0.215631134908f,  -0.848605613428f,
2349     -0.106756747018f, 0.0193027166685f,   0.801345615113f,  -0.395407237598f,
2350     -1.79983795658f,  -1.73054496242f,    0.0584392594454f, -0.388786095569f,
2351     -0.237269619354f, 0.000843578271263f, -1.24043512104f,  0.487839445893f,
2352     -0.394259726605f, 0.559632843424f,    -0.527224052291f, -1.53792340282f,
2353   };
2354 
2355   const float expected_1[] = {
2356     0.0f, 0.0f,           0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
2357     0.0f, 1.22013465602f,
2358   };
2359 
2360   const float expected_2[] = {
2361     0.156119444687f,
2362     0.517385299817f,
2363   };
2364 
2365   const float expected_3[] = {
2366     0.224177852984f,
2367     0.503384419034f,
2368     0.156119444687f,
2369     0.517385299817f,
2370   };
2371 
2372   const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
2373 
2374   CNN_CONFIG cnn_config = {
2375     4,  // num_layers
2376     0,  // is_residue
2377     0,  // ext_width
2378     0,  // ext_height
2379     0,  // strict_bounds
2380     {
2381         // layer_config
2382         {
2383             image_ch,           // in_channels
2384             filter_dim,         // filter_width
2385             filter_dim,         // filter_height
2386             num_filters,        // out_channels
2387             stride,             // skip_width
2388             stride,             // skip_height
2389             0,                  // max_pool
2390             weights_1,          // weights
2391             bias,               // bias
2392             PADDING_SAME_ZERO,  // pad
2393             NONE,               // activation
2394             0,                  // deconvolve
2395             0,                  // branch
2396             BRANCH_OUTPUT,      // branch_copy_type
2397             BRANCH_NOC,         // branch_combine_type
2398             { 2, 0, 0 },        // branch_config
2399             {},                 // bn_params
2400             0,                  // output_num
2401         },
2402         {
2403             num_filters,        // in_channels
2404             filter_dim,         // filter_width
2405             filter_dim,         // filter_height
2406             num_filters,        // out_channels
2407             stride,             // skip_width
2408             stride,             // skip_height
2409             0,                  // max_pool
2410             weights_2,          // weights
2411             bias,               // bias
2412             PADDING_SAME_ZERO,  // pad
2413             RELU,               // activation
2414             0,                  // deconvolve
2415             0,                  // branch
2416             BRANCH_NO_COPY,     // branch_copy_type
2417             BRANCH_NOC,         // branch_combine_type
2418             {},                 // branch_config
2419             {},                 // bn_params
2420             1,                  // output_num
2421         },
2422         {
2423             num_filters,        // in_channels
2424             filter_dim,         // filter_width
2425             filter_dim,         // filter_height
2426             num_filters,        // out_channels
2427             stride,             // skip_width
2428             stride,             // skip_height
2429             0,                  // max_pool
2430             weights_3,          // weights
2431             bias,               // bias
2432             PADDING_SAME_ZERO,  // pad
2433             RELU,               // activation
2434             0,                  // deconvolve
2435             0,                  // branch
2436             BRANCH_NO_COPY,     // branch_copy_type
2437             BRANCH_NOC,         // branch_combine_type
2438             {},                 // branch_config
2439             {},                 // bn_params
2440             2,                  // output_num
2441         },
2442         {
2443             num_filters,     // in_channels
2444             2 * filter_dim,  // filter_width
2445             2 * filter_dim,  // filter_height
2446             num_filters,     // out_channels
2447             2 * stride,      // skip_width
2448             2 * stride,      // skip_height
2449             0,               // max_pool
2450             weights_4,       // weights
2451             bias,            // bias
2452             PADDING_VALID,   // pad
2453             RELU,            // activation
2454             0,               // deconvolve
2455             1,               // branch
2456             BRANCH_NO_COPY,  // branch_copy_type
2457             BRANCH_CAT,      // branch_combine_type
2458             { 0, 0, 1 },     // branch_config
2459             {},              // bn_params
2460             3,               // output_num
2461         },
2462     },
2463   };
2464 
2465   CNN_THREAD_DATA thread_data = { 1, NULL };
2466 
2467   const int num_outputs = 4;
2468   const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
2469                               2 * filter_dim };
2470   const int output_dims[4] = { 4, 2, 1, 1 };
2471   const int output_sizes[4] = {
2472     output_chs[0] * output_dims[0] * output_dims[0],
2473     output_chs[1] * output_dims[1] * output_dims[1],
2474     output_chs[2] * output_dims[2] * output_dims[2],
2475     output_chs[3] * output_dims[3] * output_dims[3],
2476   };
2477   float *const output_ = (float *)aom_malloc(
2478       sizeof(*output_) *
2479       (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
2480   float *output[CNN_MAX_CHANNELS] = { nullptr };
2481   int ch_ite = 0;
2482   float *output_ite = output_;
2483   for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
2484     for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
2485       output[ch_ite++] = output_ite;
2486       output_ite += output_dims[output_idx] * output_dims[output_idx];
2487     }
2488   }
2489   CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
2490                                   output };
2491 
2492   RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
2493                      &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
2494 
2495   aom_free(output_);
2496 }
2497