1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 
15 #include "aom_dsp/aom_dsp_common.h"
16 #include "av1/encoder/cnn.h"
17 #include "av1/common/av1_common_int.h"
18 
19 #define CLAMPINDEX(a, hi) ((a) < 0 ? 0 : ((a) >= (hi) ? ((hi)-1) : (a)))
20 
21 typedef struct {
22   const float **input;
23   int in_width;
24   int in_height;
25   int in_stride;
26   const CNN_LAYER_CONFIG *layer_config;
27   float **output;
28   int out_stride;
29   int start_idx;
30   int th_step;
31 } CONVOLVE_OPS;
32 
33 typedef float (*activation_fn)(float);
34 
softsign(float x)35 static float softsign(float x) { return x / (float)(fabsf(x) + 1.0); }
36 
relu(float x)37 static float relu(float x) { return (x < 0) ? 0 : x; }
38 
identity(float x)39 static float identity(float x) { return x; }
40 
41 typedef struct {
42   int allocsize;
43   int channels;
44   int width, height, stride;
45   float *buf[CNN_MAX_CHANNELS];
46 } TENSOR;
47 
init_tensor(TENSOR * tensor)48 static void init_tensor(TENSOR *tensor) { memset(tensor, 0, sizeof(*tensor)); }
49 
free_tensor(TENSOR * tensor)50 static void free_tensor(TENSOR *tensor) {
51   if (tensor->allocsize) {
52     aom_free(tensor->buf[0]);
53     tensor->buf[0] = NULL;
54     tensor->allocsize = 0;
55   }
56 }
57 
realloc_tensor(TENSOR * tensor,int channels,int width,int height)58 static void realloc_tensor(TENSOR *tensor, int channels, int width,
59                            int height) {
60   const int newallocsize = channels * width * height;
61   if (tensor->allocsize < newallocsize) {
62     free_tensor(tensor);
63     tensor->buf[0] =
64         (float *)aom_malloc(sizeof(*tensor->buf[0]) * newallocsize);
65     tensor->allocsize = newallocsize;
66   }
67   tensor->width = width;
68   tensor->height = height;
69   tensor->stride = width;
70   tensor->channels = channels;
71   for (int c = 1; c < channels; ++c)
72     tensor->buf[c] = &tensor->buf[0][c * width * height];
73 }
74 
copy_tensor(const TENSOR * src,int copy_channels,int dst_offset,TENSOR * dst)75 static void copy_tensor(const TENSOR *src, int copy_channels, int dst_offset,
76                         TENSOR *dst) {
77   assert(src->width == dst->width);
78   assert(src->height == dst->height);
79   assert(copy_channels <= src->channels);
80   if (src->stride == dst->width && dst->stride == dst->width) {
81     for (int c = 0; c < copy_channels; ++c) {
82       memcpy(dst->buf[dst_offset + c], src->buf[c],
83              sizeof(*dst->buf[0]) * src->width * src->height);
84     }
85   } else {
86     for (int c = 0; c < copy_channels; ++c) {
87       for (int r = 0; r < dst->height; ++r) {
88         memcpy(&dst->buf[dst_offset + c][r * dst->stride],
89                &src->buf[c][r * src->stride],
90                dst->width * sizeof(*dst->buf[c]));
91       }
92     }
93   }
94 }
95 
assign_tensor(TENSOR * tensor,float * buf[CNN_MAX_CHANNELS],int channels,int width,int height,int stride)96 static void assign_tensor(TENSOR *tensor, float *buf[CNN_MAX_CHANNELS],
97                           int channels, int width, int height, int stride) {
98   tensor->allocsize = 0;
99   tensor->channels = channels;
100   tensor->width = width;
101   tensor->height = height;
102   tensor->stride = stride;
103   if (buf) {
104     for (int c = 0; c < channels; ++c) tensor->buf[c] = buf[c];
105   } else {
106     for (int c = 0; c < channels; ++c) tensor->buf[c] = NULL;
107   }
108 }
109 
swap_tensor(TENSOR * t1,TENSOR * t2)110 static void swap_tensor(TENSOR *t1, TENSOR *t2) {
111   TENSOR t = *t1;
112   *t1 = *t2;
113   *t2 = t;
114 }
115 
116 // The concatenated tensor goes into dst with first the channels in
117 // original dst followed by the channels in the src
concat_tensor(const TENSOR * src,TENSOR * dst)118 static void concat_tensor(const TENSOR *src, TENSOR *dst) {
119   assert(src->width == dst->width);
120   assert(src->height == dst->height);
121 
122   const int dst_channels = dst->channels;
123   const int channels = dst->channels + src->channels;
124   const int newallocsize = channels * dst->width * dst->height;
125   if (dst->allocsize < newallocsize) {
126     TENSOR t;
127     init_tensor(&t);
128     // allocate new buffers and copy first the dst channels
129     realloc_tensor(&t, channels, dst->width, dst->height);
130     copy_tensor(dst, dst->channels, 0, &t);
131     // Swap the tensors and free the old buffers
132     swap_tensor(dst, &t);
133     free_tensor(&t);
134   }
135   for (int c = 1; c < channels; ++c)
136     dst->buf[c] = &dst->buf[0][c * dst->width * dst->height];
137   // Copy the channels in src after the first dst_channels channels.
138   copy_tensor(src, src->channels, dst_channels, dst);
139 }
140 
check_tensor_equal_dims(TENSOR * t1,TENSOR * t2)141 int check_tensor_equal_dims(TENSOR *t1, TENSOR *t2) {
142   return (t1->width == t2->width && t1->height == t2->height);
143 }
144 
check_tensor_equal_size(TENSOR * t1,TENSOR * t2)145 int check_tensor_equal_size(TENSOR *t1, TENSOR *t2) {
146   return (t1->channels == t2->channels && t1->width == t2->width &&
147           t1->height == t2->height);
148 }
149 
find_layer_output_size(int in_width,int in_height,const CNN_LAYER_CONFIG * layer_config,int * out_width,int * out_height)150 static void find_layer_output_size(int in_width, int in_height,
151                                    const CNN_LAYER_CONFIG *layer_config,
152                                    int *out_width, int *out_height) {
153   if (!layer_config->deconvolve) {
154     switch (layer_config->pad) {
155       case PADDING_SAME_ZERO:
156       case PADDING_SAME_REPLICATE:
157         *out_width = (in_width + layer_config->skip_width - 1) /
158                      layer_config->skip_width;
159         *out_height = (in_height + layer_config->skip_height - 1) /
160                       layer_config->skip_height;
161         break;
162       case PADDING_VALID:
163         *out_width =
164             (in_width - layer_config->filter_width + layer_config->skip_width) /
165             layer_config->skip_width;
166         *out_height = (in_height - layer_config->filter_height +
167                        layer_config->skip_height) /
168                       layer_config->skip_height;
169         break;
170       default: assert(0 && "Unknown padding type");
171     }
172   } else {
173     switch (layer_config->pad) {
174       case PADDING_SAME_ZERO:
175       case PADDING_SAME_REPLICATE:
176         *out_width = in_width * layer_config->skip_width;
177         *out_height = in_height * layer_config->skip_height;
178         break;
179       case PADDING_VALID:
180         *out_width = (in_width - 1) * layer_config->skip_width +
181                      layer_config->filter_width;
182         *out_height = (in_height - 1) * layer_config->skip_height +
183                       layer_config->filter_height;
184         break;
185       default: assert(0 && "Unknown padding type");
186     }
187   }
188 }
189 
find_cnn_out_channels(const CNN_LAYER_CONFIG * layer_config,int channels_per_branch[])190 void find_cnn_out_channels(const CNN_LAYER_CONFIG *layer_config,
191                            int channels_per_branch[]) {
192   int branch = layer_config->branch;
193   const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
194   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
195     if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
196       if (layer_config->branch_copy_type == BRANCH_INPUT) {
197         channels_per_branch[b] = layer_config->in_channels;
198       } else if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
199         channels_per_branch[b] = layer_config->out_channels;
200       } else if (layer_config->branch_copy_type == BRANCH_COMBINED) {
201         channels_per_branch[b] = layer_config->out_channels;
202         for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
203           if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
204             assert(channels_per_branch[c] > 0);
205             channels_per_branch[b] += channels_per_branch[c];
206           }
207         }
208       }
209     }
210   }
211   channels_per_branch[branch] = layer_config->out_channels;
212   for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
213     if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
214       assert(channels_per_branch[c] > 0);
215       channels_per_branch[branch] += channels_per_branch[c];
216     }
217   }
218 }
219 
220 #if CONFIG_DEBUG
cnn_has_at_least_one_output(const CNN_CONFIG * cnn_config)221 static INLINE int cnn_has_at_least_one_output(const CNN_CONFIG *cnn_config) {
222   const int num_layers = cnn_config->num_layers;
223   const CNN_LAYER_CONFIG *layer_configs = cnn_config->layer_config;
224 
225   for (int idx = 0; idx < num_layers; idx++) {
226     if (layer_configs[idx].output_num != -1) {
227       return 1;
228     }
229   }
230   return 0;
231 }
232 #endif
233 
av1_find_cnn_output_size(int in_width,int in_height,const CNN_CONFIG * cnn_config,int * out_width,int * out_height,int * out_channels)234 void av1_find_cnn_output_size(int in_width, int in_height,
235                               const CNN_CONFIG *cnn_config, int *out_width,
236                               int *out_height, int *out_channels) {
237   int channels_per_branch[CNN_MAX_BRANCHES] = { 0 };
238   int i_width[CNN_MAX_BRANCHES] = { 0 };
239   int i_height[CNN_MAX_BRANCHES] = { 0 };
240   i_width[0] = in_width + cnn_config->ext_width * 2;
241   i_height[0] = in_height + cnn_config->ext_height * 2;
242 
243 #if CONFIG_DEBUG
244   assert(cnn_has_at_least_one_output(cnn_config));
245 #endif
246 
247   for (int i = 0; i < cnn_config->num_layers; ++i) {
248     const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[i];
249     const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
250     const int branch = layer_config->branch;
251     int o_width = 0, o_height = 0;
252 
253     if (layer_config->branch_copy_type == BRANCH_INPUT) {
254       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
255         if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
256           assert(i_width[branch] > 0 && i_height[branch] > 0);
257           i_width[b] = i_width[branch];
258           i_height[b] = i_height[branch];
259         }
260       }
261     }
262 
263     find_layer_output_size(i_width[branch], i_height[branch], layer_config,
264                            &o_width, &o_height);
265     i_width[branch] = o_width;
266     i_height[branch] = o_height;
267 
268     if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
269       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
270         if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
271           i_width[b] = o_width;
272           i_height[b] = o_height;
273         }
274       }
275     }
276 
277     find_cnn_out_channels(layer_config, channels_per_branch);
278 
279     const int output_num = layer_config->output_num;
280     if (output_num != -1) {  // Current layer is an output layer
281       out_width[output_num] = o_width;
282       out_height[output_num] = o_height;
283       out_channels[output_num] = channels_per_branch[layer_config->branch];
284     }
285   }
286 }
287 
get_activation(ACTIVATION layer_activation)288 activation_fn get_activation(ACTIVATION layer_activation) {
289   switch (layer_activation) {
290     case NONE: return identity;
291     case RELU: return relu;
292     case SOFTSIGN: return softsign;
293     case SIGMOID:
294       assert(0 && "Sigmoid has not been supported in CNN.");  // TO DO
295       return NULL;
296     default: assert(0 && "Unknown activation type"); return NULL;
297   }
298 }
299 
get_start_shift_convolve(int width,int filt_width,int stride)300 static INLINE int get_start_shift_convolve(int width, int filt_width,
301                                            int stride) {
302   const int mod = (width % stride);
303   const int filt_off = (filt_width - 1) / 2;
304   const int dif = (mod ? mod - 1 : stride - 1);
305   return AOMMIN((dif + (filt_width % 2)) / 2, filt_off);
306 }
307 
av1_cnn_add_c(float ** output,int channels,int width,int height,int stride,const float ** add)308 void av1_cnn_add_c(float **output, int channels, int width, int height,
309                    int stride, const float **add) {
310   for (int c = 0; c < channels; ++c) {
311     for (int i = 0; i < height; ++i)
312       for (int j = 0; j < width; ++j)
313         output[c][i * stride + j] += add[c][i * stride + j];
314   }
315 }
316 
av1_cnn_activate_c(float ** output,int channels,int width,int height,int stride,ACTIVATION layer_activation)317 void av1_cnn_activate_c(float **output, int channels, int width, int height,
318                         int stride, ACTIVATION layer_activation) {
319   activation_fn activation = get_activation(layer_activation);
320   for (int c = 0; c < channels; ++c) {
321     for (int i = 0; i < height; ++i)
322       for (int j = 0; j < width; ++j)
323         output[c][i * stride + j] = activation(output[c][i * stride + j]);
324   }
325 }
326 
copy_active_tensor_to_branches(const TENSOR * layer_active_tensor,const CNN_LAYER_CONFIG * layer_config,int branch,TENSOR branch_output[])327 static void copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
328                                            const CNN_LAYER_CONFIG *layer_config,
329                                            int branch, TENSOR branch_output[]) {
330   const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
331   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
332     if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
333       // Copy layer's active tensor to output tensor of branch b if set in
334       // mask. The output becomes the input of the first layer of the branch
335       // because the layer of the branch is not the first layer.
336       int copy_channels = branch_config->channels_to_copy > 0
337                               ? branch_config->channels_to_copy
338                               : layer_active_tensor->channels;
339       realloc_tensor(&branch_output[b], copy_channels,
340                      layer_active_tensor->width, layer_active_tensor->height);
341       copy_tensor(layer_active_tensor, copy_channels, 0, &branch_output[b]);
342     }
343   }
344 }
345 
convolve_layer(void * arg1,void * arg2)346 static int convolve_layer(void *arg1, void *arg2) {
347   const CONVOLVE_OPS *convolve_ops = arg1;
348   (void)arg2;
349   av1_cnn_convolve(
350       convolve_ops->input, convolve_ops->in_width, convolve_ops->in_height,
351       convolve_ops->in_stride, convolve_ops->layer_config, convolve_ops->output,
352       convolve_ops->out_stride, convolve_ops->start_idx, convolve_ops->th_step);
353   return 1;
354 }
355 
convolve_layer_mt(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,const CNN_THREAD_DATA * thread_data,float ** output,int out_stride)356 static void convolve_layer_mt(const float **input, int in_width, int in_height,
357                               int in_stride,
358                               const CNN_LAYER_CONFIG *layer_config,
359                               const CNN_THREAD_DATA *thread_data,
360                               float **output, int out_stride) {
361   const AVxWorkerInterface *const winterface = aom_get_worker_interface();
362   const int num_workers = thread_data->num_workers;
363 
364   CONVOLVE_OPS convolve_ops[CNN_MAX_THREADS];
365   for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
366     AVxWorker *const worker = &thread_data->workers[th];
367     winterface->reset(worker);
368 
369     CONVOLVE_OPS convolve_op = { input,      in_width,     in_height,
370                                  in_stride,  layer_config, output,
371                                  out_stride, th,           num_workers };
372     convolve_ops[th] = convolve_op;
373     worker->hook = convolve_layer;
374     worker->data1 = &(convolve_ops[th]);
375     worker->data2 = NULL;
376 
377     // Start convolving.
378     if (th == num_workers - 1) {
379       winterface->execute(worker);
380     } else {
381       winterface->launch(worker);
382     }
383   }
384 
385   // Wait until all workers have finished.
386   for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
387     winterface->sync(&thread_data->workers[th]);
388   }
389 }
390 
av1_cnn_convolve_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int step)391 void av1_cnn_convolve_c(const float **input, int in_width, int in_height,
392                         int in_stride, const CNN_LAYER_CONFIG *layer_config,
393                         float **output, int out_stride, int start_idx,
394                         int step) {
395   assert(!layer_config->deconvolve);
396   const int cstep = layer_config->in_channels * layer_config->out_channels;
397   const int filter_height_half = layer_config->filter_height >> 1;
398   const int filter_width_half = layer_config->filter_width >> 1;
399   const int channel_step = AOMMAX(step, 1);
400 
401   if (layer_config->maxpool &&
402       (layer_config->skip_height > 1 || layer_config->skip_width > 1)) {
403     switch (layer_config->pad) {
404       case PADDING_SAME_ZERO:
405         for (int i = 0; i < layer_config->out_channels; ++i) {
406           for (int h = 0, u = 0; h < in_height;
407                h += layer_config->skip_height, ++u) {
408             for (int w = 0, v = 0; w < in_width;
409                  w += layer_config->skip_width, ++v) {
410               for (int hh = h;
411                    hh < AOMMIN(in_height, h + layer_config->skip_height);
412                    ++hh) {
413                 for (int ww = w;
414                      ww < AOMMIN(in_width, w + layer_config->skip_width);
415                      ++ww) {
416                   float sum = layer_config->bias[i];
417                   for (int k = 0; k < layer_config->in_channels; ++k) {
418                     int off = k * layer_config->out_channels + i;
419                     for (int l = 0; l < layer_config->filter_height; ++l) {
420                       const int ii = hh + l - filter_height_half;
421                       for (int m = 0; m < layer_config->filter_width;
422                            ++m, off += cstep) {
423                         const int jj = ww + m - filter_width_half;
424                         if (ii < 0 || ii >= in_height || jj < 0 ||
425                             jj >= in_width)
426                           continue;
427                         sum += layer_config->weights[off] *
428                                input[k][ii * in_stride + jj];
429                       }
430                     }
431                   }
432                   const float a = sum;
433                   if (h == hh && w == ww)
434                     output[i][u * out_stride + v] = a;
435                   else
436                     output[i][u * out_stride + v] =
437                         AOMMAX(output[i][u * out_stride + v], a);
438                 }
439               }
440             }
441           }
442         }
443         break;
444       case PADDING_SAME_REPLICATE:
445         for (int i = 0; i < layer_config->out_channels; ++i) {
446           for (int h = 0, u = 0; h < in_height;
447                h += layer_config->skip_height, ++u) {
448             for (int w = 0, v = 0; w < in_width;
449                  w += layer_config->skip_width, ++v) {
450               for (int hh = h;
451                    hh < AOMMIN(in_height, h + layer_config->skip_height);
452                    ++hh) {
453                 for (int ww = w;
454                      ww < AOMMIN(in_width, w + layer_config->skip_width);
455                      ++ww) {
456                   float sum = layer_config->bias[i];
457                   for (int k = 0; k < layer_config->in_channels; ++k) {
458                     int off = k * layer_config->out_channels + i;
459                     for (int l = 0; l < layer_config->filter_height; ++l) {
460                       const int ii =
461                           CLAMPINDEX(hh + l - filter_height_half, in_height);
462                       for (int m = 0; m < layer_config->filter_width;
463                            ++m, off += cstep) {
464                         const int jj =
465                             CLAMPINDEX(ww + m - filter_width_half, in_width);
466                         assert(ii >= 0 && ii < in_height && jj >= 0 &&
467                                jj < in_width);
468                         sum += layer_config->weights[off] *
469                                input[k][ii * in_stride + jj];
470                       }
471                     }
472                   }
473                   const float a = sum;
474                   if (h == hh && w == ww)
475                     output[i][u * out_stride + v] = a;
476                   else
477                     output[i][u * out_stride + v] =
478                         AOMMAX(output[i][u * out_stride + v], a);
479                 }
480               }
481             }
482           }
483         }
484         break;
485       case PADDING_VALID:
486         for (int i = 0; i < layer_config->out_channels; ++i) {
487           for (int h = 0, u = 0;
488                h < in_height - layer_config->filter_height + 1;
489                h += layer_config->skip_height, ++u) {
490             for (int w = 0, v = 0;
491                  w < in_width - layer_config->filter_width + 1;
492                  w += layer_config->skip_width, ++v) {
493               for (int hh = h;
494                    hh < AOMMIN(in_height, h + layer_config->skip_height);
495                    ++hh) {
496                 for (int ww = w;
497                      ww < AOMMIN(in_width, w + layer_config->skip_width);
498                      ++ww) {
499                   float sum = layer_config->bias[i];
500                   for (int k = 0; k < layer_config->in_channels; ++k) {
501                     int off = k * layer_config->out_channels + i;
502                     for (int l = 0; l < layer_config->filter_height; ++l) {
503                       const int ii = hh + l;
504                       for (int m = 0; m < layer_config->filter_width;
505                            ++m, off += cstep) {
506                         const int jj = ww + m;
507                         assert(ii >= 0 && ii < in_height && jj >= 0 &&
508                                jj < in_width);
509                         sum += layer_config->weights[off] *
510                                input[k][ii * in_stride + jj];
511                       }
512                     }
513                   }
514                   const float a = sum;
515                   if (h == hh && w == ww)
516                     output[i][u * out_stride + v] = a;
517                   else
518                     output[i][u * out_stride + v] =
519                         AOMMAX(output[i][u * out_stride + v], a);
520                 }
521               }
522             }
523           }
524         }
525         break;
526       default: assert(0 && "Unknown padding type");
527     }
528   } else {
529     // Results in element-wise matrix multiplication.
530     if (layer_config->filter_height == 1 && layer_config->filter_width == 1) {
531       const int start_h = get_start_shift_convolve(
532           in_height, layer_config->filter_height, layer_config->skip_height);
533       const int start_w =
534           get_start_shift_convolve(in_width, layer_config->filter_width,
535                                    layer_config->skip_width) +
536           start_idx * layer_config->skip_width;
537       const int out_w_step = AOMMAX(step, 1);
538       const int in_w_step = layer_config->skip_width * out_w_step;
539       for (int i = 0; i < layer_config->out_channels; ++i) {
540         for (int h = start_h, u = 0; h < in_height;
541              h += layer_config->skip_height, ++u) {
542           const int in_h = h * in_stride;
543           const int out_h = u * out_stride + start_idx;
544           for (int w = start_w, out_index = out_h; w < in_width;
545                w += in_w_step, out_index += out_w_step) {
546             float sum = layer_config->bias[i];
547             for (int k = 0; k < layer_config->in_channels; ++k) {
548               sum += layer_config->weights[k * layer_config->out_channels + i] *
549                      input[k][in_h + w];
550             }
551             output[i][out_index] = sum;
552           }
553         }
554       }
555       return;
556     }
557     const int ii_shift =
558         filter_height_half - (layer_config->filter_height - 1) % 2;
559     const int jj_shift =
560         filter_width_half - (layer_config->filter_width - 1) % 2;
561     switch (layer_config->pad) {
562       case PADDING_SAME_ZERO: {
563         const int start_h = get_start_shift_convolve(
564             in_height, layer_config->filter_height, layer_config->skip_height);
565         const int start_w = get_start_shift_convolve(
566             in_width, layer_config->filter_width, layer_config->skip_width);
567         const int end_ii_shift = filter_height_half + 1;
568         const int end_jj_shift = filter_width_half + 1;
569         // *_filter_margin stores the number of pixels along a dimension in the
570         // intersection of the complement of the image in the extended image
571         // and the filter.
572         const int top_filter_margin = layer_config->filter_width * ii_shift;
573         const int right_filter_margin = end_jj_shift - in_width;
574         for (int i = start_idx; i < layer_config->out_channels;
575              i += channel_step) {
576           for (int h = start_h, u = 0; h < in_height;
577                h += layer_config->skip_height, ++u) {
578             const int out_h = u * out_stride;
579             const int top_cstep =
580                 AOMMAX(0, top_filter_margin - h * layer_config->filter_width) *
581                     cstep +
582                 i;
583             const int start_ii = AOMMAX(0, h - ii_shift);
584             const int end_ii = AOMMIN(in_height, h + end_ii_shift);
585             for (int w = start_w, out_index = out_h; w < in_width;
586                  w += layer_config->skip_width, ++out_index) {
587               const int left_cstep = AOMMAX(0, jj_shift - w) * cstep;
588               const int right_cstep =
589                   AOMMAX(0, right_filter_margin + w) * cstep;
590               const int start_jj = AOMMAX(0, w - jj_shift);
591               const int end_jj = AOMMIN(in_width, w + end_jj_shift);
592               float sum = layer_config->bias[i];
593               for (int k = 0; k < layer_config->in_channels; ++k) {
594                 int off = k * layer_config->out_channels + top_cstep;
595                 for (int ii = start_ii; ii < end_ii; ++ii) {
596                   off += left_cstep;
597                   for (int jj = start_jj; jj < end_jj; ++jj, off += cstep) {
598                     sum += layer_config->weights[off] *
599                            input[k][ii * in_stride + jj];
600                   }
601                   off += right_cstep;
602                 }
603               }
604               output[i][out_index] = sum;
605             }
606           }
607         }
608         break;
609       }
610       case PADDING_SAME_REPLICATE: {
611         // h and w are shifted to an offset coordinate system to reduce in-loop
612         // computation.
613         const int start_h =
614             get_start_shift_convolve(in_height, layer_config->filter_height,
615                                      layer_config->skip_height) -
616             ii_shift;
617         const int start_w =
618             get_start_shift_convolve(in_width, layer_config->filter_width,
619                                      layer_config->skip_width) -
620             jj_shift;
621         const int end_h = in_height - ii_shift;
622         const int end_w = in_width - jj_shift;
623         for (int i = start_idx; i < layer_config->out_channels;
624              i += channel_step) {
625           for (int h = start_h, u = 0; h < end_h;
626                h += layer_config->skip_height, ++u) {
627             const int out_h = u * out_stride;
628             const int upper_ii_index = layer_config->filter_height + h;
629             for (int w = start_w, out_index = out_h; w < end_w;
630                  w += layer_config->skip_width, ++out_index) {
631               const int upper_jj_index = layer_config->filter_width + w;
632               float sum = layer_config->bias[i];
633               for (int k = 0; k < layer_config->in_channels; ++k) {
634                 int off = k * layer_config->out_channels + i;
635                 for (int ii = h; ii < upper_ii_index; ++ii) {
636                   const int clamped_ii = CLAMPINDEX(ii, in_height);
637                   for (int jj = w; jj < upper_jj_index; ++jj) {
638                     const int clamped_jj = CLAMPINDEX(jj, in_width);
639                     assert(clamped_ii >= 0 && clamped_ii < in_height &&
640                            clamped_jj >= 0 && clamped_jj < in_width);
641                     sum += layer_config->weights[off] *
642                            input[k][clamped_ii * in_stride + clamped_jj];
643                     off += cstep;
644                   }
645                 }
646               }
647               output[i][out_index] = sum;
648             }
649           }
650         }
651         break;
652       }
653       case PADDING_VALID: {
654         for (int i = start_idx; i < layer_config->out_channels;
655              i += channel_step) {
656           for (int h = 0, u = 0;
657                h < in_height - layer_config->filter_height + 1;
658                h += layer_config->skip_height, ++u) {
659             const int out_h = u * out_stride;
660             const int upper_ii_index = layer_config->filter_height + h;
661             for (int w = 0, out_index = out_h;
662                  w < in_width - layer_config->filter_width + 1;
663                  w += layer_config->skip_width, ++out_index) {
664               const int upper_jj_index = layer_config->filter_width + w;
665               float sum = layer_config->bias[i];
666               for (int k = 0; k < layer_config->in_channels; ++k) {
667                 int off = k * layer_config->out_channels + i;
668                 for (int ii = h; ii < upper_ii_index; ++ii) {
669                   for (int jj = w; jj < upper_jj_index; ++jj) {
670                     assert(ii >= 0 && ii < in_height && jj >= 0 &&
671                            jj < in_width);
672                     sum += layer_config->weights[off] *
673                            input[k][ii * in_stride + jj];
674                     off += cstep;
675                   }
676                 }
677               }
678               output[i][out_index] = sum;
679             }
680           }
681         }
682         break;
683       }
684       default: assert(0 && "Unknown padding type");
685     }
686   }
687 }
688 
get_start_shift_deconvolve(int filt_width,int stride)689 static INLINE int get_start_shift_deconvolve(int filt_width, int stride) {
690   const int dif = AOMMAX(filt_width - stride, 0);
691   return dif / 2;
692 }
693 
av1_cnn_batchnorm_c(float ** image,int channels,int width,int height,int stride,const float * gamma,const float * beta,const float * mean,const float * std)694 void av1_cnn_batchnorm_c(float **image, int channels, int width, int height,
695                          int stride, const float *gamma, const float *beta,
696                          const float *mean, const float *std) {
697   assert(gamma && beta && beta && std && "batchnorm has null parameter!");
698   for (int ch = 0; ch < channels; ch++) {
699     const float ch_gamma = gamma[ch];
700     const float ch_beta = beta[ch];
701     const float ch_mean = mean[ch];
702     const float ch_std = std[ch];
703     float *image_row = image[ch];
704 
705     for (int row = 0; row < height; row++) {
706       for (int col = 0; col < width; col++) {
707         image_row[col] =
708             ch_gamma * (image_row[col] - ch_mean) / ch_std + ch_beta;
709       }
710       image_row += stride;
711     }
712   }
713 }
714 
av1_cnn_deconvolve_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride)715 void av1_cnn_deconvolve_c(const float **input, int in_width, int in_height,
716                           int in_stride, const CNN_LAYER_CONFIG *layer_config,
717                           float **output, int out_stride) {
718   assert(layer_config->deconvolve);
719 
720   const int cstep = layer_config->in_channels * layer_config->out_channels;
721 
722   int out_width = 0;
723   int out_height = 0;
724   find_layer_output_size(in_width, in_height, layer_config, &out_width,
725                          &out_height);
726   switch (layer_config->pad) {
727     case PADDING_SAME_ZERO:
728       for (int i = 0; i < layer_config->out_channels; ++i) {
729         for (int u = 0; u < out_height; ++u) {
730           for (int v = 0; v < out_width; ++v) {
731             float sum = layer_config->bias[i];
732             for (int k = 0; k < layer_config->in_channels; ++k) {
733               int off = k * layer_config->out_channels + i;
734               for (int l = 0; l < layer_config->filter_height; ++l) {
735                 const int h =
736                     u - l +
737                     get_start_shift_deconvolve(layer_config->filter_height,
738                                                layer_config->skip_height);
739                 for (int m = 0; m < layer_config->filter_width;
740                      ++m, off += cstep) {
741                   const int w =
742                       v - m +
743                       get_start_shift_deconvolve(layer_config->filter_width,
744                                                  layer_config->skip_width);
745                   if ((h % layer_config->skip_height) != 0 ||
746                       (w % layer_config->skip_width) != 0)
747                     continue;
748                   const int ii = h / layer_config->skip_height;
749                   const int jj = w / layer_config->skip_width;
750                   if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
751                     continue;
752                   sum += layer_config->weights[off] *
753                          input[k][ii * in_stride + jj];
754                 }
755               }
756             }
757             output[i][u * out_stride + v] = sum;
758           }
759         }
760       }
761       break;
762     case PADDING_SAME_REPLICATE:
763       for (int i = 0; i < layer_config->out_channels; ++i) {
764         for (int u = 0; u < out_height; ++u) {
765           for (int v = 0; v < out_width; ++v) {
766             float sum = layer_config->bias[i];
767             for (int k = 0; k < layer_config->in_channels; ++k) {
768               int off = k * layer_config->out_channels + i;
769               for (int l = 0; l < layer_config->filter_height; ++l) {
770                 const int h =
771                     u - l +
772                     get_start_shift_deconvolve(layer_config->filter_height,
773                                                layer_config->skip_height);
774                 for (int m = 0; m < layer_config->filter_width;
775                      ++m, off += cstep) {
776                   const int w =
777                       v - m +
778                       get_start_shift_deconvolve(layer_config->filter_width,
779                                                  layer_config->skip_width);
780                   if ((h % layer_config->skip_height) != 0 ||
781                       (w % layer_config->skip_width) != 0)
782                     continue;
783                   const int ii =
784                       CLAMPINDEX(h / layer_config->skip_height, in_height);
785                   const int jj =
786                       CLAMPINDEX(w / layer_config->skip_width, in_width);
787                   assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
788                   continue;
789                   sum += layer_config->weights[off] *
790                          input[k][ii * in_stride + jj];
791                 }
792               }
793             }
794             output[i][u * out_stride + v] = sum;
795           }
796         }
797       }
798       break;
799     case PADDING_VALID:
800       for (int i = 0; i < layer_config->out_channels; ++i) {
801         for (int u = 0; u < out_height; ++u) {
802           for (int v = 0; v < out_width; ++v) {
803             float sum = layer_config->bias[i];
804             for (int k = 0; k < layer_config->in_channels; ++k) {
805               int off = k * layer_config->out_channels + i;
806               for (int l = 0; l < layer_config->filter_height; ++l) {
807                 const int h = u - l;
808                 for (int m = 0; m < layer_config->filter_width;
809                      ++m, off += cstep) {
810                   const int w = v - m;
811                   if ((h % layer_config->skip_height) != 0 ||
812                       (w % layer_config->skip_width) != 0)
813                     continue;
814                   const int ii = h / layer_config->skip_height;
815                   const int jj = w / layer_config->skip_width;
816                   if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
817                     continue;
818                   sum += layer_config->weights[off] *
819                          input[k][ii * in_stride + jj];
820                 }
821               }
822             }
823             output[i][u * out_stride + v] = sum;
824           }
825         }
826       }
827       break;
828     default: assert(0 && "Unknown padding type");
829   }
830 }
831 
av1_cnn_predict_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output_struct)832 void av1_cnn_predict_c(const float **input, int in_width, int in_height,
833                        int in_stride, const CNN_CONFIG *cnn_config,
834                        const CNN_THREAD_DATA *thread_data,
835                        CNN_MULTI_OUT *output_struct) {
836   TENSOR tensor1[CNN_MAX_BRANCHES] = { 0 };
837   TENSOR tensor2[CNN_MAX_BRANCHES] = { 0 };
838 
839   float **output[CNN_MAX_BRANCHES];
840   const int *out_chs = output_struct->output_channels;
841   output[0] = output_struct->output_buffer;
842   for (int out_idx = 1; out_idx < output_struct->num_outputs; out_idx++) {
843     output[out_idx] = output[out_idx - 1] + out_chs[out_idx - 1];
844   }
845 
846   int i_width = in_width;
847   int i_height = in_height;
848   int o_width = 0, o_height = 0;
849   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
850     init_tensor(&tensor1[b]);
851     init_tensor(&tensor2[b]);
852   }
853 
854   const int *out_stride = output_struct->output_strides;
855   for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
856     const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
857     const int branch = layer_config->branch;
858     const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
859 
860     // Allocate input tensor
861     if (layer == 0) {       // First layer
862       assert(branch == 0);  // First layer must be primary branch
863       assign_tensor(&tensor1[branch], (float **)input,
864                     layer_config->in_channels, in_width, in_height, in_stride);
865     } else {  // Non-first layer
866       // Swap tensor1 and tensor2
867       swap_tensor(&tensor1[branch], &tensor2[branch]);
868 
869       i_width = tensor1[branch].width;
870       i_height = tensor1[branch].height;
871     }
872 
873     // Allocate output tensor
874     find_layer_output_size(i_width, i_height, layer_config, &o_width,
875                            &o_height);
876     const int output_num = layer_config->output_num;
877     if (output_num == -1) {  // Non-output layer
878       realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
879                      o_height);
880     } else {  // Output layer
881       free_tensor(&tensor2[branch]);
882       assign_tensor(&tensor2[branch], output[output_num],
883                     layer_config->out_channels, o_width, o_height,
884                     out_stride[output_num]);
885     }
886 
887     // If we are combining branches make sure that the branch to combine
888     // is different from the current branch.
889     assert(IMPLIES(layer_config->branch_combine_type != BRANCH_NOC,
890                    !(branch_config->branches_to_combine & (1 << branch))));
891 
892     if (layer_config->branch_copy_type == BRANCH_INPUT) {
893       copy_active_tensor_to_branches(&tensor1[branch], layer_config, branch,
894                                      tensor2);
895     }
896     // Check consistency of input and output channels
897     assert(tensor1[branch].channels == layer_config->in_channels);
898     assert(tensor2[branch].channels == layer_config->out_channels);
899 
900     // Convolve/Deconvolve
901     if (!cnn_config->layer_config[layer].deconvolve) {
902       if (thread_data->num_workers > 1) {
903         convolve_layer_mt((const float **)tensor1[branch].buf,
904                           tensor1[branch].width, tensor1[branch].height,
905                           tensor1[branch].stride, layer_config, thread_data,
906                           tensor2[branch].buf, tensor2[branch].stride);
907       } else {
908         av1_cnn_convolve((const float **)tensor1[branch].buf,
909                          tensor1[branch].width, tensor1[branch].height,
910                          tensor1[branch].stride, layer_config,
911                          tensor2[branch].buf, tensor2[branch].stride, 0, 1);
912       }
913     } else {
914       av1_cnn_deconvolve((const float **)tensor1[branch].buf,
915                          tensor1[branch].width, tensor1[branch].height,
916                          tensor1[branch].stride, layer_config,
917                          tensor2[branch].buf, tensor2[branch].stride);
918     }
919 
920     if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
921       copy_active_tensor_to_branches(&tensor2[branch], layer_config, branch,
922                                      tensor2);
923     }
924 
925     // Add tensors from other branches if needed
926     if (layer_config->branch_combine_type == BRANCH_ADD) {
927       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
928         if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
929           assert(check_tensor_equal_size(&tensor2[b], &tensor2[branch]));
930           av1_cnn_add(tensor2[branch].buf, tensor2[branch].channels,
931                       tensor2[branch].width, tensor2[branch].height,
932                       tensor2[branch].stride, (const float **)tensor2[b].buf);
933         }
934       }
935     }
936 
937     // Non-linearity
938     if (layer_config->activation != IDENTITY)
939       av1_cnn_activate(tensor2[branch].buf, tensor2[branch].channels,
940                        tensor2[branch].width, tensor2[branch].height,
941                        tensor2[branch].stride, layer_config->activation);
942 
943     if (layer_config->bn_params.bn_gamma) {
944       av1_cnn_batchnorm(
945           tensor2[branch].buf, tensor2[branch].channels, tensor2[branch].width,
946           tensor2[branch].height, tensor2[branch].stride,
947           layer_config->bn_params.bn_gamma, layer_config->bn_params.bn_beta,
948           layer_config->bn_params.bn_mean, layer_config->bn_params.bn_std);
949     }
950 
951     // Concatenate tensors
952     if (layer_config->branch_combine_type == BRANCH_CAT) {
953       if (output_num == -1) {  // Non-output layer
954         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
955           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
956             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
957             assert(tensor2[b].channels > 0);
958             concat_tensor(&tensor2[b], &tensor2[branch]);
959           }
960         }
961       } else {  // Output layer
962         const int existing_channels = tensor2[branch].channels;
963         int num_chs = existing_channels;
964         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
965           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
966             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
967             // Needed only to assign the new channel buffers
968             num_chs += tensor2[b].channels;
969           }
970         }
971         assign_tensor(&tensor2[branch], output[output_num], num_chs, o_width,
972                       o_height, out_stride[output_num]);
973 
974         num_chs = existing_channels;
975         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
976           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
977             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
978             // Needed only to assign the new channel buffers
979             copy_tensor(&tensor2[b], tensor2[b].channels, num_chs,
980                         &tensor2[branch]);
981             num_chs += tensor2[b].channels;
982           }
983         }
984       }
985     }
986 
987     if (layer_config->branch_copy_type == BRANCH_COMBINED) {
988       copy_active_tensor_to_branches(&tensor2[branch], layer_config, branch,
989                                      tensor2);
990     }
991   }
992 
993   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
994     free_tensor(&tensor1[b]);
995     free_tensor(&tensor2[b]);
996   }
997 }
998 
999 // Assume output already has proper allocation
1000 // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_multi_out(uint8_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output)1001 void av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
1002                                    int stride, const CNN_CONFIG *cnn_config,
1003                                    const CNN_THREAD_DATA *thread_data,
1004                                    CNN_MULTI_OUT *output) {
1005   const float max_val = 255.0;
1006 
1007   const int in_width = width + 2 * cnn_config->ext_width;
1008   const int in_height = height + 2 * cnn_config->ext_height;
1009   const int in_channels = cnn_config->layer_config[0].in_channels;
1010   float *inputs[CNN_MAX_CHANNELS];
1011   float *input_ =
1012       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
1013   const int in_stride = in_width;
1014 
1015   for (int c = 0; c < in_channels; ++c) {
1016     inputs[c] = input_ + c * in_stride * in_height;
1017     float *input =
1018         inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
1019 
1020     if (cnn_config->strict_bounds) {
1021       for (int i = 0; i < height; ++i)
1022         for (int j = 0; j < width; ++j)
1023           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1024       // extend left and right
1025       for (int i = 0; i < height; ++i) {
1026         for (int j = -cnn_config->ext_width; j < 0; ++j)
1027           input[i * in_stride + j] = input[i * in_stride];
1028         for (int j = width; j < width + cnn_config->ext_width; ++j)
1029           input[i * in_stride + j] = input[i * in_stride + width - 1];
1030       }
1031       // extend top and bottom
1032       for (int i = -cnn_config->ext_height; i < 0; ++i)
1033         memcpy(&input[i * in_stride - cnn_config->ext_width],
1034                &input[-cnn_config->ext_width], in_width * sizeof(*input));
1035       for (int i = height; i < height + cnn_config->ext_height; ++i)
1036         memcpy(&input[i * in_stride - cnn_config->ext_width],
1037                &input[(height - 1) * in_stride - cnn_config->ext_width],
1038                in_width * sizeof(*input));
1039     } else {
1040       for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
1041            ++i)
1042         for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
1043              ++j)
1044           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1045     }
1046   }
1047   av1_cnn_predict((const float **)inputs, in_width, in_height, in_stride,
1048                   cnn_config, thread_data, output);
1049 
1050   aom_free(input_);
1051 }
1052 
1053 // Assume output already has proper allocation
1054 // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_multi_out_highbd(uint16_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,int bit_depth,CNN_MULTI_OUT * output)1055 void av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
1056                                           int stride,
1057                                           const CNN_CONFIG *cnn_config,
1058                                           const CNN_THREAD_DATA *thread_data,
1059                                           int bit_depth,
1060                                           CNN_MULTI_OUT *output) {
1061   const float max_val = (float)((1 << bit_depth) - 1);
1062 
1063   const int in_width = width + 2 * cnn_config->ext_width;
1064   const int in_height = height + 2 * cnn_config->ext_height;
1065   const int in_channels = cnn_config->layer_config[0].in_channels;
1066   float *inputs[CNN_MAX_CHANNELS];
1067   float *input_ =
1068       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
1069   const int in_stride = in_width;
1070 
1071   for (int c = 0; c < in_channels; ++c) {
1072     inputs[c] = input_ + c * in_stride * in_height;
1073     float *input =
1074         inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
1075 
1076     if (cnn_config->strict_bounds) {
1077       for (int i = 0; i < height; ++i)
1078         for (int j = 0; j < width; ++j)
1079           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1080       // extend left and right
1081       for (int i = 0; i < height; ++i) {
1082         for (int j = -cnn_config->ext_width; j < 0; ++j)
1083           input[i * in_stride + j] = input[i * in_stride];
1084         for (int j = width; j < width + cnn_config->ext_width; ++j)
1085           input[i * in_stride + j] = input[i * in_stride + width - 1];
1086       }
1087       // extend top and bottom
1088       for (int i = -cnn_config->ext_height; i < 0; ++i)
1089         memcpy(&input[i * in_stride - cnn_config->ext_width],
1090                &input[-cnn_config->ext_width], in_width * sizeof(*input));
1091       for (int i = height; i < height + cnn_config->ext_height; ++i)
1092         memcpy(&input[i * in_stride - cnn_config->ext_width],
1093                &input[(height - 1) * in_stride - cnn_config->ext_width],
1094                in_width * sizeof(*input));
1095     } else {
1096       for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
1097            ++i)
1098         for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
1099              ++j)
1100           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1101     }
1102   }
1103 
1104   av1_cnn_predict((const float **)inputs, in_width, in_height, in_stride,
1105                   cnn_config, thread_data, output);
1106 
1107   aom_free(input_);
1108 }
1109 
1110 // Assume output already has proper allocation
1111 // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img(uint8_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,float ** output,int out_stride)1112 void av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
1113                          const CNN_CONFIG *cnn_config,
1114                          const CNN_THREAD_DATA *thread_data, float **output,
1115                          int out_stride) {
1116   int out_width = 0, out_height = 0, out_channels = 0;
1117   av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height,
1118                            &out_channels);
1119   const int output_chs[1] = { out_channels };
1120   const int output_strides[1] = { out_stride };
1121   CNN_MULTI_OUT output_struct = { .output_channels = output_chs,
1122                                   .output_strides = output_strides,
1123                                   .output_buffer = output };
1124   av1_cnn_predict_img_multi_out(dgd, width, height, stride, cnn_config,
1125                                 thread_data, &output_struct);
1126 }
1127 
1128 // Assume output already has proper allocation
1129 // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_highbd(uint16_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,int bit_depth,float ** output,int out_stride)1130 void av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
1131                                 int stride, const CNN_CONFIG *cnn_config,
1132                                 const CNN_THREAD_DATA *thread_data,
1133                                 int bit_depth, float **output, int out_stride) {
1134   int out_width = 0, out_height = 0, out_channels = 0;
1135   av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height,
1136                            &out_channels);
1137   const int output_chs[1] = { out_channels };
1138   const int output_strides[1] = { out_stride };
1139   CNN_MULTI_OUT output_struct = { .output_channels = output_chs,
1140                                   .output_strides = output_strides,
1141                                   .output_buffer = output };
1142   av1_cnn_predict_img_multi_out_highbd(dgd, width, height, stride, cnn_config,
1143                                        thread_data, bit_depth, &output_struct);
1144 }
1145