1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <float.h>
13 
14 #include "aom_ports/system_state.h"
15 
16 #include "av1/common/enums.h"
17 #include "av1/common/reconinter.h"
18 
19 #include "av1/encoder/encoder.h"
20 #include "av1/encoder/partition_model_weights.h"
21 #include "av1/encoder/partition_strategy.h"
22 #include "av1/encoder/rdopt.h"
23 
24 // Performs a simple_motion_search with a single reference frame and extract
25 // the variance of residues. Here features is assumed to be a length 6 array.
26 // After this function is called, we will store the following in to features:
27 // features[0] = log(1 + dc_q**2/256)
28 // features[1] = log(1 + variance_of_residue)
29 // for i in [2, 3, 4, 5]:
30 //  features[i] = log(1 + variance_of_residue_in_block[i]/variance_of_residue)
get_res_var_features(AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,BLOCK_SIZE bsize,float * features)31 static void get_res_var_features(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
32                                  int mi_col, BLOCK_SIZE bsize,
33                                  float *features) {
34   // TODO(chiyotsai@google.com): The data this model trained on did not also use
35   // SIMPLE_TRANSLATION to build the inter_predictor. Retraining and tuning the
36   // model with the correct data should give better performance.
37   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
38 
39   MACROBLOCKD *xd = &x->e_mbd;
40 
41   // Perform a single motion search in Y_PLANE to make a prediction
42   const int use_subpixel = 0;
43 
44   // Start getting the features
45   int f_idx = 0;
46 
47   // Q_INDEX
48   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
49   aom_clear_system_state();
50   features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
51 
52   // VARIANCE
53   unsigned int sse = 0;
54   unsigned int var = 0;
55   const MV ref_mv_full = { .row = 0, .col = 0 };
56   av1_simple_motion_sse_var(cpi, x, mi_row, mi_col, bsize, ref_mv_full,
57                             use_subpixel, &sse, &var);
58   aom_clear_system_state();
59   features[f_idx++] = logf(1.0f + (float)var);
60 
61   // Regional
62   const uint8_t *src = x->plane[0].src.buf;
63   const int src_stride = x->plane[0].src.stride;
64   const uint8_t *dst = xd->plane[0].dst.buf;
65   const int dst_stride = xd->plane[0].dst.stride;
66   const int bw = block_size_wide[bsize];
67   const int bh = block_size_high[bsize];
68   const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
69   int r_idx = 0;
70   for (r_idx = 0; r_idx < 4; r_idx++) {
71     const int x_idx = (r_idx & 1) * bw / 2;
72     const int y_idx = (r_idx >> 1) * bh / 2;
73     const int src_offset = y_idx * src_stride + x_idx;
74     const int dst_offset = y_idx * dst_stride + x_idx;
75     const unsigned int sub_var = cpi->fn_ptr[subsize].vf(
76         src + src_offset, src_stride, dst + dst_offset, dst_stride, &sse);
77     aom_clear_system_state();
78     const float var_ratio = (1.0f + (float)sub_var) / (4.0f + (float)var);
79     features[f_idx++] = var_ratio;
80   }
81 }
82 
av1_simple_motion_search_based_split(AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,BLOCK_SIZE bsize,int * partition_none_allowed,int * partition_horz_allowed,int * partition_vert_allowed,int * do_rectangular_split,int * do_square_split)83 void av1_simple_motion_search_based_split(
84     AV1_COMP *const cpi, MACROBLOCK *x, int mi_row, int mi_col,
85     BLOCK_SIZE bsize, int *partition_none_allowed, int *partition_horz_allowed,
86     int *partition_vert_allowed, int *do_rectangular_split,
87     int *do_square_split) {
88   const NN_CONFIG *nn_config = NULL;
89   float split_only_thresh = 0.0f;
90   if (bsize == BLOCK_128X128) {
91     nn_config = &av1_simple_motion_search_based_split_nn_config_128;
92     split_only_thresh = av1_simple_motion_search_based_split_thresh_128;
93   } else if (bsize == BLOCK_64X64) {
94     nn_config = &av1_simple_motion_search_based_split_nn_config_64;
95     split_only_thresh = av1_simple_motion_search_based_split_thresh_64;
96   } else if (bsize == BLOCK_32X32) {
97     nn_config = &av1_simple_motion_search_based_split_nn_config_32;
98     split_only_thresh = av1_simple_motion_search_based_split_thresh_32;
99   } else if (bsize == BLOCK_16X16) {
100     nn_config = &av1_simple_motion_search_based_split_nn_config_16;
101     split_only_thresh = av1_simple_motion_search_based_split_thresh_16;
102   } else if (bsize == BLOCK_8X8) {
103     // Disable BLOCK_8X8 for now
104 #if !CONFIG_DISABLE_FULL_PIXEL_SPLIT_8X8
105     nn_config = &av1_simple_motion_search_based_split_nn_config_8;
106     split_only_thresh = av1_simple_motion_search_based_split_thresh_8;
107 #endif
108   } else {
109     assert(0 && "Unexpected block size in simple_motion_based_split");
110   }
111   if (nn_config) {
112     float features[6] = { 0 };
113     float score = 0;
114     get_res_var_features(cpi, x, mi_row, mi_col, bsize, features);
115     av1_nn_predict(features, nn_config, &score);
116 
117     if (score > split_only_thresh) {
118       *partition_none_allowed = 0;
119       *partition_horz_allowed = 0;
120       *partition_vert_allowed = 0;
121       *do_rectangular_split = 0;
122     }
123     if (cpi->sf.simple_motion_search_split_only >= 2) {
124       if (score < -split_only_thresh) *do_square_split = 0;
125       // For larger scores (>split_only_thresh), none and rectangular partitions
126       // are skipped. As score reduces, possibility of split decreases. Hence
127       // for near larger scores (.875 * split_only_thresh to split_only_thresh)
128       // none partition is disabled, but rectangular partitions are evaluated
129       // additionally.
130       if (score > (split_only_thresh * 0.875)) *partition_none_allowed = 0;
131     }
132   }
133 }
134 
135 // Given a list of ref frames in refs, performs simple_motion_search on each of
136 // the refs and returns the ref with the smallest sse. Returns -1 if none of the
137 // ref in the list is available. Also stores the best sse and var in best_sse,
138 // best_var, respectively. If save_mv_code is -1, don't update mv_ref_fulls in
139 // pc_tree. If save_mv_code is between 0 and 3, update mv_ref_fulls under
140 // pc_tree->split[i]. If save_mv_code is 4, update mv_ref_fulls under pc_tree.
simple_motion_search_get_best_ref(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,const int * const refs,int num_refs,int use_subpixel,int save_mv_code,unsigned int * best_sse,unsigned int * best_var)141 static int simple_motion_search_get_best_ref(
142     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
143     int mi_col, BLOCK_SIZE bsize, const int *const refs, int num_refs,
144     int use_subpixel, int save_mv_code, unsigned int *best_sse,
145     unsigned int *best_var) {
146   // TODO(chiyotsai@google.com): The calculation of variance currently uses
147   // bsize, so we might take area outside of the image into account. We need to
148   // modify the SIMD functions to fix this later.
149   const AV1_COMMON *const cm = &cpi->common;
150   int best_ref = -1;
151 
152   if (mi_col >= cm->mi_cols || mi_row >= cm->mi_rows) {
153     // If the whole block is outside of the image, set the var and sse to 0.
154     *best_var = 0;
155     *best_sse = 0;
156 
157     return best_ref;
158   }
159 
160   // Otherwise do loop through the reference frames and find the one with the
161   // minimum SSE
162   const MACROBLOCKD *xd = &x->e_mbd;
163   const MV *mv_ref_fulls = pc_tree->mv_ref_fulls;
164 
165   const int num_planes = 1;
166 
167   *best_sse = INT_MAX;
168 
169   for (int ref_idx = 0; ref_idx < num_refs; ref_idx++) {
170     const int ref = refs[ref_idx];
171 
172     if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref]) {
173       unsigned int curr_sse = 0, curr_var = 0;
174       av1_simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref,
175                                mv_ref_fulls[ref], num_planes, use_subpixel);
176       curr_var = cpi->fn_ptr[bsize].vf(
177           x->plane[0].src.buf, x->plane[0].src.stride, xd->plane[0].dst.buf,
178           xd->plane[0].dst.stride, &curr_sse);
179       if (curr_sse < *best_sse) {
180         *best_sse = curr_sse;
181         *best_var = curr_var;
182         best_ref = ref;
183       }
184 
185       const int new_mv_row = x->best_mv.as_mv.row / 8;
186       const int new_mv_col = x->best_mv.as_mv.col / 8;
187       if (save_mv_code == 4) {
188         pc_tree->mv_ref_fulls[ref].row = new_mv_row;
189         pc_tree->mv_ref_fulls[ref].col = new_mv_col;
190       } else if (save_mv_code >= 0 && save_mv_code < 4) {
191         // Propagate the new motion vectors to a lower level
192         pc_tree->split[save_mv_code]->mv_ref_fulls[ref].row = new_mv_row;
193         pc_tree->split[save_mv_code]->mv_ref_fulls[ref].col = new_mv_col;
194       } else {
195         assert(save_mv_code == -1 &&
196                "Unknown code in simple_motion_search_get_best_ref.");
197       }
198     }
199   }
200 
201   return best_ref;
202 }
203 
204 // Performs fullpixel simple_motion_search with LAST_FRAME and ALTREF_FRAME on
205 // each subblock and extract the variance and sse of residues. Then store the
206 // var and sse from each partition subblock to features. The DC qindex is also
207 // stored in features.
208 // Here features is assumed to be a length 19 array.
209 // After this function is called, we will store the following to features:
210 // features[0:17] = var and sse from subblocks
211 // features[18] = DC q_index
simple_motion_search_prune_part_features(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,float * features)212 static void simple_motion_search_prune_part_features(
213     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
214     int mi_col, BLOCK_SIZE bsize, float *features) {
215   // TODO(chiyotsai@google.com): Cache the result of the motion search from the
216   // larger bsize.
217   const int w_mi = mi_size_wide[bsize];
218   const int h_mi = mi_size_high[bsize];
219   int f_idx = 0;
220   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
221   assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
222          cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
223 
224   // Setting up motion search
225   const int ref_list[] = { LAST_FRAME, ALTREF_FRAME };
226   const int num_refs = 2;
227   const int use_subpixel = 1;
228 
229   unsigned int int_features[FEATURE_SIZE_SMS_PRUNE_PART - 1];
230 
231   // Doing whole block first to update the mv
232   simple_motion_search_get_best_ref(
233       cpi, x, pc_tree, mi_row, mi_col, bsize, ref_list, num_refs, use_subpixel,
234       4, &int_features[f_idx], &int_features[f_idx + 1]);
235   f_idx += 2;
236 
237   // Split subblocks
238   BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
239   int r_idx = 0;
240   for (r_idx = 0; r_idx < 4; r_idx++) {
241     const int sub_mi_col = mi_col + (r_idx & 1) * w_mi / 2;
242     const int sub_mi_row = mi_row + (r_idx >> 1) * h_mi / 2;
243 
244     simple_motion_search_get_best_ref(
245         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
246         use_subpixel, r_idx, &int_features[f_idx], &int_features[f_idx + 1]);
247     f_idx += 2;
248   }
249 
250   // Horz subblocks
251   subsize = get_partition_subsize(bsize, PARTITION_HORZ);
252   for (r_idx = 0; r_idx < 2; r_idx++) {
253     const int sub_mi_col = mi_col + 0;
254     const int sub_mi_row = mi_row + r_idx * h_mi / 2;
255 
256     simple_motion_search_get_best_ref(
257         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
258         use_subpixel, -1, &int_features[f_idx], &int_features[f_idx + 1]);
259 
260     f_idx += 2;
261   }
262 
263   // Vert subblock
264   subsize = get_partition_subsize(bsize, PARTITION_VERT);
265   for (r_idx = 0; r_idx < 2; r_idx++) {
266     const int sub_mi_col = mi_col + r_idx * w_mi / 2;
267     const int sub_mi_row = mi_row + 0;
268 
269     simple_motion_search_get_best_ref(
270         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
271         use_subpixel, -1, &int_features[f_idx], &int_features[f_idx + 1]);
272 
273     f_idx += 2;
274   }
275 
276   aom_clear_system_state();
277   for (int idx = 0; idx < f_idx; idx++) {
278     features[idx] = logf(1.0f + (float)int_features[idx]);
279   }
280 
281   const MACROBLOCKD *xd = &x->e_mbd;
282   set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
283 
284   // Q_INDEX
285   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
286   features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
287 
288   // Neighbor stuff
289   const int has_above = !!xd->above_mbmi;
290   const int has_left = !!xd->left_mbmi;
291   const BLOCK_SIZE above_bsize = has_above ? xd->above_mbmi->sb_type : bsize;
292   const BLOCK_SIZE left_bsize = has_left ? xd->left_mbmi->sb_type : bsize;
293   features[f_idx++] = (float)has_above;
294   features[f_idx++] = (float)mi_size_wide_log2[above_bsize];
295   features[f_idx++] = (float)mi_size_high_log2[above_bsize];
296   features[f_idx++] = (float)has_left;
297   features[f_idx++] = (float)mi_size_wide_log2[left_bsize];
298   features[f_idx++] = (float)mi_size_high_log2[left_bsize];
299 
300   assert(f_idx == FEATURE_SIZE_SMS_PRUNE_PART);
301 }
302 
av1_simple_motion_search_prune_part(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,int * partition_none_allowed,int * partition_horz_allowed,int * partition_vert_allowed,int * do_square_split,int * do_rectangular_split,int * prune_horz,int * prune_vert,float * features,int * valid)303 void av1_simple_motion_search_prune_part(
304     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
305     int mi_col, BLOCK_SIZE bsize, int *partition_none_allowed,
306     int *partition_horz_allowed, int *partition_vert_allowed,
307     int *do_square_split, int *do_rectangular_split, int *prune_horz,
308     int *prune_vert, float *features, int *valid) {
309   const AV1_COMMON *const cm = &cpi->common;
310   // Get model parameters
311   const NN_CONFIG *nn_config = NULL;
312   const float *prune_thresh = NULL, *only_thresh = NULL;
313   const float *ml_mean = NULL, *ml_std = NULL;
314   float normalized_features[FEATURE_SIZE_SMS_PRUNE_PART] = { 0.0f };
315 
316   if (bsize == BLOCK_128X128) {
317     nn_config = &av1_simple_motion_search_prune_part_nn_config_128;
318     ml_mean = av1_simple_motion_search_prune_part_mean_128;
319     ml_std = av1_simple_motion_search_prune_part_std_128;
320     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_128;
321     only_thresh = av1_simple_motion_search_prune_part_only_thresh_128;
322   } else if (bsize == BLOCK_64X64) {
323     nn_config = &av1_simple_motion_search_prune_part_nn_config_64;
324     ml_mean = av1_simple_motion_search_prune_part_mean_64;
325     ml_std = av1_simple_motion_search_prune_part_std_64;
326     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_64;
327     only_thresh = av1_simple_motion_search_prune_part_only_thresh_64;
328   } else if (bsize == BLOCK_32X32) {
329     nn_config = &av1_simple_motion_search_prune_part_nn_config_32;
330     ml_mean = av1_simple_motion_search_prune_part_mean_32;
331     ml_std = av1_simple_motion_search_prune_part_std_32;
332     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_32;
333     only_thresh = av1_simple_motion_search_prune_part_only_thresh_32;
334   } else if (bsize == BLOCK_16X16) {
335     nn_config = &av1_simple_motion_search_prune_part_nn_config_16;
336     ml_mean = av1_simple_motion_search_prune_part_mean_16;
337     ml_std = av1_simple_motion_search_prune_part_std_16;
338     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_16;
339     only_thresh = av1_simple_motion_search_prune_part_only_thresh_16;
340   } else if (bsize == BLOCK_8X8) {
341     nn_config = &av1_simple_motion_search_prune_part_nn_config_8;
342     ml_mean = av1_simple_motion_search_prune_part_mean_8;
343     ml_std = av1_simple_motion_search_prune_part_std_8;
344     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_8;
345     only_thresh = av1_simple_motion_search_prune_part_only_thresh_8;
346   } else {
347     assert(0 && "Unexpected block size in simple_motion_prune_part");
348   }
349 
350   // If there is no valid threshold, return immediately.
351   if (!nn_config || (prune_thresh[PARTITION_HORZ] == 0.0f &&
352                      prune_thresh[PARTITION_VERT] == 0.0f)) {
353     return;
354   }
355   if (bsize < BLOCK_8X8) {
356     return;
357   }
358 
359   // Get features
360   simple_motion_search_prune_part_features(cpi, x, pc_tree, mi_row, mi_col,
361                                            bsize, features);
362   *valid = 1;
363   for (int f_idx = 0; f_idx < FEATURE_SIZE_SMS_PRUNE_PART; f_idx++) {
364     normalized_features[f_idx] =
365         (features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
366   }
367 
368   // Get probabilities
369   float scores[EXT_PARTITION_TYPES] = { 0.0f },
370         probs[EXT_PARTITION_TYPES] = { 0.0f };
371   const int num_classes = (bsize == BLOCK_128X128 || bsize == BLOCK_8X8)
372                               ? PARTITION_TYPES
373                               : EXT_PARTITION_TYPES;
374 
375   av1_nn_predict(normalized_features, nn_config, scores);
376   aom_clear_system_state();
377 
378   av1_nn_softmax(scores, probs, num_classes);
379 
380   // Determine if we should prune rectangular partitions.
381   if (cpi->sf.simple_motion_search_prune_rect && !frame_is_intra_only(cm) &&
382       (*partition_horz_allowed || *partition_vert_allowed) &&
383       bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
384     *prune_horz = probs[PARTITION_HORZ] <= prune_thresh[PARTITION_HORZ];
385     *prune_vert = probs[PARTITION_VERT] <= prune_thresh[PARTITION_VERT];
386   }
387 
388   // Silence compiler warnings
389   (void)only_thresh;
390   (void)partition_none_allowed;
391   (void)do_square_split;
392   (void)do_rectangular_split;
393 }
394 
395 // Early terminates PARTITION_NONE using simple_motion_search features and the
396 // rate, distortion, and rdcost of PARTITION_NONE. This is only called when:
397 //  - The frame is a show frame
398 //  - The frame is not intra only
399 //  - The current bsize is > BLOCK_8X8
400 //  - blk_row + blk_height/2 < total_rows and blk_col + blk_width/2 < total_cols
av1_simple_motion_search_early_term_none(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,const RD_STATS * none_rdc,int * early_terminate,float * simple_motion_features,int * simple_motion_features_are_valid)401 void av1_simple_motion_search_early_term_none(
402     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
403     int mi_col, BLOCK_SIZE bsize, const RD_STATS *none_rdc,
404     int *early_terminate, float *simple_motion_features,
405     int *simple_motion_features_are_valid) {
406   // TODO(chiyotsai@google.com): There are other features we can extract from
407   // PARTITION_NONE. Play with this later.
408   int f_idx = 0;
409   if (!*simple_motion_features_are_valid) {
410     simple_motion_search_prune_part_features(cpi, x, pc_tree, mi_row, mi_col,
411                                              bsize, simple_motion_features);
412     *simple_motion_features_are_valid = 1;
413   }
414   f_idx = 25;
415 
416   simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->rate);
417   simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->dist);
418   simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->rdcost);
419 
420   assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
421 
422   const float *ml_mean = NULL;
423   const float *ml_std = NULL;
424   const float *ml_model = NULL;
425 
426   if (bsize == BLOCK_128X128) {
427     ml_mean = av1_simple_motion_search_term_none_mean_128;
428     ml_std = av1_simple_motion_search_term_none_std_128;
429     ml_model = av1_simple_motion_search_term_none_model_128;
430   } else if (bsize == BLOCK_64X64) {
431     ml_mean = av1_simple_motion_search_term_none_mean_64;
432     ml_std = av1_simple_motion_search_term_none_std_64;
433     ml_model = av1_simple_motion_search_term_none_model_64;
434   } else if (bsize == BLOCK_32X32) {
435     ml_mean = av1_simple_motion_search_term_none_mean_32;
436     ml_std = av1_simple_motion_search_term_none_std_32;
437     ml_model = av1_simple_motion_search_term_none_model_32;
438   } else if (bsize == BLOCK_16X16) {
439     ml_mean = av1_simple_motion_search_term_none_mean_16;
440     ml_std = av1_simple_motion_search_term_none_std_16;
441     ml_model = av1_simple_motion_search_term_none_model_16;
442   } else {
443     assert(0 && "Unexpected block size in simple_motion_term_none");
444   }
445 
446   if (ml_model) {
447     float score = 0.0f;
448     for (f_idx = 0; f_idx < FEATURE_SIZE_SMS_TERM_NONE; f_idx++) {
449       score += ml_model[f_idx] *
450                (simple_motion_features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
451     }
452     score += ml_model[FEATURE_SIZE_SMS_TERM_NONE];
453 
454     if (score >= 0.0f) {
455       *early_terminate = 1;
456     }
457   }
458 }
459 
firstpass_simple_motion_search_features(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,float * features)460 static void firstpass_simple_motion_search_features(
461     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
462     int mi_col, BLOCK_SIZE bsize, float *features) {
463   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
464   assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
465          cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
466 
467   // Setting up motion search
468   const int ref_list[] = { LAST_FRAME, ALTREF_FRAME };
469   const int num_refs = 2;
470   const int use_subpixel = 0;
471 
472   unsigned int int_features[10] = { 0 };
473 
474   int f_idx = 0;
475   // Doing whole block first to update the mv
476   simple_motion_search_get_best_ref(
477       cpi, x, pc_tree, mi_row, mi_col, bsize, ref_list, num_refs, use_subpixel,
478       4, &int_features[f_idx], &int_features[f_idx + 1]);
479   f_idx += 2;
480 
481   // Split subblocks
482   const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
483   const int w_mi = mi_size_wide[bsize];
484   const int h_mi = mi_size_high[bsize];
485   for (int r_idx = 0; r_idx < 4; r_idx++) {
486     const int sub_mi_col = mi_col + (r_idx & 1) * w_mi / 2;
487     const int sub_mi_row = mi_row + (r_idx >> 1) * h_mi / 2;
488 
489     simple_motion_search_get_best_ref(
490         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
491         use_subpixel, r_idx, &int_features[f_idx], &int_features[f_idx + 1]);
492     f_idx += 2;
493   }
494 
495   aom_clear_system_state();
496   for (int idx = 0; idx < f_idx; idx++) {
497     features[idx] = logf(1.0f + (float)int_features[idx]);
498   }
499 
500   const MACROBLOCKD *xd = &x->e_mbd;
501   set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
502 
503   // Q_INDEX
504   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
505   features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
506 
507   // Neighbor stuff
508   const int has_above = !!xd->above_mbmi;
509   const int has_left = !!xd->left_mbmi;
510   const BLOCK_SIZE above_bsize = has_above ? xd->above_mbmi->sb_type : bsize;
511   const BLOCK_SIZE left_bsize = has_left ? xd->left_mbmi->sb_type : bsize;
512   features[f_idx++] = (float)has_above;
513   features[f_idx++] = (float)mi_size_wide_log2[above_bsize];
514   features[f_idx++] = (float)mi_size_high_log2[above_bsize];
515   features[f_idx++] = (float)has_left;
516   features[f_idx++] = (float)mi_size_wide_log2[left_bsize];
517   features[f_idx++] = (float)mi_size_high_log2[left_bsize];
518 }
519 
av1_firstpass_simple_motion_search_early_term(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,const RD_STATS * none_rdc,int * do_square_split)520 void av1_firstpass_simple_motion_search_early_term(AV1_COMP *const cpi,
521                                                    MACROBLOCK *x,
522                                                    PC_TREE *pc_tree, int mi_row,
523                                                    int mi_col, BLOCK_SIZE bsize,
524                                                    const RD_STATS *none_rdc,
525                                                    int *do_square_split) {
526   const NN_CONFIG *nn_config = NULL;
527   float thresh = 0.0f;
528   const float *ml_mean = NULL, *ml_std = NULL;
529   if (bsize == BLOCK_32X32) {
530     nn_config = &av1_fp_simple_motion_search_term_none_nn_config_32;
531     ml_mean = av1_fp_simple_motion_search_term_none_mean_32;
532     ml_std = av1_fp_simple_motion_search_term_none_std_32;
533     thresh = av1_fp_simple_motion_search_term_none_thresh_32;
534   } else if (bsize == BLOCK_16X16) {
535     nn_config = &av1_fp_simple_motion_search_term_none_nn_config_16;
536     ml_mean = av1_fp_simple_motion_search_term_none_mean_16;
537     ml_std = av1_fp_simple_motion_search_term_none_std_16;
538     thresh = av1_fp_simple_motion_search_term_none_thresh_16;
539   } else if (bsize == BLOCK_8X8) {
540     nn_config = &av1_fp_simple_motion_search_term_none_nn_config_8;
541     ml_mean = av1_fp_simple_motion_search_term_none_mean_8;
542     ml_std = av1_fp_simple_motion_search_term_none_std_8;
543     thresh = av1_fp_simple_motion_search_term_none_thresh_8;
544   } else {
545     assert(0 &&
546            "Unexpected bsize in firstpass_simple_motion_search_early_term");
547     return;
548   }
549 
550   float ml_features[FEATURE_SIZE_FP_SMS_TERM_NONE] = { 0.0f };
551 
552   firstpass_simple_motion_search_features(cpi, x, pc_tree, mi_row, mi_col,
553                                           bsize, ml_features);
554   int f_idx = 17;
555 
556   ml_features[f_idx++] = logf(1.0f + (float)none_rdc->rate);
557   ml_features[f_idx++] = logf(1.0f + (float)none_rdc->dist);
558   ml_features[f_idx++] = logf(1.0f + (float)none_rdc->rdcost);
559 
560   for (f_idx = 0; f_idx < 20; f_idx++) {
561     ml_features[f_idx] = (ml_features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
562   }
563 
564   // Get probabilities
565   float score = 0.0f;
566 
567   av1_nn_predict(ml_features, nn_config, &score);
568   aom_clear_system_state();
569 
570   // Determine if we should prune square partitions.
571   if (score < thresh) {
572     *do_square_split = 0;
573   }
574 }
575 
av1_get_max_min_partition_features(AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,float * features)576 void av1_get_max_min_partition_features(AV1_COMP *const cpi, MACROBLOCK *x,
577                                         int mi_row, int mi_col,
578                                         float *features) {
579   AV1_COMMON *const cm = &cpi->common;
580   MACROBLOCKD *xd = &x->e_mbd;
581   const BLOCK_SIZE sb_size = cm->seq_params.sb_size;
582 
583   assert(sb_size == BLOCK_128X128);
584 
585   int f_idx = 0;
586 
587   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
588   aom_clear_system_state();
589   const float log_q_sq = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
590 
591   // Perform full-pixel single motion search in Y plane of 16x16 mbs in the sb
592   float sum_mv_row_sq = 0;
593   float sum_mv_row = 0;
594   float min_abs_mv_row = FLT_MAX;
595   float max_abs_mv_row = 0;
596 
597   float sum_mv_col_sq = 0;
598   float sum_mv_col = 0;
599   float min_abs_mv_col = FLT_MAX;
600   float max_abs_mv_col = 0;
601 
602   float sum_log_sse_sq = 0;
603   float sum_log_sse = 0;
604   float min_log_sse = FLT_MAX;
605   float max_log_sse = 0;
606 
607   const BLOCK_SIZE mb_size = BLOCK_16X16;
608   const int mb_rows = block_size_high[sb_size] / block_size_high[mb_size];
609   const int mb_cols = block_size_wide[sb_size] / block_size_wide[mb_size];
610   const int mb_in_mi_size_high_log2 = mi_size_high_log2[mb_size];
611   const int mb_in_mi_size_wide_log2 = mi_size_wide_log2[mb_size];
612 
613   for (int mb_row = 0; mb_row < mb_rows; mb_row++)
614     for (int mb_col = 0; mb_col < mb_cols; mb_col++) {
615       const int this_mi_row = mi_row + (mb_row << mb_in_mi_size_high_log2);
616       const int this_mi_col = mi_col + (mb_col << mb_in_mi_size_wide_log2);
617       unsigned int sse = 0;
618       unsigned int var = 0;
619       const MV ref_mv_full = { .row = 0, .col = 0 };
620 
621       av1_simple_motion_sse_var(cpi, x, this_mi_row, this_mi_col, mb_size,
622                                 ref_mv_full, 0, &sse, &var);
623 
624       aom_clear_system_state();
625       const float mv_row = (float)(x->best_mv.as_mv.row / 8);
626       const float mv_col = (float)(x->best_mv.as_mv.col / 8);
627       const float log_sse = logf(1.0f + (float)sse);
628       const float abs_mv_row = fabsf(mv_row);
629       const float abs_mv_col = fabsf(mv_col);
630 
631       sum_mv_row_sq += mv_row * mv_row;
632       sum_mv_row += mv_row;
633       sum_mv_col_sq += mv_col * mv_col;
634       sum_mv_col += mv_col;
635 
636       if (abs_mv_row < min_abs_mv_row) min_abs_mv_row = abs_mv_row;
637       if (abs_mv_row > max_abs_mv_row) max_abs_mv_row = abs_mv_row;
638       if (abs_mv_col < min_abs_mv_col) min_abs_mv_col = abs_mv_col;
639       if (abs_mv_col > max_abs_mv_col) max_abs_mv_col = abs_mv_col;
640 
641       sum_log_sse_sq += log_sse * log_sse;
642       sum_log_sse += log_sse;
643       if (log_sse < min_log_sse) min_log_sse = log_sse;
644       if (log_sse > max_log_sse) max_log_sse = log_sse;
645     }
646   aom_clear_system_state();
647   const float avg_mv_row = sum_mv_row / 64.0f;
648   const float var_mv_row = sum_mv_row_sq / 64.0f - avg_mv_row * avg_mv_row;
649 
650   const float avg_mv_col = sum_mv_col / 64.0f;
651   const float var_mv_col = sum_mv_col_sq / 64.0f - avg_mv_col * avg_mv_col;
652 
653   const float avg_log_sse = sum_log_sse / 64.0f;
654   const float var_log_sse = sum_log_sse_sq / 64.0f - avg_log_sse * avg_log_sse;
655 
656   features[f_idx++] = avg_log_sse;
657   features[f_idx++] = avg_mv_col;
658   features[f_idx++] = avg_mv_row;
659   features[f_idx++] = log_q_sq;
660   features[f_idx++] = max_abs_mv_col;
661   features[f_idx++] = max_abs_mv_row;
662   features[f_idx++] = max_log_sse;
663   features[f_idx++] = min_abs_mv_col;
664   features[f_idx++] = min_abs_mv_row;
665   features[f_idx++] = min_log_sse;
666   features[f_idx++] = var_log_sse;
667   features[f_idx++] = var_mv_col;
668   features[f_idx++] = var_mv_row;
669 
670   assert(f_idx == FEATURE_SIZE_MAX_MIN_PART_PRED);
671 }
672 
av1_predict_max_partition(AV1_COMP * const cpi,MACROBLOCK * const x,const float * features)673 BLOCK_SIZE av1_predict_max_partition(AV1_COMP *const cpi, MACROBLOCK *const x,
674                                      const float *features) {
675   float scores[MAX_NUM_CLASSES_MAX_MIN_PART_PRED] = { 0.0f },
676         probs[MAX_NUM_CLASSES_MAX_MIN_PART_PRED] = { 0.0f };
677   const NN_CONFIG *nn_config = &av1_max_part_pred_nn_config;
678 
679   assert(cpi->sf.auto_max_partition_based_on_simple_motion != NOT_IN_USE);
680 
681   aom_clear_system_state();
682   av1_nn_predict(features, nn_config, scores);
683   av1_nn_softmax(scores, probs, MAX_NUM_CLASSES_MAX_MIN_PART_PRED);
684 
685   int result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1;
686   if (cpi->sf.auto_max_partition_based_on_simple_motion == DIRECT_PRED) {
687     result = 0;
688     float max_prob = probs[0];
689     for (int i = 1; i < MAX_NUM_CLASSES_MAX_MIN_PART_PRED; ++i) {
690       if (probs[i] > max_prob) {
691         max_prob = probs[i];
692         result = i;
693       }
694     }
695   } else if (cpi->sf.auto_max_partition_based_on_simple_motion ==
696              RELAXED_PRED) {
697     for (result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1; result >= 0;
698          --result) {
699       if (result < MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1) {
700         probs[result] += probs[result + 1];
701       }
702       if (probs[result] > 0.2) break;
703     }
704   } else if (cpi->sf.auto_max_partition_based_on_simple_motion == ADAPT_PRED) {
705     const BLOCK_SIZE sb_size = cpi->common.seq_params.sb_size;
706     MACROBLOCKD *const xd = &x->e_mbd;
707     // TODO(debargha): x->source_variance is unavailable at this point,
708     // so compute. The redundant recomputation later can be removed.
709     const unsigned int source_variance =
710         is_cur_buf_hbd(xd)
711             ? av1_high_get_sby_perpixel_variance(cpi, &x->plane[0].src, sb_size,
712                                                  xd->bd)
713             : av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, sb_size);
714     if (source_variance > 16) {
715       const double thresh = source_variance < 128 ? 0.05 : 0.1;
716       for (result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1; result >= 0;
717            --result) {
718         if (result < MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1) {
719           probs[result] += probs[result + 1];
720         }
721         if (probs[result] > thresh) break;
722       }
723     }
724   }
725 
726   return (BLOCK_SIZE)((result + 2) * 3);
727 }
728